Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/run_auto_annotate.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
120 lines (110 sloc)
4.14 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils.auto_annotate import generate_annotations | |
from tqdm import tqdm | |
from dotenv import load_dotenv, find_dotenv | |
import argparse | |
import json | |
import os | |
import pandas as pd | |
def load_eli5(eli5_path, ans_path=None): | |
with open(eli5_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
if ans_path is not None: | |
with open(ans_path, "r", encoding="utf-8") as f: | |
ans_data = json.load(f) | |
if len(ans_data) != len(data): | |
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match") | |
for example, ans in zip(data, ans_data): | |
example["answer"] = ans | |
return data | |
def load_annotations(eli5_path, ann_path, ans_path=None): | |
if not os.path.exists(ann_path): | |
eli5_data = load_eli5(eli5_path, ans_path) | |
empty_ann_data = [ | |
{ | |
"answer_sentences": [], | |
"span_annotation_done": False | |
} | |
for _ in eli5_data | |
] | |
save_annotations(ann_path, empty_ann_data) | |
with open(ann_path, "r", encoding="utf-8") as f: | |
ann_data = json.load(f) | |
return ann_data | |
def save_annotations(ann_path, ann_data): | |
with open(ann_path, "w", encoding="utf-8") as f: | |
json.dump(ann_data, f, indent=4) | |
def save_response_logs(response_logs, ann_path): | |
log_path = ann_path.replace(".json", "_response_logs.csv") | |
# append if exists otherwise create | |
prev_log_df = None | |
if os.path.exists(log_path): | |
prev_log_df = pd.read_csv(log_path) | |
log_df = pd.DataFrame(response_logs) | |
if prev_log_df is not None: | |
log_df = pd.concat([prev_log_df, log_df], ignore_index=True) | |
log_df.to_csv(log_path, index=False) | |
def run(args): | |
if args.eli5_ans_path is not None: | |
ann_base_path = os.path.dirname(args.eli5_ans_path) | |
else: | |
ann_base_path = os.path.splitext(args.eli5_path)[0] + "_answers_original" | |
ann_path = os.path.join(ann_base_path, "spans", args.model_name, f"{args.prompt_type}.json") | |
os.makedirs(os.path.dirname(ann_path), exist_ok=True) | |
eli5_data = load_eli5(args.eli5_path, args.eli5_ans_path) | |
ann_data = load_annotations(args.eli5_path, ann_path, args.eli5_ans_path) | |
for i, (example, ann) in enumerate(tqdm(zip(eli5_data, ann_data), total=len(eli5_data))): | |
if args.debug_num_samples and i == args.debug_num_samples: | |
break | |
if not ann["span_annotation_done"]: | |
response_logs = generate_annotations(example, ann, args.model_name, args.prompt_type, temperature=args.temperature) | |
ann["span_annotation_done"] = True | |
save_annotations(ann_path, ann_data) | |
# add response to log | |
for log in response_logs: | |
log["example_idx"] = i | |
# append to log CSV file | |
save_response_logs(response_logs, ann_path) | |
print("Done!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--eli5_path", | |
type=str, | |
help="Path to the ELI5 dataset JSON file", | |
default="./data/eli5_eval_bm25_top100_reranked_oracle.json", | |
) | |
parser.add_argument( | |
"--eli5_ans_path", | |
type=str, | |
help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", | |
default=None, | |
) | |
parser.add_argument( | |
"--model_name", | |
type=str, | |
help="Name of the model to use for span annotation", | |
choices=["gpt-4o", "claude-3-5-sonnet-20241022"], | |
default="gpt-4o", | |
) | |
parser.add_argument( | |
"--prompt_type", | |
type=str, | |
help="Type of prompt to use for span annotation", | |
choices=["unified", "per_answer_sentence_per_doc"], | |
default="unified", | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
help="Temperature for span annotation", | |
default=0.0, | |
) | |
parser.add_argument( | |
"--debug_num_samples", | |
type=int, | |
help="Number of samples to run for debugging", | |
default=None, | |
) | |
args = parser.parse_args() | |
_ = load_dotenv(find_dotenv(), override=True) | |
run(args) |