Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import json
import numpy as np
import argparse
from sklearn.metrics import precision_score, recall_score, f1_score
def load_json_data(json_file):
"""
Load data from the JSON file.
Returns:
list: The data from the JSON file.
"""
with open(json_file, 'r') as f:
data = json.load(f)
return data
def parse_args():
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser(description='Compute metrics on the pipeline output')
parser.add_argument("--eli5_path", type=str,
help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
parser.add_argument('--gt_path', type=str,
help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json')
parser.add_argument('--pred_path', type=str,
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json')
parser.add_argument("--doc_level_f1_threshold", type=float,
help="Threshold for the document-level F1 score to determine if a document can be considered 'correctly' cited", default=0.5)
return parser.parse_args()
def get_span_group_char_indicators(doc_len, span_group):
char_inds = np.zeros(doc_len, dtype=int)
for span in span_group:
char_inds[span["start"]:span["end"]] = 1
return char_inds
def add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds):
# Compute summary statistics
num_samples = len(gt_inds)
num_gt_positive = np.sum(gt_inds).item()
num_pred_positive = np.sum(pred_inds).item()
percent_gt_positive = 100 * (num_gt_positive / num_samples)
percent_pred_positive = 100 * (num_pred_positive / num_samples)
precision = precision_score(gt_inds, pred_inds)
recall = recall_score(gt_inds, pred_inds)
f1 = f1_score(gt_inds, pred_inds)
# Add to the dictionary
summary_stats[span_type][level] = {
'num_samples': num_samples,
'num_gt_positive': num_gt_positive,
'num_pred_positive': num_pred_positive,
'percent_gt_positive': percent_gt_positive,
'percent_pred_positive': percent_pred_positive,
'precision': precision,
'recall': recall,
'f1': f1,
}
# Print out
print()
print(f"({level}-level) Number of GT positives: {num_gt_positive} / {num_samples} ({percent_gt_positive:.2f}%)")
print(f"({level}-level) Number of predicted positives: {num_pred_positive} / {num_samples} ({percent_pred_positive:.2f}%)")
print(f"({level}-level) Precision: {precision:.4f}")
print(f"({level}-level) Recall: {recall:.4f}")
print(f"({level}-level) F1: {f1:.4f}")
print()
def main():
# Parse command line arguments
args = parse_args()
print("\n---------------------------------------------------------------------------------------------------------")
print(f">>> Computing metrics for {args.pred_path}")
print("---------------------------------------------------------------------------------------------------------\n")
# Load both JSON files
eli5_data = load_json_data(args.eli5_path)
gt_data = load_json_data(args.gt_path)
pred_data = load_json_data(args.pred_path)
# Iterate through the data
citation_gt_char_inds = []
citation_pred_char_inds = []
conflict_gt_char_inds = []
conflict_pred_char_inds = []
citation_gt_doc_inds = []
citation_pred_doc_inds = []
conflict_gt_doc_inds = []
conflict_pred_doc_inds = []
for eli5_item, gt_item, pred_item in zip(eli5_data, gt_data, pred_data):
for gt_sentence, pred_sentence in zip(gt_item['answer_sentences'], pred_item['answer_sentences']):
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds),
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
]:
for i, (gt_doc_span_groups, pred_doc_span_groups) in enumerate(
zip(gt_sentence[f'{span_type}_spans'], pred_sentence[f'{span_type}_spans'])
):
# Get pred and gt indicators {0, 1} for each character position in the document
doc_len = len(eli5_item["docs"][i]["text"])
# Predicted spans aren't grouped so there will only be one group
pred_doc_spans = pred_doc_span_groups[0] if pred_doc_span_groups else []
pred_doc_char_inds = get_span_group_char_indicators(doc_len, pred_doc_spans)
# Ground truth spans are grouped
gt_doc_span_groups = gt_doc_span_groups if gt_doc_span_groups else [[]]
gt_doc_group_char_inds = [
get_span_group_char_indicators(doc_len, gt_doc_spans)
for gt_doc_spans in gt_doc_span_groups
]
# get the f1 score for each group vs the predicted indicators and pick the best matching group
# (GT span groups are OR'd together meaning that the predicted spans can match any of the groups to be correct)
doc_group_f1_scores = [
f1_score(gt_doc_char_inds, pred_doc_char_inds, zero_division=0)
for gt_doc_char_inds in gt_doc_group_char_inds
]
best_doc_group_f1_idx = np.argmax(doc_group_f1_scores)
doc_f1_score = doc_group_f1_scores[best_doc_group_f1_idx]
gt_doc_char_inds = gt_doc_group_char_inds[best_doc_group_f1_idx]
# Append to the character-level indicator lists
gt_char_inds.append(gt_doc_char_inds)
pred_char_inds.append(pred_doc_char_inds)
# get pred and gt indicators {0, 1} for each document
gt_doc_ind = 1 if np.sum(gt_doc_char_inds) > 0 else 0
if gt_doc_ind == 0:
# If no GT spans exist than consider the document cited if any predicted spans exist
pred_doc_ind = 1 if np.sum(pred_doc_char_inds) > 0 else 0
else:
# If GT spans exist than consider the document cited if the F1 score is above the threshold
# i.e. the predicted citation spans are close enough to the GT spans
pred_doc_ind = 1 if doc_f1_score >= args.doc_level_f1_threshold else 0
# Append to the document-level indicator lists
gt_doc_inds.append(gt_doc_ind)
pred_doc_inds.append(pred_doc_ind)
citation_gt_char_inds = np.concatenate(citation_gt_char_inds)
citation_pred_char_inds = np.concatenate(citation_pred_char_inds)
conflict_gt_char_inds = np.concatenate(conflict_gt_char_inds)
conflict_pred_char_inds = np.concatenate(conflict_pred_char_inds)
citation_gt_doc_inds = np.array(citation_gt_doc_inds)
citation_pred_doc_inds = np.array(citation_pred_doc_inds)
conflict_gt_doc_inds = np.array(conflict_gt_doc_inds)
conflict_pred_doc_inds = np.array(conflict_pred_doc_inds)
# Print summary statistics
summary_stats = {
"citation": {},
"conflict": {},
}
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds),
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
]:
print("-------------------------")
print(f"> Span type: {span_type}")
print("-------------------------")
for level, gt_inds, pred_inds in [("char", gt_char_inds, pred_char_inds), ("doc", gt_doc_inds, pred_doc_inds)]:
add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds)
# Save the summary statistics
output_path = args.pred_path.replace(".json", "_summary_stats.json")
with open(output_path, 'w') as f:
json.dump(summary_stats, f, indent=2)
print(f"Summary statistics saved to: {output_path}")
if __name__ == "__main__":
main()