Skip to content
Permalink
4ebe4cf5f2
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
291 lines (262 sloc) 14.3 KB
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict
import numpy as np
import argparse
import ast
import tqdm
import os
import re
import gc
from typing import Union
from citation_systems import (
SlidingWindowSystem,
GradientBasedSystem,
PromptBasedSystem,
SaliencyBasedSystem,
CitationResults,
)
from utils.pipeline_helpers import run_command
from utils.reporting_helpers import write_results_to_excel, print_stats
from utils.token_text_mapper import RAGTokenTextMapper
from utils.config_loader import ConfigLoader
from prompts import get_rag_messages
def load_eli5_data(file_path: str) -> List[Dict]:
with open(file_path, 'r') as f:
return json.load(f)
def load_eli5_ans_data(file_path: str) -> List[str]:
with open(file_path, 'r') as f:
return json.load(f)
def load_xorattriqa_data(file_path: str) -> List[Dict]:
data = []
with open(file_path, 'r') as f:
for line in f:
data.append(json.loads(line))
return data
def preprocess_eli5_sample(sample: Dict) -> Dict:
return {
'question_ctx': sample['question_ctx'] if sample['question_ctx'] and sample['question_ctx'] != "[removed]" else None,
'question': sample['question'],
'answer': sample['answer'],
'docs': [{'text': doc['text'], 'title': doc['title']} for doc in sample['docs']],
'claims': sample['claims']
}
def preprocess_xorattriqa_sample(sample: Dict) -> Dict:
return {
'query': sample['query_translated_en'].replace('[b]', '').strip(),
'answer': ast.literal_eval(sample['answers_translated_en']),
'documents': [{"text": sample['passage_en'], "title": "Untitled"}],
'attributable_gt': sample['ais']
}
def generate_citations(
config: Dict[str, str],
citation_system: Union[SaliencyBasedSystem, PromptBasedSystem],
query: str,
answer: str,
documents: List[Dict[str, str]],
sentence_tokenize_answer: bool = True,
query_context=None,
) -> CitationResults:
if isinstance(citation_system, PromptBasedSystem):
return citation_system.generate_citations(query, answer, documents, query_context)
else:
messages = get_rag_messages(config, query, documents, query_context, answer)
input_ids = citation_system.tokenizer.apply_chat_template(messages, return_tensors="pt")
token_text_mapper = RAGTokenTextMapper(
citation_system.tokenizer,
config["context_regex"],
config["document_regex"],
input_ids=input_ids,
)
context_length = token_text_mapper.get_context_length()
answer_bounds = None
document_bounds = token_text_mapper.get_document_bounds()
if not sentence_tokenize_answer:
answer_bounds = citation_system._get_answer_bounds_all_sentences(token_text_mapper, context_length)
answer_start = min([bounds[1][0] for bounds in answer_bounds])
answer_end = max([bounds[1][1] for bounds in answer_bounds])
answer_bounds = (answer_start, answer_end)
return citation_system.generate_citations(
token_text_mapper,
context_length,
answer_bounds,
document_bounds,
)
def evaluate_attribution_xorattriqa(citation_results: CitationResults, attributable: bool) -> Dict:
"""
Evaluate the attribution result for XOR-AttriQA task.
Args:
- citation_result: CitationResult object containing the citation spans and texts
- attributable: Whether the answer is attributable to any source (ground truth)
Returns:
- Dictionary containing the evaluation metrics:
- pred: Whether the system attributed the answer to any source
- correct_attribution: Whether the system attribution matches the ground truth
- num_citation_spans: Number of citation spans attributed to the answer
- citation_spans: List of citation span texts
"""
if not citation_results.results:
system_attributed = False
doc_spans = []
else:
citation_result = citation_results.results[0]
# Check if the system attributed the answer to any source
doc_spans = [span for span in citation_result.citation_spans if span.document_span]
system_attributed = len(doc_spans) > 0
# Compare system attribution with ground truth
correct_attribution = system_attributed == attributable
return {
"pred": system_attributed,
"correct_attribution": int(correct_attribution),
"num_citation_spans": len(doc_spans),
"citation_spans": [span.document_span.text for span in doc_spans],
}
def evaluate_attribution_eli5(processed_eli5_data: Dict, output_path: str, file_name="eli5") -> None:
"""
Evaluate the attribution results for ELI5 dataset using the ALCE evaluation script.
"""
# save ELI5 processed data to a JSON file
with open(f"{output_path}/{file_name}.json", 'w') as f:
json.dump(processed_eli5_data, f)
print("Evaluating ELI5 data...")
run_command(f"python run_eli5_eval.py --f {output_path}/{file_name}.json --citations --claims_nli --report {output_path}/{file_name}.xlsx")
return
def main(args):
# Load prompt config
config_loader = ConfigLoader()
config = config_loader.load_config(args.yaml_base_config, args.yaml_config)
if args.gen_seed:
print("Setting config['gen_seed'] to", args.gen_seed)
config["gen_seed"] = args.gen_seed
config_name = os.path.splitext(os.path.basename(args.yaml_config))[0]
exp_output_path = os.path.join(args.output_path, config_name, args.attribution_system)
os.makedirs(exp_output_path, exist_ok=True)
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['model'])
model = AutoModelForCausalLM.from_pretrained(config['model'], torch_dtype=torch.float16, device_map="auto")
if args.attribution_system == "sliding_window":
attribution_system = SlidingWindowSystem(model, tokenizer, z_threshold=config['z_threshold'], window_batch_size=config['window_batch_size'], smoothing_window_size=config.get('smoothing_window_size'))
elif args.attribution_system == "gradient_based":
attribution_system = GradientBasedSystem(model, tokenizer, z_threshold=config['z_threshold'], smoothing_window_size=config.get('smoothing_window_size'))
elif args.attribution_system == "prompt_based":
attribution_system = PromptBasedSystem(config, model, tokenizer)
all_xorattriqa_results = []
if args.xorattriqa_path and not args.eli5_only:
# load all jsonl files in the directory
jsonl_files = [f for f in os.listdir(args.xorattriqa_path) if f.endswith('.jsonl')]
# remove train, val and toy files
jsonl_files = [f for f in jsonl_files if 'train' not in f and 'val' not in f and 'toy' not in f]
jsonl_paths = [os.path.join(args.xorattriqa_path, f) for f in jsonl_files]
xorattriqa_results = {}
overall_stats = {}
result_logs = ""
for jsonl_path in jsonl_paths:
file_name = os.path.basename(jsonl_path)
xorattriqa_results[file_name] = []
xorattriqa_data = load_xorattriqa_data(jsonl_path)
print(f"Evaluating XOR-AttriQA data from {file_name}...")
# Add tqdm for progress bar
for sample in tqdm.tqdm(xorattriqa_data, desc=f"Evaluating {file_name} XOR-AttriQA"):
processed_sample = preprocess_xorattriqa_sample(sample)
if len(processed_sample['answer']) > 1:
# Evaluate each answer sentence separately. This is the most brutal evaluation.
for sentence in processed_sample['answer']:
citation_results = generate_citations(
config,
attribution_system,
processed_sample['query'],
sentence,
processed_sample['documents'],
sentence_tokenize_answer=False,
)
partial_sample = {**processed_sample, 'answer': sentence}
evaluation = evaluate_attribution_xorattriqa(citation_results, partial_sample['attributable_gt'])
xorattriqa_results[file_name].append({**partial_sample, **evaluation})
else:
citation_results = generate_citations(
config,
attribution_system,
processed_sample['query'],
processed_sample['answer'][0],
processed_sample['documents'],
sentence_tokenize_answer=False,
)
evaluation = evaluate_attribution_xorattriqa(citation_results, processed_sample['attributable_gt'])
xorattriqa_results[file_name].append({**processed_sample, **evaluation})
all_xorattriqa_results.extend(xorattriqa_results[file_name])
# Write results to Excel
print(f"Saving {file_name} XOR-AttriQA results to Excel...")
write_results_to_excel(xorattriqa_results[file_name], f"{exp_output_path}/xorattriqa_{file_name.split('.')[0]}.xlsx")
# compute aggregate stats
aggregate_stats = {
"mean_accuracy": np.mean([r['correct_attribution'] for r in xorattriqa_results[file_name]]),
"mean_extracted_spans": np.mean([r['num_citation_spans'] for r in xorattriqa_results[file_name]])
}
overall_stats[file_name] = aggregate_stats
result_logs += "\n" + str(print_stats(aggregate_stats, return_table=True))
# compute average and stdev on accuracy and extracted spans in overall xorattriqa results
final_stats = {
"overall_mean_accuracy": np.mean([r['correct_attribution'] for r in all_xorattriqa_results]),
"overall_mean_extracted_spans": np.mean([r['num_citation_spans'] for r in all_xorattriqa_results]),
"overall_stdev_accuracy": np.std([r['correct_attribution'] for r in all_xorattriqa_results]),
"overall_stdev_extracted_spans": np.std([r['num_citation_spans'] for r in all_xorattriqa_results]),
"mean_of_means": np.mean([v['mean_accuracy'] for v in overall_stats.values()]),
"mean_of_extracted_spans": np.mean([v['mean_extracted_spans'] for v in overall_stats.values()])
}
result_logs += "\n" + str(print_stats(final_stats, return_table=True))
# output to a file
with open(f"{exp_output_path}/xorattriqa_results.txt", 'w') as f:
f.write(str(result_logs))
if args.eli5_path and not args.xorattriqa_only:
# ELI5 evaluation
eli5_data = load_eli5_data(args.eli5_path)
eli5_ans_data = load_eli5_ans_data(args.eli5_ans_path) if args.eli5_ans_path else None
if eli5_ans_data and len(eli5_ans_data) != len(eli5_data):
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match")
clean_up_conflicts = r'\s*\[citation needed\]|\s*\[conflicts with [^\]]*\]'
processed_eli5 = {'data' : [preprocess_eli5_sample(sample) for sample in eli5_data]}
for i, sample in enumerate(tqdm.tqdm(processed_eli5['data'], desc="Evaluating ELI5")):
if eli5_ans_data:
sample['answer'] = eli5_ans_data[i]
citation_results = generate_citations(
config,
attribution_system,
sample['question'],
sample['answer'],
sample['docs'],
query_context=sample["question_ctx"],
)
cited_text = str(citation_results)
cited_text = re.sub(clean_up_conflicts, '', cited_text)
# Replace all newlines with a single space and trim leading and trailing whitespace. If there are multiple spaces, replace them with a single space.
cited_text = cited_text.replace('\n', ' ').strip()
cited_text = re.sub(r'\s+', ' ', cited_text)
cited_spans = ""
for citation_result in citation_results.results:
cited_spans += f"~~~~~~~~\n{citation_result.answer_text}\n~~~~~~~~\n"
for span in citation_result.citation_spans:
if span.document_span:
cited_spans += span.document_span.text + "\n**********\n"
sample = {**sample, 'output': cited_text, 'cited_spans': cited_spans}
processed_eli5['data'][i] = sample
if torch.cuda.is_available():
# get rid of the model and free up GPU memory
del attribution_system
gc.collect()
torch.cuda.empty_cache()
evaluate_attribution_eli5(processed_eli5, exp_output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run attribution evaluation on ELI5 and XOR-AttriQA datasets")
parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
parser.add_argument("--eli5_ans_path", type=str, help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", default=None)
parser.add_argument("--xorattriqa_path", type=str, help="Path to the XOR-AttriQA dataset JSONL file", default="./xor_attriqa/in-language")
parser.add_argument("--attribution_system", type=str, choices=["sliding_window", "gradient_based", "prompt_based"], help="Attribution system to use", default="sliding_window")
parser.add_argument("--output_path", type=str, help="Path to output all Excel files", default="attribution_results")
parser.add_argument("--eli5_only", action="store_true", help="Run attribution evaluation only on ELI5 dataset")
parser.add_argument("--xorattriqa_only", action="store_true", help="Run attribution evaluation only on XOR-AttriQA dataset")
parser.add_argument("--yaml_base_config", type=str, help="Path to the base YAML config file", default="base_config.yaml")
parser.add_argument("--yaml_config", type=str, help="Path to the experiment-specific YAML config file", default="mistral_7B.yaml")
parser.add_argument("--gen_seed", type=int, help="Random seed for generation", default=42)
args = parser.parse_args()
main(args)