import json import numpy as np import argparse from sklearn.metrics import precision_score, recall_score, f1_score def load_json_data(json_file): """ Load data from the JSON file. Returns: list: The data from the JSON file. """ with open(json_file, 'r') as f: data = json.load(f) return data def parse_args(): """ Parse command line arguments. """ parser = argparse.ArgumentParser(description='Compute metrics on the pipeline output') parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json") parser.add_argument('--gt_path', type=str, help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json') parser.add_argument('--pred_path', type=str, help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json') parser.add_argument("--doc_level_f1_threshold", type=float, help="Threshold for the document-level F1 score to determine if a document can be considered 'correctly' cited", default=0.5) return parser.parse_args() def get_span_group_char_indicators(doc_len, span_group): char_inds = np.zeros(doc_len, dtype=int) for span in span_group: char_inds[span["start"]:span["end"]] = 1 return char_inds def add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds): # Compute summary statistics num_samples = len(gt_inds) num_gt_positive = np.sum(gt_inds).item() num_pred_positive = np.sum(pred_inds).item() percent_gt_positive = 100 * (num_gt_positive / num_samples) percent_pred_positive = 100 * (num_pred_positive / num_samples) precision = precision_score(gt_inds, pred_inds) recall = recall_score(gt_inds, pred_inds) f1 = f1_score(gt_inds, pred_inds) # Add to the dictionary summary_stats[span_type][level] = { 'num_samples': num_samples, 'num_gt_positive': num_gt_positive, 'num_pred_positive': num_pred_positive, 'percent_gt_positive': percent_gt_positive, 'percent_pred_positive': percent_pred_positive, 'precision': precision, 'recall': recall, 'f1': f1, } # Print out print() print(f"({level}-level) Number of GT positives: {num_gt_positive} / {num_samples} ({percent_gt_positive:.2f}%)") print(f"({level}-level) Number of predicted positives: {num_pred_positive} / {num_samples} ({percent_pred_positive:.2f}%)") print(f"({level}-level) Precision: {precision:.4f}") print(f"({level}-level) Recall: {recall:.4f}") print(f"({level}-level) F1: {f1:.4f}") print() def main(): # Parse command line arguments args = parse_args() print("\n---------------------------------------------------------------------------------------------------------") print(f">>> Computing metrics for {args.pred_path}") print("---------------------------------------------------------------------------------------------------------\n") # Load both JSON files eli5_data = load_json_data(args.eli5_path) gt_data = load_json_data(args.gt_path) pred_data = load_json_data(args.pred_path) # Iterate through the data citation_gt_char_inds = [] citation_pred_char_inds = [] conflict_gt_char_inds = [] conflict_pred_char_inds = [] citation_gt_doc_inds = [] citation_pred_doc_inds = [] conflict_gt_doc_inds = [] conflict_pred_doc_inds = [] for eli5_item, gt_item, pred_item in zip(eli5_data, gt_data, pred_data): for gt_sentence, pred_sentence in zip(gt_item['answer_sentences'], pred_item['answer_sentences']): for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [ ("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), ("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds), ]: for i, (gt_doc_span_groups, pred_doc_span_groups) in enumerate( zip(gt_sentence[f'{span_type}_spans'], pred_sentence[f'{span_type}_spans']) ): # Get pred and gt indicators {0, 1} for each character position in the document doc_len = len(eli5_item["docs"][i]["text"]) # Predicted spans aren't grouped so there will only be one group pred_doc_spans = pred_doc_span_groups[0] if pred_doc_span_groups else [] pred_doc_char_inds = get_span_group_char_indicators(doc_len, pred_doc_spans) # Ground truth spans are grouped gt_doc_span_groups = gt_doc_span_groups if gt_doc_span_groups else [[]] gt_doc_group_char_inds = [ get_span_group_char_indicators(doc_len, gt_doc_spans) for gt_doc_spans in gt_doc_span_groups ] # get the f1 score for each group vs the predicted indicators and pick the best matching group # (GT span groups are OR'd together meaning that the predicted spans can match any of the groups to be correct) doc_group_f1_scores = [ f1_score(gt_doc_char_inds, pred_doc_char_inds, zero_division=0) for gt_doc_char_inds in gt_doc_group_char_inds ] best_doc_group_f1_idx = np.argmax(doc_group_f1_scores) doc_f1_score = doc_group_f1_scores[best_doc_group_f1_idx] gt_doc_char_inds = gt_doc_group_char_inds[best_doc_group_f1_idx] # Append to the character-level indicator lists gt_char_inds.append(gt_doc_char_inds) pred_char_inds.append(pred_doc_char_inds) # get pred and gt indicators {0, 1} for each document gt_doc_ind = 1 if np.sum(gt_doc_char_inds) > 0 else 0 if gt_doc_ind == 0: # If no GT spans exist than consider the document cited if any predicted spans exist pred_doc_ind = 1 if np.sum(pred_doc_char_inds) > 0 else 0 else: # If GT spans exist than consider the document cited if the F1 score is above the threshold # i.e. the predicted citation spans are close enough to the GT spans pred_doc_ind = 1 if doc_f1_score >= args.doc_level_f1_threshold else 0 # Append to the document-level indicator lists gt_doc_inds.append(gt_doc_ind) pred_doc_inds.append(pred_doc_ind) citation_gt_char_inds = np.concatenate(citation_gt_char_inds) citation_pred_char_inds = np.concatenate(citation_pred_char_inds) conflict_gt_char_inds = np.concatenate(conflict_gt_char_inds) conflict_pred_char_inds = np.concatenate(conflict_pred_char_inds) citation_gt_doc_inds = np.array(citation_gt_doc_inds) citation_pred_doc_inds = np.array(citation_pred_doc_inds) conflict_gt_doc_inds = np.array(conflict_gt_doc_inds) conflict_pred_doc_inds = np.array(conflict_pred_doc_inds) # Print summary statistics summary_stats = { "citation": {}, "conflict": {}, } for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [ ("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), ("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds), ]: print("-------------------------") print(f"> Span type: {span_type}") print("-------------------------") for level, gt_inds, pred_inds in [("char", gt_char_inds, pred_char_inds), ("doc", gt_doc_inds, pred_doc_inds)]: add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds) # Save the summary statistics output_path = args.pred_path.replace(".json", "_summary_stats.json") with open(output_path, 'w') as f: json.dump(summary_stats, f, indent=2) print(f"Summary statistics saved to: {output_path}") if __name__ == "__main__": main()