Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/compute_iou.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
151 lines (131 sloc)
6.35 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import pandas as pd | |
import ast | |
import statistics | |
import argparse | |
from metrics.iou import IoU | |
def load_excel_data(excel_file): | |
""" | |
Load data from the Excel file. | |
Returns: | |
DataFrame: The data from the Excel file. | |
""" | |
df = pd.read_excel(excel_file) | |
return df | |
def load_json_data(json_file): | |
""" | |
Load data from the JSON file. | |
Returns: | |
list: The data from the JSON file. | |
""" | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
return data | |
def get_spans_from_row(row): | |
""" | |
Extract spans from a row in the DataFrame. | |
Returns: | |
list of tuples: List of spans. | |
""" | |
# Assuming 'citation_spans_bound' is a column in the DataFrame containing spans as strings | |
# For example: "[(start1, end1), (start2, end2), ...]" | |
spans = ast.literal_eval(row['citation_spans_bound']) if row['citation_spans_bound'] else [] | |
return spans | |
def get_spans_from_json(doc_item): | |
""" | |
Extract spans from a document item in the JSON data. | |
Returns: | |
list of tuples: List of spans. | |
""" | |
grouped_spans = [] | |
# 'citation_spans' is assumed to be in the structure as per the given JSON | |
# Adjust the parsing logic based on the exact structure | |
for span_group in doc_item: | |
spans = [] | |
for span in span_group: | |
spans.append((span['start'], span['end'])) | |
grouped_spans.append(spans) | |
return grouped_spans | |
def parse_args(): | |
""" | |
Parse command line arguments. | |
""" | |
parser = argparse.ArgumentParser(description='Compute IoU scores between Excel and JSON data') | |
parser.add_argument('--pred', type=str, | |
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json') | |
parser.add_argument('--gt', type=str, | |
help='Path to the JSON file containing the ground truth annotations', default='./eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_joek_8502.json') | |
parser.add_argument('--output', type=str, help='Path to save the output to a JSON file (optional)', default='./eli5_salsa_sample_iou.json') | |
return parser.parse_args() | |
def main(): | |
# Parse command line arguments | |
args = parse_args() | |
# Load both JSON files | |
pred_data = load_json_data(args.pred) | |
gt_data = load_json_data(args.gt) | |
#TODO: Remove this once we have the full dataset! | |
pred_data = pred_data[50:] | |
gt_data = gt_data[50:] | |
# Initialize IoU metric | |
iou_metric = IoU() | |
dataset_citation_scores = [] | |
dataset_conflict_scores = [] | |
# Iterate through the data | |
for i, (pred_item, gt_item) in enumerate(zip(pred_data, gt_data)): | |
# Get spans from both JSONs | |
for pred_sentence, gt_sentence in zip(pred_item['answer_sentences'], gt_item['answer_sentences']): | |
# if pred_sentence['sentence'].strip() != gt_sentence['sentence'].strip(): | |
# print(f"Sentence mismatch: {pred_sentence['sentence']} != {gt_sentence['sentence']}") | |
# continue | |
doc_scores = [] | |
for pred_doc, gt_doc in zip(pred_sentence['citation_spans'], gt_sentence['citation_spans']): | |
# Handle citation spans | |
grouped_pred_spans = get_spans_from_json(pred_doc) | |
grouped_true_spans = get_spans_from_json(gt_doc) | |
# grouped_pred_spans and grouped_true_spans are lists of lists of tuples | |
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else [] | |
# Compute IoU | |
iou_score = iou_metric.compute(pred_spans, grouped_true_spans) | |
if iou_score is not None: | |
doc_scores.append(iou_score) | |
# Handle conflict spans similarly | |
conflict_scores = [] | |
for pred_doc, gt_doc in zip(pred_sentence['conflict_spans'], gt_sentence['conflict_spans']): | |
grouped_pred_spans = get_spans_from_json(pred_doc) | |
grouped_true_spans = get_spans_from_json(gt_doc) | |
pred_spans = grouped_pred_spans[0] if grouped_pred_spans else [] | |
# Compute IoU for conflicts | |
conflict_iou = iou_metric.compute(pred_spans, grouped_true_spans) | |
if conflict_iou is not None: | |
conflict_scores.append(conflict_iou) | |
# Store both scores in the prediction | |
pred_sentence['citation_iou_scores'] = doc_scores | |
pred_sentence['conflict_iou_scores'] = conflict_scores | |
# Calculate mean scores if there are any valid scores | |
pred_sentence['mean_citation_iou'] = statistics.mean(doc_scores) if doc_scores else 1.0 # If no IoU score is computed, we assume it is 1.0 because the means that there are no citations | |
# for that sentence in both the prediction and the ground truth | |
pred_sentence['mean_conflict_iou'] = statistics.mean(conflict_scores) if conflict_scores else 1.0 | |
dataset_citation_scores.append(pred_sentence['mean_citation_iou']) | |
dataset_conflict_scores.append(pred_sentence['mean_conflict_iou']) | |
# Save the augmented predictions if output path is provided | |
if args.output: | |
with open(args.output, 'w') as f: | |
json.dump(pred_data, f, indent=2) | |
print(f"Augmented predictions saved to: {args.output}") | |
# Print summary statistics | |
valid_citation_scores = [score for score in dataset_citation_scores if score is not None] | |
valid_conflict_scores = [score for score in dataset_conflict_scores if score is not None] | |
print(f"\nSummary Statistics:") | |
print(f"Number of samples: {len(valid_citation_scores)}") | |
if valid_citation_scores: | |
print(f"Average IoU Score: {sum(valid_citation_scores) / len(valid_citation_scores):.4f}") | |
print(f"Median IoU Score: {statistics.median(valid_citation_scores):.4f}") | |
print(f"Standard Deviation: {statistics.stdev(valid_citation_scores):.4f}") | |
if valid_conflict_scores: | |
print(f"Average Conflict IoU Score: {sum(valid_conflict_scores) / len(valid_conflict_scores):.4f}") | |
print(f"Median Conflict IoU Score: {statistics.median(valid_conflict_scores):.4f}") | |
print(f"Standard Deviation: {statistics.stdev(valid_conflict_scores):.4f}") | |
else: | |
print("No valid IoU scores to compute statistics.") | |
if __name__ == "__main__": | |
main() |