Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/run_pipeline_new.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
203 lines (188 sloc)
7.98 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import json | |
import torch | |
from typing import List, Dict | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from citation_systems import ( | |
SlidingWindowSystem, | |
GradientBasedSystem, | |
PromptBasedSystem, | |
) | |
from utils.token_text_mapper import RAGTokenTextMapper | |
from utils.config_loader import ConfigLoader | |
from tqdm import tqdm | |
from prompts import get_rag_messages | |
from utils.reporting_helpers import write_results_to_excel, write_results_to_annotation_json | |
def load_eli5_data(file_path: str) -> List[Dict]: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def load_eli5_ans_data(file_path: str) -> List[str]: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def get_doc_spans(spans, doc_idx): | |
doc_spans = [ | |
s.document_span for s in spans | |
if s.document_span and s.document_span.document_index == doc_idx | |
] | |
spans_bounds = [s.text_rel_bounds for s in doc_spans] | |
spans_text = [s.text for s in doc_spans] | |
return spans_bounds, spans_text | |
def main(args): | |
# Load config | |
config_loader = ConfigLoader() | |
config = config_loader.load_config(args.yaml_base_config, args.yaml_config) | |
# Config overrides from args | |
if args.window_batch_size and args.window_batch_size != config["window_batch_size"]: | |
print("Setting window_batch_size to", args.window_batch_size) | |
config["window_batch_size"] = args.window_batch_size | |
if args.gen_batch_size and args.gen_batch_size != config["gen_batch_size"]: | |
print("Setting gen_batch_size to", args.gen_batch_size) | |
config["gen_batch_size"] = args.gen_batch_size | |
if args.gen_seed and args.gen_seed != config["gen_seed"]: | |
print("Setting gen_seed to", args.gen_seed) | |
config["gen_seed"] = args.gen_seed | |
# Load evaluation data | |
eli5_data = load_eli5_data(args.eli5_path) | |
eli5_ans_data = load_eli5_ans_data(args.eli5_ans_path) if args.eli5_ans_path else None | |
if eli5_ans_data and len(eli5_ans_data) != len(eli5_data): | |
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match") | |
# Prepare the output directory | |
config_name = os.path.splitext(os.path.basename(args.yaml_config))[0] | |
exp_output_path = os.path.join(args.output_path, config_name, args.attribution_system) | |
os.makedirs(exp_output_path, exist_ok=True) | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(config["model"]) | |
model = AutoModelForCausalLM.from_pretrained(config["model"], torch_dtype=torch.float16, device_map="auto") | |
# Initialize the attribution system | |
system_kwargs = {"model": model, "tokenizer": tokenizer} | |
if args.attribution_system != "prompt_based": | |
system_kwargs["z_threshold"] = config["z_threshold"] | |
system_kwargs["smoothing_window_size"] = config.get("smoothing_window_size") | |
if args.attribution_system == "sliding_window": | |
attribution_system = SlidingWindowSystem(window_batch_size=config["window_batch_size"], **system_kwargs) | |
elif args.attribution_system == "gradient_based": | |
attribution_system = GradientBasedSystem(**system_kwargs) | |
else: | |
attribution_system = PromptBasedSystem(config, **system_kwargs) | |
# Run the system on the evaluation data | |
results = [] | |
for i, sample in enumerate(tqdm(eli5_data, desc="Running samples")): | |
if args.debug_num_samples and i == args.debug_num_samples: | |
break | |
question = sample["question"] | |
answer = eli5_ans_data[i] if eli5_ans_data else sample["answer"] | |
documents = sample["docs"] | |
question_context = sample["question_ctx"] | |
# Compute the attribution results | |
if isinstance(attribution_system, PromptBasedSystem): | |
citation_results = attribution_system.generate_citations(question, answer, documents, question_context) | |
else: | |
messages = get_rag_messages(config, question, documents, question_context, answer) | |
input_ids = attribution_system.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
token_text_mapper = RAGTokenTextMapper( | |
attribution_system.tokenizer, | |
config["context_regex"], | |
config["document_regex"], | |
input_ids=input_ids, | |
) | |
context_length = token_text_mapper.get_context_length() | |
document_bounds = token_text_mapper.get_document_bounds() | |
citation_results = attribution_system.generate_citations( | |
token_text_mapper, | |
context_length, | |
document_bounds=document_bounds, | |
) | |
# Collect the results | |
for j, citation_result in enumerate(citation_results.results): | |
for doc_idx, doc in enumerate(documents): | |
citation_spans_bounds, citation_spans_text = get_doc_spans(citation_result.citation_spans, doc_idx) | |
conflict_spans_bounds, conflict_spans_text = get_doc_spans(citation_result.conflict_spans, doc_idx) | |
results.append({ | |
"sample_id": i+1, | |
"question_context": question_context, | |
"question": question, | |
"answer_sentence_id": j+1, | |
"answer_sentence": citation_result.answer_text.strip(), | |
"doc_id": doc_idx+1, | |
"doc_title": doc["title"], | |
"doc_text": doc["text"], | |
"citation_spans_bounds": citation_spans_bounds, | |
"citation_spans_text": citation_spans_text, | |
"conflict_spans_bounds": conflict_spans_bounds, | |
"conflict_spans_text": conflict_spans_text, | |
}) | |
# Save the results | |
output_filename = os.path.join(exp_output_path, "eli5.xlsx") | |
write_results_to_excel( | |
results, | |
output_filename, | |
wider_columns=["question_context", "doc_text", "citation_spans_text", "conflict_spans_text"], | |
) | |
output_filename = os.path.join(exp_output_path, "eli5.json") | |
write_results_to_annotation_json(results, output_filename) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run attribution evaluation on ELI5 dataset") | |
parser.add_argument( | |
"--eli5_path", | |
type=str, | |
help="Path to the ELI5 dataset JSON file", | |
default="./data/eli5_eval_bm25_top100_reranked_oracle.json", | |
) | |
parser.add_argument( | |
"--eli5_ans_path", | |
type=str, | |
help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", | |
default=None, | |
) | |
parser.add_argument( | |
"--attribution_system", | |
type=str, | |
choices=["sliding_window", "gradient_based", "prompt_based"], | |
help="Attribution system to use", | |
default="sliding_window", | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
help="Path to output all results", | |
default="attribution_results", | |
) | |
parser.add_argument( | |
"--yaml_base_config", | |
type=str, help="Path to the base YAML config file", | |
default="base_config.yaml", | |
) | |
parser.add_argument( | |
"--yaml_config", | |
type=str, | |
help="Path to the experiment-specific YAML config file", | |
default="mistral_7B.yaml", | |
) | |
parser.add_argument( | |
"--window_batch_size", | |
type=int, | |
help="Batch size for sliding_window system", | |
default=4, | |
) | |
parser.add_argument( | |
"--gen_batch_size", | |
type=int, | |
help="Batch size for generation in prompt_based system", | |
default=4, | |
) | |
parser.add_argument( | |
"--gen_seed", | |
type=int, | |
help="Random seed for generation in prompt_based system", | |
default=42, | |
) | |
parser.add_argument( | |
"--debug_num_samples", | |
type=int, | |
help="Number of samples to run for debugging", | |
default=None, | |
) | |
args = parser.parse_args() | |
print("Running with args:", args) | |
main(args) |