import json
import numpy as np
import argparse
from sklearn.metrics import precision_score, recall_score, f1_score
def load_json_data(json_file):
Load data from the JSON file.
list: The data from the JSON file.
with open(json_file, 'r') as f:
data = json.load(f)
return data
def parse_args():
Parse command line arguments.
parser = argparse.ArgumentParser(description='Compute metrics on the pipeline output')
parser.add_argument("--eli5_path", type=str,
help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
parser.add_argument('--gt_path', type=str,
help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json')
parser.add_argument('--pred_path', type=str,
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json')
parser.add_argument("--doc_level_f1_threshold", type=float,
help="Threshold for the document-level F1 score to determine if a document can be considered 'correctly' cited", default=0.5)
return parser.parse_args()
def get_span_group_char_indicators(doc_len, span_group):
char_inds = np.zeros(doc_len, dtype=int)
for span in span_group:
char_inds[span["start"]:span["end"]] = 1
return char_inds
def add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds):
# Compute summary statistics
num_samples = len(gt_inds)
num_gt_positive = np.sum(gt_inds).item()
num_pred_positive = np.sum(pred_inds).item()
percent_gt_positive = 100 * (num_gt_positive / num_samples)
percent_pred_positive = 100 * (num_pred_positive / num_samples)
precision = precision_score(gt_inds, pred_inds)
recall = recall_score(gt_inds, pred_inds)
f1 = f1_score(gt_inds, pred_inds)
# Add to the dictionary
summary_stats[span_type][level] = {
'num_samples': num_samples,
'num_gt_positive': num_gt_positive,
'num_pred_positive': num_pred_positive,
'percent_gt_positive': percent_gt_positive,
'percent_pred_positive': percent_pred_positive,
'precision': precision,
'recall': recall,
'f1': f1,
# Print out
print(f"({level}-level) Number of GT positives: {num_gt_positive} / {num_samples} ({percent_gt_positive:.2f}%)")
print(f"({level}-level) Number of predicted positives: {num_pred_positive} / {num_samples} ({percent_pred_positive:.2f}%)")
print(f"({level}-level) Precision: {precision:.4f}")
print(f"({level}-level) Recall: {recall:.4f}")
print(f"({level}-level) F1: {f1:.4f}")
def main():
# Parse command line arguments
args = parse_args()
print(f">>> Computing metrics for {args.pred_path}")
# Load both JSON files
eli5_data = load_json_data(args.eli5_path)
gt_data = load_json_data(args.gt_path)
pred_data = load_json_data(args.pred_path)
# Iterate through the data
citation_gt_char_inds = []
citation_pred_char_inds = []
conflict_gt_char_inds = []
conflict_pred_char_inds = []
citation_gt_doc_inds = []
citation_pred_doc_inds = []
conflict_gt_doc_inds = []
conflict_pred_doc_inds = []
for eli5_item, gt_item, pred_item in zip(eli5_data, gt_data, pred_data):
for gt_sentence, pred_sentence in zip(gt_item['answer_sentences'], pred_item['answer_sentences']):
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds),
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
for i, (gt_doc_span_groups, pred_doc_span_groups) in enumerate(
zip(gt_sentence[f'{span_type}_spans'], pred_sentence[f'{span_type}_spans'])
# Get pred and gt indicators {0, 1} for each character position in the document
doc_len = len(eli5_item["docs"][i]["text"])
# Predicted spans aren't grouped so there will only be one group
pred_doc_spans = pred_doc_span_groups[0] if pred_doc_span_groups else []
pred_doc_char_inds = get_span_group_char_indicators(doc_len, pred_doc_spans)
# Ground truth spans are grouped
gt_doc_span_groups = gt_doc_span_groups if gt_doc_span_groups else [[]]
gt_doc_group_char_inds = [
get_span_group_char_indicators(doc_len, gt_doc_spans)
for gt_doc_spans in gt_doc_span_groups
# get the f1 score for each group vs the predicted indicators and pick the best matching group
# (GT span groups are OR'd together meaning that the predicted spans can match any of the groups to be correct)
doc_group_f1_scores = [
f1_score(gt_doc_char_inds, pred_doc_char_inds, zero_division=0)
for gt_doc_char_inds in gt_doc_group_char_inds
best_doc_group_f1_idx = np.argmax(doc_group_f1_scores)
doc_f1_score = doc_group_f1_scores[best_doc_group_f1_idx]
gt_doc_char_inds = gt_doc_group_char_inds[best_doc_group_f1_idx]
# Append to the character-level indicator lists
# get pred and gt indicators {0, 1} for each document
gt_doc_ind = 1 if np.sum(gt_doc_char_inds) > 0 else 0
if gt_doc_ind == 0:
# If no GT spans exist than consider the document cited if any predicted spans exist
pred_doc_ind = 1 if np.sum(pred_doc_char_inds) > 0 else 0
# If GT spans exist than consider the document cited if the F1 score is above the threshold
# i.e. the predicted citation spans are close enough to the GT spans
pred_doc_ind = 1 if doc_f1_score >= args.doc_level_f1_threshold else 0
# Append to the document-level indicator lists
citation_gt_char_inds = np.concatenate(citation_gt_char_inds)
citation_pred_char_inds = np.concatenate(citation_pred_char_inds)
conflict_gt_char_inds = np.concatenate(conflict_gt_char_inds)
conflict_pred_char_inds = np.concatenate(conflict_pred_char_inds)
citation_gt_doc_inds = np.array(citation_gt_doc_inds)
citation_pred_doc_inds = np.array(citation_pred_doc_inds)
conflict_gt_doc_inds = np.array(conflict_gt_doc_inds)
conflict_pred_doc_inds = np.array(conflict_pred_doc_inds)
# Print summary statistics
summary_stats = {
"citation": {},
"conflict": {},
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds),
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
print(f"> Span type: {span_type}")
for level, gt_inds, pred_inds in [("char", gt_char_inds, pred_char_inds), ("doc", gt_doc_inds, pred_doc_inds)]:
add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds)
# Save the summary statistics
output_path = args.pred_path.replace(".json", "_summary_stats.json")
with open(output_path, 'w') as f:
json.dump(summary_stats, f, indent=2)
print(f"Summary statistics saved to: {output_path}")
if __name__ == "__main__":