Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/run_eval_char_binary.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
172 lines (153 sloc)
8.36 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import numpy as np | |
import argparse | |
from sklearn.metrics import precision_score, recall_score, f1_score | |
def load_json_data(json_file): | |
""" | |
Load data from the JSON file. | |
Returns: | |
list: The data from the JSON file. | |
""" | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
return data | |
def parse_args(): | |
""" | |
Parse command line arguments. | |
""" | |
parser = argparse.ArgumentParser(description='Compute metrics on the pipeline output') | |
parser.add_argument("--eli5_path", type=str, | |
help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json") | |
parser.add_argument('--gt_path', type=str, | |
help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json') | |
parser.add_argument('--pred_path', type=str, | |
help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json') | |
parser.add_argument("--doc_level_f1_threshold", type=float, | |
help="Threshold for the document-level F1 score to determine if a document can be considered 'correctly' cited", default=0.5) | |
return parser.parse_args() | |
def get_span_group_char_indicators(doc_len, span_group): | |
char_inds = np.zeros(doc_len, dtype=int) | |
for span in span_group: | |
char_inds[span["start"]:span["end"]] = 1 | |
return char_inds | |
def add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds): | |
# Compute summary statistics | |
num_samples = len(gt_inds) | |
num_gt_positive = np.sum(gt_inds).item() | |
num_pred_positive = np.sum(pred_inds).item() | |
percent_gt_positive = 100 * (num_gt_positive / num_samples) | |
percent_pred_positive = 100 * (num_pred_positive / num_samples) | |
precision = precision_score(gt_inds, pred_inds) | |
recall = recall_score(gt_inds, pred_inds) | |
f1 = f1_score(gt_inds, pred_inds) | |
# Add to the dictionary | |
summary_stats[span_type][level] = { | |
'num_samples': num_samples, | |
'num_gt_positive': num_gt_positive, | |
'num_pred_positive': num_pred_positive, | |
'percent_gt_positive': percent_gt_positive, | |
'percent_pred_positive': percent_pred_positive, | |
'precision': precision, | |
'recall': recall, | |
'f1': f1, | |
} | |
# Print out | |
print() | |
print(f"({level}-level) Number of GT positives: {num_gt_positive} / {num_samples} ({percent_gt_positive:.2f}%)") | |
print(f"({level}-level) Number of predicted positives: {num_pred_positive} / {num_samples} ({percent_pred_positive:.2f}%)") | |
print(f"({level}-level) Precision: {precision:.4f}") | |
print(f"({level}-level) Recall: {recall:.4f}") | |
print(f"({level}-level) F1: {f1:.4f}") | |
print() | |
def main(): | |
# Parse command line arguments | |
args = parse_args() | |
print("\n---------------------------------------------------------------------------------------------------------") | |
print(f">>> Computing metrics for {args.pred_path}") | |
print("---------------------------------------------------------------------------------------------------------\n") | |
# Load both JSON files | |
eli5_data = load_json_data(args.eli5_path) | |
gt_data = load_json_data(args.gt_path) | |
pred_data = load_json_data(args.pred_path) | |
# Iterate through the data | |
citation_gt_char_inds = [] | |
citation_pred_char_inds = [] | |
conflict_gt_char_inds = [] | |
conflict_pred_char_inds = [] | |
citation_gt_doc_inds = [] | |
citation_pred_doc_inds = [] | |
conflict_gt_doc_inds = [] | |
conflict_pred_doc_inds = [] | |
for eli5_item, gt_item, pred_item in zip(eli5_data, gt_data, pred_data): | |
for gt_sentence, pred_sentence in zip(gt_item['answer_sentences'], pred_item['answer_sentences']): | |
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [ | |
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), | |
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds), | |
]: | |
for i, (gt_doc_span_groups, pred_doc_span_groups) in enumerate( | |
zip(gt_sentence[f'{span_type}_spans'], pred_sentence[f'{span_type}_spans']) | |
): | |
# Get pred and gt indicators {0, 1} for each character position in the document | |
doc_len = len(eli5_item["docs"][i]["text"]) | |
# Predicted spans aren't grouped so there will only be one group | |
pred_doc_spans = pred_doc_span_groups[0] if pred_doc_span_groups else [] | |
pred_doc_char_inds = get_span_group_char_indicators(doc_len, pred_doc_spans) | |
# Ground truth spans are grouped | |
gt_doc_span_groups = gt_doc_span_groups if gt_doc_span_groups else [[]] | |
gt_doc_group_char_inds = [ | |
get_span_group_char_indicators(doc_len, gt_doc_spans) | |
for gt_doc_spans in gt_doc_span_groups | |
] | |
# get the f1 score for each group vs the predicted indicators and pick the best matching group | |
# (GT span groups are OR'd together meaning that the predicted spans can match any of the groups to be correct) | |
doc_group_f1_scores = [ | |
f1_score(gt_doc_char_inds, pred_doc_char_inds, zero_division=0) | |
for gt_doc_char_inds in gt_doc_group_char_inds | |
] | |
best_doc_group_f1_idx = np.argmax(doc_group_f1_scores) | |
doc_f1_score = doc_group_f1_scores[best_doc_group_f1_idx] | |
gt_doc_char_inds = gt_doc_group_char_inds[best_doc_group_f1_idx] | |
# Append to the character-level indicator lists | |
gt_char_inds.append(gt_doc_char_inds) | |
pred_char_inds.append(pred_doc_char_inds) | |
# get pred and gt indicators {0, 1} for each document | |
gt_doc_ind = 1 if np.sum(gt_doc_char_inds) > 0 else 0 | |
if gt_doc_ind == 0: | |
# If no GT spans exist than consider the document cited if any predicted spans exist | |
pred_doc_ind = 1 if np.sum(pred_doc_char_inds) > 0 else 0 | |
else: | |
# If GT spans exist than consider the document cited if the F1 score is above the threshold | |
# i.e. the predicted citation spans are close enough to the GT spans | |
pred_doc_ind = 1 if doc_f1_score >= args.doc_level_f1_threshold else 0 | |
# Append to the document-level indicator lists | |
gt_doc_inds.append(gt_doc_ind) | |
pred_doc_inds.append(pred_doc_ind) | |
citation_gt_char_inds = np.concatenate(citation_gt_char_inds) | |
citation_pred_char_inds = np.concatenate(citation_pred_char_inds) | |
conflict_gt_char_inds = np.concatenate(conflict_gt_char_inds) | |
conflict_pred_char_inds = np.concatenate(conflict_pred_char_inds) | |
citation_gt_doc_inds = np.array(citation_gt_doc_inds) | |
citation_pred_doc_inds = np.array(citation_pred_doc_inds) | |
conflict_gt_doc_inds = np.array(conflict_gt_doc_inds) | |
conflict_pred_doc_inds = np.array(conflict_pred_doc_inds) | |
# Print summary statistics | |
summary_stats = { | |
"citation": {}, | |
"conflict": {}, | |
} | |
for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [ | |
("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), | |
("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds), | |
]: | |
print("-------------------------") | |
print(f"> Span type: {span_type}") | |
print("-------------------------") | |
for level, gt_inds, pred_inds in [("char", gt_char_inds, pred_char_inds), ("doc", gt_doc_inds, pred_doc_inds)]: | |
add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds) | |
# Save the summary statistics | |
output_path = args.pred_path.replace(".json", "_summary_stats.json") | |
with open(output_path, 'w') as f: | |
json.dump(summary_stats, f, indent=2) | |
print(f"Summary statistics saved to: {output_path}") | |
if __name__ == "__main__": | |
main() |