Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/run_pipeline.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
291 lines (262 sloc)
14.3 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from typing import List, Dict | |
import numpy as np | |
import argparse | |
import ast | |
import tqdm | |
import os | |
import re | |
import gc | |
from typing import Union | |
from citation_systems import ( | |
SlidingWindowSystem, | |
GradientBasedSystem, | |
PromptBasedSystem, | |
SaliencyBasedSystem, | |
CitationResults, | |
) | |
from utils.pipeline_helpers import run_command | |
from utils.reporting_helpers import write_results_to_excel, print_stats | |
from utils.token_text_mapper import RAGTokenTextMapper | |
from utils.config_loader import ConfigLoader | |
from prompts import get_rag_messages | |
def load_eli5_data(file_path: str) -> List[Dict]: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def load_eli5_ans_data(file_path: str) -> List[str]: | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def load_xorattriqa_data(file_path: str) -> List[Dict]: | |
data = [] | |
with open(file_path, 'r') as f: | |
for line in f: | |
data.append(json.loads(line)) | |
return data | |
def preprocess_eli5_sample(sample: Dict) -> Dict: | |
return { | |
'question_ctx': sample['question_ctx'] if sample['question_ctx'] and sample['question_ctx'] != "[removed]" else None, | |
'question': sample['question'], | |
'answer': sample['answer'], | |
'docs': [{'text': doc['text'], 'title': doc['title']} for doc in sample['docs']], | |
'claims': sample['claims'] | |
} | |
def preprocess_xorattriqa_sample(sample: Dict) -> Dict: | |
return { | |
'query': sample['query_translated_en'].replace('[b]', '').strip(), | |
'answer': ast.literal_eval(sample['answers_translated_en']), | |
'documents': [{"text": sample['passage_en'], "title": "Untitled"}], | |
'attributable_gt': sample['ais'] | |
} | |
def generate_citations( | |
config: Dict[str, str], | |
citation_system: Union[SaliencyBasedSystem, PromptBasedSystem], | |
query: str, | |
answer: str, | |
documents: List[Dict[str, str]], | |
sentence_tokenize_answer: bool = True, | |
query_context=None, | |
) -> CitationResults: | |
if isinstance(citation_system, PromptBasedSystem): | |
return citation_system.generate_citations(query, answer, documents, query_context) | |
else: | |
messages = get_rag_messages(config, query, documents, query_context, answer) | |
input_ids = citation_system.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
token_text_mapper = RAGTokenTextMapper( | |
citation_system.tokenizer, | |
config["context_regex"], | |
config["document_regex"], | |
input_ids=input_ids, | |
) | |
context_length = token_text_mapper.get_context_length() | |
answer_bounds = None | |
document_bounds = token_text_mapper.get_document_bounds() | |
if not sentence_tokenize_answer: | |
answer_bounds = citation_system._get_answer_bounds_all_sentences(token_text_mapper, context_length) | |
answer_start = min([bounds[1][0] for bounds in answer_bounds]) | |
answer_end = max([bounds[1][1] for bounds in answer_bounds]) | |
answer_bounds = (answer_start, answer_end) | |
return citation_system.generate_citations( | |
token_text_mapper, | |
context_length, | |
answer_bounds, | |
document_bounds, | |
) | |
def evaluate_attribution_xorattriqa(citation_results: CitationResults, attributable: bool) -> Dict: | |
""" | |
Evaluate the attribution result for XOR-AttriQA task. | |
Args: | |
- citation_result: CitationResult object containing the citation spans and texts | |
- attributable: Whether the answer is attributable to any source (ground truth) | |
Returns: | |
- Dictionary containing the evaluation metrics: | |
- pred: Whether the system attributed the answer to any source | |
- correct_attribution: Whether the system attribution matches the ground truth | |
- num_citation_spans: Number of citation spans attributed to the answer | |
- citation_spans: List of citation span texts | |
""" | |
if not citation_results.results: | |
system_attributed = False | |
doc_spans = [] | |
else: | |
citation_result = citation_results.results[0] | |
# Check if the system attributed the answer to any source | |
doc_spans = [span for span in citation_result.citation_spans if span.document_span] | |
system_attributed = len(doc_spans) > 0 | |
# Compare system attribution with ground truth | |
correct_attribution = system_attributed == attributable | |
return { | |
"pred": system_attributed, | |
"correct_attribution": int(correct_attribution), | |
"num_citation_spans": len(doc_spans), | |
"citation_spans": [span.document_span.text for span in doc_spans], | |
} | |
def evaluate_attribution_eli5(processed_eli5_data: Dict, output_path: str, file_name="eli5") -> None: | |
""" | |
Evaluate the attribution results for ELI5 dataset using the ALCE evaluation script. | |
""" | |
# save ELI5 processed data to a JSON file | |
with open(f"{output_path}/{file_name}.json", 'w') as f: | |
json.dump(processed_eli5_data, f) | |
print("Evaluating ELI5 data...") | |
run_command(f"python run_eli5_eval.py --f {output_path}/{file_name}.json --citations --claims_nli --report {output_path}/{file_name}.xlsx") | |
return | |
def main(args): | |
# Load prompt config | |
config_loader = ConfigLoader() | |
config = config_loader.load_config(args.yaml_base_config, args.yaml_config) | |
if args.gen_seed: | |
print("Setting config['gen_seed'] to", args.gen_seed) | |
config["gen_seed"] = args.gen_seed | |
config_name = os.path.splitext(os.path.basename(args.yaml_config))[0] | |
exp_output_path = os.path.join(args.output_path, config_name, args.attribution_system) | |
os.makedirs(exp_output_path, exist_ok=True) | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(config['model']) | |
model = AutoModelForCausalLM.from_pretrained(config['model'], torch_dtype=torch.float16, device_map="auto") | |
if args.attribution_system == "sliding_window": | |
attribution_system = SlidingWindowSystem(model, tokenizer, z_threshold=config['z_threshold'], window_batch_size=config['window_batch_size'], smoothing_window_size=config.get('smoothing_window_size')) | |
elif args.attribution_system == "gradient_based": | |
attribution_system = GradientBasedSystem(model, tokenizer, z_threshold=config['z_threshold'], smoothing_window_size=config.get('smoothing_window_size')) | |
elif args.attribution_system == "prompt_based": | |
attribution_system = PromptBasedSystem(config, model, tokenizer) | |
all_xorattriqa_results = [] | |
if args.xorattriqa_path and not args.eli5_only: | |
# load all jsonl files in the directory | |
jsonl_files = [f for f in os.listdir(args.xorattriqa_path) if f.endswith('.jsonl')] | |
# remove train, val and toy files | |
jsonl_files = [f for f in jsonl_files if 'train' not in f and 'val' not in f and 'toy' not in f] | |
jsonl_paths = [os.path.join(args.xorattriqa_path, f) for f in jsonl_files] | |
xorattriqa_results = {} | |
overall_stats = {} | |
result_logs = "" | |
for jsonl_path in jsonl_paths: | |
file_name = os.path.basename(jsonl_path) | |
xorattriqa_results[file_name] = [] | |
xorattriqa_data = load_xorattriqa_data(jsonl_path) | |
print(f"Evaluating XOR-AttriQA data from {file_name}...") | |
# Add tqdm for progress bar | |
for sample in tqdm.tqdm(xorattriqa_data, desc=f"Evaluating {file_name} XOR-AttriQA"): | |
processed_sample = preprocess_xorattriqa_sample(sample) | |
if len(processed_sample['answer']) > 1: | |
# Evaluate each answer sentence separately. This is the most brutal evaluation. | |
for sentence in processed_sample['answer']: | |
citation_results = generate_citations( | |
config, | |
attribution_system, | |
processed_sample['query'], | |
sentence, | |
processed_sample['documents'], | |
sentence_tokenize_answer=False, | |
) | |
partial_sample = {**processed_sample, 'answer': sentence} | |
evaluation = evaluate_attribution_xorattriqa(citation_results, partial_sample['attributable_gt']) | |
xorattriqa_results[file_name].append({**partial_sample, **evaluation}) | |
else: | |
citation_results = generate_citations( | |
config, | |
attribution_system, | |
processed_sample['query'], | |
processed_sample['answer'][0], | |
processed_sample['documents'], | |
sentence_tokenize_answer=False, | |
) | |
evaluation = evaluate_attribution_xorattriqa(citation_results, processed_sample['attributable_gt']) | |
xorattriqa_results[file_name].append({**processed_sample, **evaluation}) | |
all_xorattriqa_results.extend(xorattriqa_results[file_name]) | |
# Write results to Excel | |
print(f"Saving {file_name} XOR-AttriQA results to Excel...") | |
write_results_to_excel(xorattriqa_results[file_name], f"{exp_output_path}/xorattriqa_{file_name.split('.')[0]}.xlsx") | |
# compute aggregate stats | |
aggregate_stats = { | |
"mean_accuracy": np.mean([r['correct_attribution'] for r in xorattriqa_results[file_name]]), | |
"mean_extracted_spans": np.mean([r['num_citation_spans'] for r in xorattriqa_results[file_name]]) | |
} | |
overall_stats[file_name] = aggregate_stats | |
result_logs += "\n" + str(print_stats(aggregate_stats, return_table=True)) | |
# compute average and stdev on accuracy and extracted spans in overall xorattriqa results | |
final_stats = { | |
"overall_mean_accuracy": np.mean([r['correct_attribution'] for r in all_xorattriqa_results]), | |
"overall_mean_extracted_spans": np.mean([r['num_citation_spans'] for r in all_xorattriqa_results]), | |
"overall_stdev_accuracy": np.std([r['correct_attribution'] for r in all_xorattriqa_results]), | |
"overall_stdev_extracted_spans": np.std([r['num_citation_spans'] for r in all_xorattriqa_results]), | |
"mean_of_means": np.mean([v['mean_accuracy'] for v in overall_stats.values()]), | |
"mean_of_extracted_spans": np.mean([v['mean_extracted_spans'] for v in overall_stats.values()]) | |
} | |
result_logs += "\n" + str(print_stats(final_stats, return_table=True)) | |
# output to a file | |
with open(f"{exp_output_path}/xorattriqa_results.txt", 'w') as f: | |
f.write(str(result_logs)) | |
if args.eli5_path and not args.xorattriqa_only: | |
# ELI5 evaluation | |
eli5_data = load_eli5_data(args.eli5_path) | |
eli5_ans_data = load_eli5_ans_data(args.eli5_ans_path) if args.eli5_ans_path else None | |
if eli5_ans_data and len(eli5_ans_data) != len(eli5_data): | |
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match") | |
clean_up_conflicts = r'\s*\[citation needed\]|\s*\[conflicts with [^\]]*\]' | |
processed_eli5 = {'data' : [preprocess_eli5_sample(sample) for sample in eli5_data]} | |
for i, sample in enumerate(tqdm.tqdm(processed_eli5['data'], desc="Evaluating ELI5")): | |
if eli5_ans_data: | |
sample['answer'] = eli5_ans_data[i] | |
citation_results = generate_citations( | |
config, | |
attribution_system, | |
sample['question'], | |
sample['answer'], | |
sample['docs'], | |
query_context=sample["question_ctx"], | |
) | |
cited_text = str(citation_results) | |
cited_text = re.sub(clean_up_conflicts, '', cited_text) | |
# Replace all newlines with a single space and trim leading and trailing whitespace. If there are multiple spaces, replace them with a single space. | |
cited_text = cited_text.replace('\n', ' ').strip() | |
cited_text = re.sub(r'\s+', ' ', cited_text) | |
cited_spans = "" | |
for citation_result in citation_results.results: | |
cited_spans += f"~~~~~~~~\n{citation_result.answer_text}\n~~~~~~~~\n" | |
for span in citation_result.citation_spans: | |
if span.document_span: | |
cited_spans += span.document_span.text + "\n**********\n" | |
sample = {**sample, 'output': cited_text, 'cited_spans': cited_spans} | |
processed_eli5['data'][i] = sample | |
if torch.cuda.is_available(): | |
# get rid of the model and free up GPU memory | |
del attribution_system | |
gc.collect() | |
torch.cuda.empty_cache() | |
evaluate_attribution_eli5(processed_eli5, exp_output_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run attribution evaluation on ELI5 and XOR-AttriQA datasets") | |
parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json") | |
parser.add_argument("--eli5_ans_path", type=str, help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", default=None) | |
parser.add_argument("--xorattriqa_path", type=str, help="Path to the XOR-AttriQA dataset JSONL file", default="./xor_attriqa/in-language") | |
parser.add_argument("--attribution_system", type=str, choices=["sliding_window", "gradient_based", "prompt_based"], help="Attribution system to use", default="sliding_window") | |
parser.add_argument("--output_path", type=str, help="Path to output all Excel files", default="attribution_results") | |
parser.add_argument("--eli5_only", action="store_true", help="Run attribution evaluation only on ELI5 dataset") | |
parser.add_argument("--xorattriqa_only", action="store_true", help="Run attribution evaluation only on XOR-AttriQA dataset") | |
parser.add_argument("--yaml_base_config", type=str, help="Path to the base YAML config file", default="base_config.yaml") | |
parser.add_argument("--yaml_config", type=str, help="Path to the experiment-specific YAML config file", default="mistral_7B.yaml") | |
parser.add_argument("--gen_seed", type=int, help="Random seed for generation", default=42) | |
args = parser.parse_args() | |
main(args) |