Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
from utils.auto_annotate import generate_annotations
from tqdm import tqdm
from dotenv import load_dotenv, find_dotenv
import argparse
import json
import os
import pandas as pd
def load_eli5(eli5_path, ans_path=None):
with open(eli5_path, "r", encoding="utf-8") as f:
data = json.load(f)
if ans_path is not None:
with open(ans_path, "r", encoding="utf-8") as f:
ans_data = json.load(f)
if len(ans_data) != len(data):
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match")
for example, ans in zip(data, ans_data):
example["answer"] = ans
return data
def load_annotations(eli5_path, ann_path, ans_path=None):
if not os.path.exists(ann_path):
eli5_data = load_eli5(eli5_path, ans_path)
empty_ann_data = [
{
"answer_sentences": [],
"span_annotation_done": False
}
for _ in eli5_data
]
save_annotations(ann_path, empty_ann_data)
with open(ann_path, "r", encoding="utf-8") as f:
ann_data = json.load(f)
return ann_data
def save_annotations(ann_path, ann_data):
with open(ann_path, "w", encoding="utf-8") as f:
json.dump(ann_data, f, indent=4)
def save_response_logs(response_logs, ann_path):
log_path = ann_path.replace(".json", "_response_logs.csv")
# append if exists otherwise create
prev_log_df = None
if os.path.exists(log_path):
prev_log_df = pd.read_csv(log_path)
log_df = pd.DataFrame(response_logs)
if prev_log_df is not None:
log_df = pd.concat([prev_log_df, log_df], ignore_index=True)
log_df.to_csv(log_path, index=False)
def run(args):
if args.eli5_ans_path is not None:
ann_base_path = os.path.dirname(args.eli5_ans_path)
else:
ann_base_path = os.path.splitext(args.eli5_path)[0] + "_answers_original"
ann_path = os.path.join(ann_base_path, "spans", args.model_name, f"{args.prompt_type}.json")
os.makedirs(os.path.dirname(ann_path), exist_ok=True)
eli5_data = load_eli5(args.eli5_path, args.eli5_ans_path)
ann_data = load_annotations(args.eli5_path, ann_path, args.eli5_ans_path)
for i, (example, ann) in enumerate(tqdm(zip(eli5_data, ann_data), total=len(eli5_data))):
if args.debug_num_samples and i == args.debug_num_samples:
break
if not ann["span_annotation_done"]:
response_logs = generate_annotations(example, ann, args.model_name, args.prompt_type, temperature=args.temperature)
ann["span_annotation_done"] = True
save_annotations(ann_path, ann_data)
# add response to log
for log in response_logs:
log["example_idx"] = i
# append to log CSV file
save_response_logs(response_logs, ann_path)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--eli5_path",
type=str,
help="Path to the ELI5 dataset JSON file",
default="./data/eli5_eval_bm25_top100_reranked_oracle.json",
)
parser.add_argument(
"--eli5_ans_path",
type=str,
help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.",
default=None,
)
parser.add_argument(
"--model_name",
type=str,
help="Name of the model to use for span annotation",
choices=["gpt-4o", "claude-3-5-sonnet-20241022"],
default="gpt-4o",
)
parser.add_argument(
"--prompt_type",
type=str,
help="Type of prompt to use for span annotation",
choices=["unified", "per_answer_sentence_per_doc"],
default="unified",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for span annotation",
default=0.0,
)
parser.add_argument(
"--debug_num_samples",
type=int,
help="Number of samples to run for debugging",
default=None,
)
args = parser.parse_args()
_ = load_dotenv(find_dotenv(), override=True)
run(args)