Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/run_eval_iou.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
255 lines (216 sloc)
10.1 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import pandas as pd | |
import ast | |
import statistics | |
import argparse | |
from metrics.iou import IoU | |
from sklearn.metrics import precision_score, recall_score, f1_score | |
from typing import List, Tuple | |
def create_binary_classifications(pred_spans, grouped_true_spans, iou_threshold=0.4) -> Tuple[List[int], List[int]]: | |
""" | |
Create binary classifications for spans based on IoU matching. | |
Args: | |
pred_spans: List of predicted spans | |
grouped_true_spans: List of lists of ground truth spans | |
iou_threshold: Threshold for considering spans as matching | |
Returns: | |
tuple: Lists of ground truth and prediction binary labels | |
""" | |
gt_labels = [] | |
pred_labels = [] | |
iou_metric = IoU() | |
# Find best matching group based on average IoU | |
best_group_iou = 0 | |
best_group_idx = 0 | |
for group_idx, true_spans in enumerate(grouped_true_spans): | |
group_ious = [] | |
for true_span in true_spans: | |
max_iou = 0 | |
for pred_span in pred_spans: | |
# Calculate IoU between true_span and pred_span | |
iou_score = iou_metric.compute([true_span], [pred_span]) | |
if iou_score is not None: | |
group_ious.append(iou_score) | |
if iou_score is not None and iou_score > max_iou: | |
max_iou = iou_score | |
avg_group_iou = sum(group_ious) / len(group_ious) if group_ious else 0 | |
if avg_group_iou > best_group_iou: | |
best_group_iou = avg_group_iou | |
best_group_idx = group_idx - 1 # We subtract 1 because we start counting from 0 | |
# Use best matching group to create binary classifications | |
true_spans = grouped_true_spans[best_group_idx] if grouped_true_spans else [] | |
# For each ground truth span | |
for true_span in true_spans: | |
max_iou = 0 | |
for pred_span in pred_spans: | |
# Calculate IoU | |
iou_score = iou_metric.compute([pred_span], [true_span]) | |
if iou_score is not None: | |
max_iou = max(max_iou, iou_score) | |
gt_labels.append(1) # This is a ground truth span | |
pred_labels.append(1 if max_iou >= iou_threshold else 0) | |
# For each predicted span | |
for pred_span in pred_spans: | |
max_iou = 0 | |
for true_span in true_spans: | |
# Calculate IoU | |
iou_score = iou_metric.compute([pred_span], [true_span]) | |
if iou_score is not None: | |
max_iou = max(max_iou, iou_score) | |
#TODO: Is this double counting? | |
# Only add if this prediction didn't match any ground truth span | |
if max_iou < iou_threshold: | |
gt_labels.append(0) # This is not a ground truth span | |
pred_labels.append(1) # But it was predicted | |
return gt_labels, pred_labels | |
def load_excel_data(excel_file): | |
""" | |
Load data from the Excel file. | |
Returns: | |
DataFrame: The data from the Excel file. | |
""" | |
df = pd.read_excel(excel_file) | |
return df | |
def load_json_data(json_file): | |
""" | |
Load data from the JSON file. | |
Returns: | |
list: The data from the JSON file. | |
""" | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
return data | |
def get_spans_from_row(row): | |
""" | |
Extract spans from a row in the DataFrame. | |
Returns: | |
list of tuples: List of spans. | |
""" | |
# Assuming 'citation_spans_bound' is a column in the DataFrame containing spans as strings | |
# For example: "[(start1, end1), (start2, end2), ...]" | |
spans = ast.literal_eval(row['citation_spans_bound']) if row['citation_spans_bound'] else [] | |
return spans | |
def get_spans_from_json(doc_item): | |
""" | |
Extract spans from a document item in the JSON data. | |
Returns: | |
list of tuples: List of spans. | |
""" | |
grouped_spans = [] | |
# 'citation_spans' is assumed to be in the structure as per the given JSON | |
# Adjust the parsing logic based on the exact structure | |
for span_group in doc_item: | |
spans = [] | |
for span in span_group: | |
spans.append((span['start'], span['end'])) | |
grouped_spans.append(spans) | |
return grouped_spans | |
def parse_args(): | |
""" | |
Parse command line arguments. | |
""" | |
parser = argparse.ArgumentParser(description='Compute IoU scores between Excel and JSON data') | |
parser.add_argument('--pred', type=str, | |
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json') | |
parser.add_argument('--gt', type=str, | |
help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json') | |
parser.add_argument('--output', type=str, help='Path to save the output to a JSON file (optional)', default='./eli5_salsa_sample_iou.json') | |
return parser.parse_args() | |
def main(): | |
# Parse command line arguments | |
args = parse_args() | |
print("---------------------------------------------------------------------------------------------------------") | |
print(f">>> Computing metrics for {args.pred}") | |
print("---------------------------------------------------------------------------------------------------------") | |
# Load both JSON files | |
pred_data = load_json_data(args.pred) | |
gt_data = load_json_data(args.gt) | |
# Initialize metrics | |
iou_metric = IoU() | |
all_gt_labels = [] | |
all_pred_labels = [] | |
dataset_citation_scores = [] | |
dataset_conflict_scores = [] | |
dataset_citation_prf1 = [] | |
dataset_conflict_prf1 = [] | |
# Iterate through the data | |
for i, (pred_item, gt_item) in enumerate(zip(pred_data, gt_data)): | |
# Get spans from both JSONs | |
for pred_sentence, gt_sentence in zip(pred_item['answer_sentences'], gt_item['answer_sentences']): | |
# if pred_sentence['sentence'].strip() != gt_sentence['sentence'].strip(): | |
# print(f"Sentence mismatch: {pred_sentence['sentence']} != {gt_sentence['sentence']}") | |
# continue | |
for span_type, dataset_scores, dataset_prf1 in [ | |
("citation", dataset_citation_scores, dataset_citation_prf1), | |
("conflict", dataset_conflict_scores, dataset_conflict_prf1) | |
]: | |
doc_scores = [] | |
for pred_doc, gt_doc in zip(pred_sentence[f'{span_type}_spans'], gt_sentence[f'{span_type}_spans']): | |
# Handle citation spans | |
grouped_pred_spans = get_spans_from_json(pred_doc) | |
grouped_true_spans = get_spans_from_json(gt_doc) | |
# grouped_pred_spans and grouped_true_spans are lists of lists of tuples | |
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else [] | |
# Get binary classifications | |
gt_labels, pred_labels = create_binary_classifications( | |
pred_spans, grouped_true_spans, iou_threshold=0.4 | |
) | |
# Compute IoU scores only where GT is 1 | |
#TODO: This is almost non-existent for conflicting spans. | |
doc_scores = [iou_metric.compute([pred_span], [true_span]) for pred_span, true_span, gt_label in zip(pred_spans, grouped_true_spans, gt_labels) if gt_label == 1] | |
all_gt_labels.extend(gt_labels) | |
all_pred_labels.extend(pred_labels) | |
# Store score in the prediction | |
pred_sentence[f'{span_type}_iou_scores'] = doc_scores | |
# Calculate mean scores if there are any valid scores | |
# for that sentence in both the prediction and the ground truth | |
#TODO: Check this assumption. | |
pred_sentence[f'mean_{span_type}_iou'] = statistics.mean(doc_scores) if doc_scores else None | |
dataset_scores.append(pred_sentence[f'mean_{span_type}_iou']) | |
# Save the augmented predictions if output path is provided | |
if args.output: | |
with open(args.output, 'w') as f: | |
json.dump(pred_data, f, indent=2) | |
print(f"Augmented predictions saved to: {args.output}") | |
# Print summary statistics | |
summary_stats = { | |
"citation": {}, | |
"conflict": {} | |
} | |
print(f"\nSummary Statistics:") | |
for span_type, dataset_scores in [ | |
("citation", dataset_citation_scores), | |
("conflict", dataset_conflict_scores) | |
]: | |
# IoU summary statistics | |
valid_scores = [score for score in dataset_scores if score is not None] | |
print(f"\nNumber of valid IoU scores ({span_type}): {len(valid_scores)} / {len(dataset_scores)}") | |
if valid_scores: | |
summary_stats[span_type]['mean_iou'] = sum(valid_scores) / len(valid_scores) | |
summary_stats[span_type]['median_iou'] = statistics.median(valid_scores) | |
summary_stats[span_type]['std_iou'] = statistics.stdev(valid_scores) | |
print(f"Average IoU Score: {summary_stats[span_type]['mean_iou']:.4f}") | |
print(f"Median IoU Score: {summary_stats[span_type]['median_iou']:.4f}") | |
print(f"Standard Deviation: {summary_stats[span_type]['std_iou']:.4f}") | |
else: | |
print(f"No valid {span_type} IoU scores to compute statistics.") | |
# PRF1 summary statistics | |
precision = precision_score(all_gt_labels, all_pred_labels, zero_division=0) | |
recall = recall_score(all_gt_labels, all_pred_labels, zero_division=0) | |
f1 = f1_score(all_gt_labels, all_pred_labels, zero_division=0) | |
print(f"\n\nDocument-level metrics:") | |
print(f"Precision: {precision:.4f}") | |
print(f"Recall: {recall:.4f}") | |
print(f"F1: {f1:.4f}") | |
summary_stats[span_type]['precision'] = precision | |
summary_stats[span_type]['recall'] = recall | |
summary_stats[span_type]['f1'] = f1 | |
# Save the summary statistics if output path is provided | |
if args.output: | |
summary_stats_path = args.output.replace('.json', '_summary.json') | |
with open(summary_stats_path, 'w') as f: | |
json.dump(summary_stats, f, indent=2) | |
print(f"\nSummary statistics saved to: {summary_stats_path}") | |
print() | |
if __name__ == "__main__": | |
main() |