Skip to content
Permalink
4ebe4cf5f2
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
255 lines (216 sloc) 10.1 KB
import json
import pandas as pd
import ast
import statistics
import argparse
from metrics.iou import IoU
from sklearn.metrics import precision_score, recall_score, f1_score
from typing import List, Tuple
def create_binary_classifications(pred_spans, grouped_true_spans, iou_threshold=0.4) -> Tuple[List[int], List[int]]:
"""
Create binary classifications for spans based on IoU matching.
Args:
pred_spans: List of predicted spans
grouped_true_spans: List of lists of ground truth spans
iou_threshold: Threshold for considering spans as matching
Returns:
tuple: Lists of ground truth and prediction binary labels
"""
gt_labels = []
pred_labels = []
iou_metric = IoU()
# Find best matching group based on average IoU
best_group_iou = 0
best_group_idx = 0
for group_idx, true_spans in enumerate(grouped_true_spans):
group_ious = []
for true_span in true_spans:
max_iou = 0
for pred_span in pred_spans:
# Calculate IoU between true_span and pred_span
iou_score = iou_metric.compute([true_span], [pred_span])
if iou_score is not None:
group_ious.append(iou_score)
if iou_score is not None and iou_score > max_iou:
max_iou = iou_score
avg_group_iou = sum(group_ious) / len(group_ious) if group_ious else 0
if avg_group_iou > best_group_iou:
best_group_iou = avg_group_iou
best_group_idx = group_idx - 1 # We subtract 1 because we start counting from 0
# Use best matching group to create binary classifications
true_spans = grouped_true_spans[best_group_idx] if grouped_true_spans else []
# For each ground truth span
for true_span in true_spans:
max_iou = 0
for pred_span in pred_spans:
# Calculate IoU
iou_score = iou_metric.compute([pred_span], [true_span])
if iou_score is not None:
max_iou = max(max_iou, iou_score)
gt_labels.append(1) # This is a ground truth span
pred_labels.append(1 if max_iou >= iou_threshold else 0)
# For each predicted span
for pred_span in pred_spans:
max_iou = 0
for true_span in true_spans:
# Calculate IoU
iou_score = iou_metric.compute([pred_span], [true_span])
if iou_score is not None:
max_iou = max(max_iou, iou_score)
#TODO: Is this double counting?
# Only add if this prediction didn't match any ground truth span
if max_iou < iou_threshold:
gt_labels.append(0) # This is not a ground truth span
pred_labels.append(1) # But it was predicted
return gt_labels, pred_labels
def load_excel_data(excel_file):
"""
Load data from the Excel file.
Returns:
DataFrame: The data from the Excel file.
"""
df = pd.read_excel(excel_file)
return df
def load_json_data(json_file):
"""
Load data from the JSON file.
Returns:
list: The data from the JSON file.
"""
with open(json_file, 'r') as f:
data = json.load(f)
return data
def get_spans_from_row(row):
"""
Extract spans from a row in the DataFrame.
Returns:
list of tuples: List of spans.
"""
# Assuming 'citation_spans_bound' is a column in the DataFrame containing spans as strings
# For example: "[(start1, end1), (start2, end2), ...]"
spans = ast.literal_eval(row['citation_spans_bound']) if row['citation_spans_bound'] else []
return spans
def get_spans_from_json(doc_item):
"""
Extract spans from a document item in the JSON data.
Returns:
list of tuples: List of spans.
"""
grouped_spans = []
# 'citation_spans' is assumed to be in the structure as per the given JSON
# Adjust the parsing logic based on the exact structure
for span_group in doc_item:
spans = []
for span in span_group:
spans.append((span['start'], span['end']))
grouped_spans.append(spans)
return grouped_spans
def parse_args():
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser(description='Compute IoU scores between Excel and JSON data')
parser.add_argument('--pred', type=str,
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json')
parser.add_argument('--gt', type=str,
help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json')
parser.add_argument('--output', type=str, help='Path to save the output to a JSON file (optional)', default='./eli5_salsa_sample_iou.json')
return parser.parse_args()
def main():
# Parse command line arguments
args = parse_args()
print("---------------------------------------------------------------------------------------------------------")
print(f">>> Computing metrics for {args.pred}")
print("---------------------------------------------------------------------------------------------------------")
# Load both JSON files
pred_data = load_json_data(args.pred)
gt_data = load_json_data(args.gt)
# Initialize metrics
iou_metric = IoU()
all_gt_labels = []
all_pred_labels = []
dataset_citation_scores = []
dataset_conflict_scores = []
dataset_citation_prf1 = []
dataset_conflict_prf1 = []
# Iterate through the data
for i, (pred_item, gt_item) in enumerate(zip(pred_data, gt_data)):
# Get spans from both JSONs
for pred_sentence, gt_sentence in zip(pred_item['answer_sentences'], gt_item['answer_sentences']):
# if pred_sentence['sentence'].strip() != gt_sentence['sentence'].strip():
# print(f"Sentence mismatch: {pred_sentence['sentence']} != {gt_sentence['sentence']}")
# continue
for span_type, dataset_scores, dataset_prf1 in [
("citation", dataset_citation_scores, dataset_citation_prf1),
("conflict", dataset_conflict_scores, dataset_conflict_prf1)
]:
doc_scores = []
for pred_doc, gt_doc in zip(pred_sentence[f'{span_type}_spans'], gt_sentence[f'{span_type}_spans']):
# Handle citation spans
grouped_pred_spans = get_spans_from_json(pred_doc)
grouped_true_spans = get_spans_from_json(gt_doc)
# grouped_pred_spans and grouped_true_spans are lists of lists of tuples
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else []
# Get binary classifications
gt_labels, pred_labels = create_binary_classifications(
pred_spans, grouped_true_spans, iou_threshold=0.4
)
# Compute IoU scores only where GT is 1
#TODO: This is almost non-existent for conflicting spans.
doc_scores = [iou_metric.compute([pred_span], [true_span]) for pred_span, true_span, gt_label in zip(pred_spans, grouped_true_spans, gt_labels) if gt_label == 1]
all_gt_labels.extend(gt_labels)
all_pred_labels.extend(pred_labels)
# Store score in the prediction
pred_sentence[f'{span_type}_iou_scores'] = doc_scores
# Calculate mean scores if there are any valid scores
# for that sentence in both the prediction and the ground truth
#TODO: Check this assumption.
pred_sentence[f'mean_{span_type}_iou'] = statistics.mean(doc_scores) if doc_scores else None
dataset_scores.append(pred_sentence[f'mean_{span_type}_iou'])
# Save the augmented predictions if output path is provided
if args.output:
with open(args.output, 'w') as f:
json.dump(pred_data, f, indent=2)
print(f"Augmented predictions saved to: {args.output}")
# Print summary statistics
summary_stats = {
"citation": {},
"conflict": {}
}
print(f"\nSummary Statistics:")
for span_type, dataset_scores in [
("citation", dataset_citation_scores),
("conflict", dataset_conflict_scores)
]:
# IoU summary statistics
valid_scores = [score for score in dataset_scores if score is not None]
print(f"\nNumber of valid IoU scores ({span_type}): {len(valid_scores)} / {len(dataset_scores)}")
if valid_scores:
summary_stats[span_type]['mean_iou'] = sum(valid_scores) / len(valid_scores)
summary_stats[span_type]['median_iou'] = statistics.median(valid_scores)
summary_stats[span_type]['std_iou'] = statistics.stdev(valid_scores)
print(f"Average IoU Score: {summary_stats[span_type]['mean_iou']:.4f}")
print(f"Median IoU Score: {summary_stats[span_type]['median_iou']:.4f}")
print(f"Standard Deviation: {summary_stats[span_type]['std_iou']:.4f}")
else:
print(f"No valid {span_type} IoU scores to compute statistics.")
# PRF1 summary statistics
precision = precision_score(all_gt_labels, all_pred_labels, zero_division=0)
recall = recall_score(all_gt_labels, all_pred_labels, zero_division=0)
f1 = f1_score(all_gt_labels, all_pred_labels, zero_division=0)
print(f"\n\nDocument-level metrics:")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1: {f1:.4f}")
summary_stats[span_type]['precision'] = precision
summary_stats[span_type]['recall'] = recall
summary_stats[span_type]['f1'] = f1
# Save the summary statistics if output path is provided
if args.output:
summary_stats_path = args.output.replace('.json', '_summary.json')
with open(summary_stats_path, 'w') as f:
json.dump(summary_stats, f, indent=2)
print(f"\nSummary statistics saved to: {summary_stats_path}")
print()
if __name__ == "__main__":
main()