Skip to content
Permalink
4d623e0e9d
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
151 lines (131 sloc) 6.35 KB
import json
import pandas as pd
import ast
import statistics
import argparse
from metrics.iou import IoU
def load_excel_data(excel_file):
"""
Load data from the Excel file.
Returns:
DataFrame: The data from the Excel file.
"""
df = pd.read_excel(excel_file)
return df
def load_json_data(json_file):
"""
Load data from the JSON file.
Returns:
list: The data from the JSON file.
"""
with open(json_file, 'r') as f:
data = json.load(f)
return data
def get_spans_from_row(row):
"""
Extract spans from a row in the DataFrame.
Returns:
list of tuples: List of spans.
"""
# Assuming 'citation_spans_bound' is a column in the DataFrame containing spans as strings
# For example: "[(start1, end1), (start2, end2), ...]"
spans = ast.literal_eval(row['citation_spans_bound']) if row['citation_spans_bound'] else []
return spans
def get_spans_from_json(doc_item):
"""
Extract spans from a document item in the JSON data.
Returns:
list of tuples: List of spans.
"""
grouped_spans = []
# 'citation_spans' is assumed to be in the structure as per the given JSON
# Adjust the parsing logic based on the exact structure
for span_group in doc_item:
spans = []
for span in span_group:
spans.append((span['start'], span['end']))
grouped_spans.append(spans)
return grouped_spans
def parse_args():
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser(description='Compute IoU scores between Excel and JSON data')
parser.add_argument('--pred', type=str,
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json')
parser.add_argument('--gt', type=str,
help='Path to the JSON file containing the ground truth annotations', default='./eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_joek_8502.json')
parser.add_argument('--output', type=str, help='Path to save the output to a JSON file (optional)', default='./eli5_salsa_sample_iou.json')
return parser.parse_args()
def main():
# Parse command line arguments
args = parse_args()
# Load both JSON files
pred_data = load_json_data(args.pred)
gt_data = load_json_data(args.gt)
#TODO: Remove this once we have the full dataset!
pred_data = pred_data[50:]
gt_data = gt_data[50:]
# Initialize IoU metric
iou_metric = IoU()
dataset_citation_scores = []
dataset_conflict_scores = []
# Iterate through the data
for i, (pred_item, gt_item) in enumerate(zip(pred_data, gt_data)):
# Get spans from both JSONs
for pred_sentence, gt_sentence in zip(pred_item['answer_sentences'], gt_item['answer_sentences']):
# if pred_sentence['sentence'].strip() != gt_sentence['sentence'].strip():
# print(f"Sentence mismatch: {pred_sentence['sentence']} != {gt_sentence['sentence']}")
# continue
doc_scores = []
for pred_doc, gt_doc in zip(pred_sentence['citation_spans'], gt_sentence['citation_spans']):
# Handle citation spans
grouped_pred_spans = get_spans_from_json(pred_doc)
grouped_true_spans = get_spans_from_json(gt_doc)
# grouped_pred_spans and grouped_true_spans are lists of lists of tuples
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else []
# Compute IoU
iou_score = iou_metric.compute(pred_spans, grouped_true_spans)
if iou_score is not None:
doc_scores.append(iou_score)
# Handle conflict spans similarly
conflict_scores = []
for pred_doc, gt_doc in zip(pred_sentence['conflict_spans'], gt_sentence['conflict_spans']):
grouped_pred_spans = get_spans_from_json(pred_doc)
grouped_true_spans = get_spans_from_json(gt_doc)
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else []
# Compute IoU for conflicts
conflict_iou = iou_metric.compute(pred_spans, grouped_true_spans)
if conflict_iou is not None:
conflict_scores.append(conflict_iou)
# Store both scores in the prediction
pred_sentence['citation_iou_scores'] = doc_scores
pred_sentence['conflict_iou_scores'] = conflict_scores
# Calculate mean scores if there are any valid scores
pred_sentence['mean_citation_iou'] = statistics.mean(doc_scores) if doc_scores else 1.0 # If no IoU score is computed, we assume it is 1.0 because the means that there are no citations
# for that sentence in both the prediction and the ground truth
pred_sentence['mean_conflict_iou'] = statistics.mean(conflict_scores) if conflict_scores else 1.0
dataset_citation_scores.append(pred_sentence['mean_citation_iou'])
dataset_conflict_scores.append(pred_sentence['mean_conflict_iou'])
# Save the augmented predictions if output path is provided
if args.output:
with open(args.output, 'w') as f:
json.dump(pred_data, f, indent=2)
print(f"Augmented predictions saved to: {args.output}")
# Print summary statistics
valid_citation_scores = [score for score in dataset_citation_scores if score is not None]
valid_conflict_scores = [score for score in dataset_conflict_scores if score is not None]
print(f"\nSummary Statistics:")
print(f"Number of samples: {len(valid_citation_scores)}")
if valid_citation_scores:
print(f"Average IoU Score: {sum(valid_citation_scores) / len(valid_citation_scores):.4f}")
print(f"Median IoU Score: {statistics.median(valid_citation_scores):.4f}")
print(f"Standard Deviation: {statistics.stdev(valid_citation_scores):.4f}")
if valid_conflict_scores:
print(f"Average Conflict IoU Score: {sum(valid_conflict_scores) / len(valid_conflict_scores):.4f}")
print(f"Median Conflict IoU Score: {statistics.median(valid_conflict_scores):.4f}")
print(f"Standard Deviation: {statistics.stdev(valid_conflict_scores):.4f}")
else:
print("No valid IoU scores to compute statistics.")
if __name__ == "__main__":
main()