Skip to content
Permalink
4d623e0e9d
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
203 lines (188 sloc) 7.98 KB
import argparse
import os
import json
import torch
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
from citation_systems import (
SlidingWindowSystem,
GradientBasedSystem,
PromptBasedSystem,
)
from utils.token_text_mapper import RAGTokenTextMapper
from utils.config_loader import ConfigLoader
from tqdm import tqdm
from prompts import get_rag_messages
from utils.reporting_helpers import write_results_to_excel, write_results_to_annotation_json
def load_eli5_data(file_path: str) -> List[Dict]:
with open(file_path, 'r') as f:
return json.load(f)
def load_eli5_ans_data(file_path: str) -> List[str]:
with open(file_path, 'r') as f:
return json.load(f)
def get_doc_spans(spans, doc_idx):
doc_spans = [
s.document_span for s in spans
if s.document_span and s.document_span.document_index == doc_idx
]
spans_bounds = [s.text_rel_bounds for s in doc_spans]
spans_text = [s.text for s in doc_spans]
return spans_bounds, spans_text
def main(args):
# Load config
config_loader = ConfigLoader()
config = config_loader.load_config(args.yaml_base_config, args.yaml_config)
# Config overrides from args
if args.window_batch_size and args.window_batch_size != config["window_batch_size"]:
print("Setting window_batch_size to", args.window_batch_size)
config["window_batch_size"] = args.window_batch_size
if args.gen_batch_size and args.gen_batch_size != config["gen_batch_size"]:
print("Setting gen_batch_size to", args.gen_batch_size)
config["gen_batch_size"] = args.gen_batch_size
if args.gen_seed and args.gen_seed != config["gen_seed"]:
print("Setting gen_seed to", args.gen_seed)
config["gen_seed"] = args.gen_seed
# Load evaluation data
eli5_data = load_eli5_data(args.eli5_path)
eli5_ans_data = load_eli5_ans_data(args.eli5_ans_path) if args.eli5_ans_path else None
if eli5_ans_data and len(eli5_ans_data) != len(eli5_data):
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match")
# Prepare the output directory
config_name = os.path.splitext(os.path.basename(args.yaml_config))[0]
exp_output_path = os.path.join(args.output_path, config_name, args.attribution_system)
os.makedirs(exp_output_path, exist_ok=True)
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(config["model"])
model = AutoModelForCausalLM.from_pretrained(config["model"], torch_dtype=torch.float16, device_map="auto")
# Initialize the attribution system
system_kwargs = {"model": model, "tokenizer": tokenizer}
if args.attribution_system != "prompt_based":
system_kwargs["z_threshold"] = config["z_threshold"]
system_kwargs["smoothing_window_size"] = config.get("smoothing_window_size")
if args.attribution_system == "sliding_window":
attribution_system = SlidingWindowSystem(window_batch_size=config["window_batch_size"], **system_kwargs)
elif args.attribution_system == "gradient_based":
attribution_system = GradientBasedSystem(**system_kwargs)
else:
attribution_system = PromptBasedSystem(config, **system_kwargs)
# Run the system on the evaluation data
results = []
for i, sample in enumerate(tqdm(eli5_data, desc="Running samples")):
if args.debug_num_samples and i == args.debug_num_samples:
break
question = sample["question"]
answer = eli5_ans_data[i] if eli5_ans_data else sample["answer"]
documents = sample["docs"]
question_context = sample["question_ctx"]
# Compute the attribution results
if isinstance(attribution_system, PromptBasedSystem):
citation_results = attribution_system.generate_citations(question, answer, documents, question_context)
else:
messages = get_rag_messages(config, question, documents, question_context, answer)
input_ids = attribution_system.tokenizer.apply_chat_template(messages, return_tensors="pt")
token_text_mapper = RAGTokenTextMapper(
attribution_system.tokenizer,
config["context_regex"],
config["document_regex"],
input_ids=input_ids,
)
context_length = token_text_mapper.get_context_length()
document_bounds = token_text_mapper.get_document_bounds()
citation_results = attribution_system.generate_citations(
token_text_mapper,
context_length,
document_bounds=document_bounds,
)
# Collect the results
for j, citation_result in enumerate(citation_results.results):
for doc_idx, doc in enumerate(documents):
citation_spans_bounds, citation_spans_text = get_doc_spans(citation_result.citation_spans, doc_idx)
conflict_spans_bounds, conflict_spans_text = get_doc_spans(citation_result.conflict_spans, doc_idx)
results.append({
"sample_id": i+1,
"question_context": question_context,
"question": question,
"answer_sentence_id": j+1,
"answer_sentence": citation_result.answer_text.strip(),
"doc_id": doc_idx+1,
"doc_title": doc["title"],
"doc_text": doc["text"],
"citation_spans_bounds": citation_spans_bounds,
"citation_spans_text": citation_spans_text,
"conflict_spans_bounds": conflict_spans_bounds,
"conflict_spans_text": conflict_spans_text,
})
# Save the results
output_filename = os.path.join(exp_output_path, "eli5.xlsx")
write_results_to_excel(
results,
output_filename,
wider_columns=["question_context", "doc_text", "citation_spans_text", "conflict_spans_text"],
)
output_filename = os.path.join(exp_output_path, "eli5.json")
write_results_to_annotation_json(results, output_filename)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run attribution evaluation on ELI5 dataset")
parser.add_argument(
"--eli5_path",
type=str,
help="Path to the ELI5 dataset JSON file",
default="./data/eli5_eval_bm25_top100_reranked_oracle.json",
)
parser.add_argument(
"--eli5_ans_path",
type=str,
help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.",
default=None,
)
parser.add_argument(
"--attribution_system",
type=str,
choices=["sliding_window", "gradient_based", "prompt_based"],
help="Attribution system to use",
default="sliding_window",
)
parser.add_argument(
"--output_path",
type=str,
help="Path to output all results",
default="attribution_results",
)
parser.add_argument(
"--yaml_base_config",
type=str, help="Path to the base YAML config file",
default="base_config.yaml",
)
parser.add_argument(
"--yaml_config",
type=str,
help="Path to the experiment-specific YAML config file",
default="mistral_7B.yaml",
)
parser.add_argument(
"--window_batch_size",
type=int,
help="Batch size for sliding_window system",
default=4,
)
parser.add_argument(
"--gen_batch_size",
type=int,
help="Batch size for generation in prompt_based system",
default=4,
)
parser.add_argument(
"--gen_seed",
type=int,
help="Random seed for generation in prompt_based system",
default=42,
)
parser.add_argument(
"--debug_num_samples",
type=int,
help="Number of samples to run for debugging",
default=None,
)
args = parser.parse_args()
print("Running with args:", args)
main(args)