Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/citation_systems.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
933 lines (854 sloc)
41.1 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import List, Tuple, Dict, Optional, Union | |
from transformers import ( | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
from sklearn.preprocessing import StandardScaler | |
from scipy.stats import entropy | |
from itertools import chain | |
import numpy as np | |
import torch | |
import nltk | |
import regex as re | |
from utils.token_text_mapper import TokenTextMapper | |
from utils.answer_helpers import get_answer_bounds_selections, LIST_INDEXER_REGEX | |
from utils.model_handlers import ModelHandler, HFModelHandler, APIModelHandler | |
from prompts import get_attribution_messages, get_fewshot_attribution_messages | |
#TODO: Add python docs to all classes and methods | |
class Span: | |
def __init__( | |
self, | |
token_bounds: Tuple[int, int], | |
text_bounds: Tuple[int, int], | |
text: str, | |
): | |
self.token_bounds = token_bounds | |
self.text_bounds = text_bounds | |
self.text = text | |
def __str__(self): | |
return self.text | |
class DocumentSpan(Span): | |
def __init__( | |
self, | |
token_bounds: Tuple[int, int], | |
text_bounds: Tuple[int, int], | |
text: str, | |
token_rel_bounds: Tuple[int, int], | |
text_rel_bounds: Tuple[int, int], | |
document_index: int, | |
): | |
super().__init__(token_bounds, text_bounds, text) | |
self.token_rel_bounds = token_rel_bounds | |
self.text_rel_bounds = text_rel_bounds | |
self.document_index = document_index | |
class SalientSpan(Span): | |
def __init__( | |
self, | |
token_bounds: Tuple[int, int], | |
text_bounds: Tuple[int, int], | |
text: str, | |
document_span: Optional[DocumentSpan] = None, | |
): | |
super().__init__(token_bounds, text_bounds, text) | |
self.document_span = document_span | |
class CitationResult: | |
def __init__( | |
self, | |
token_answer_bounds: Tuple[int, int], | |
text_answer_bounds: Tuple[int, int], | |
answer_text: str, | |
saliency_scores: List[float], | |
citation_spans: List[SalientSpan], | |
conflict_spans: List[SalientSpan], | |
z_threshold: float, | |
): | |
self.token_answer_bounds = token_answer_bounds | |
self.text_answer_bounds = text_answer_bounds | |
self.answer_text = answer_text | |
self.saliency_scores = saliency_scores | |
self.citation_spans = citation_spans | |
self.conflict_spans = conflict_spans | |
self.z_threshold = z_threshold | |
def _to_str(self): | |
citation_docs = [ | |
citation_span.document_span.document_index + 1 | |
for citation_span in self.citation_spans | |
if citation_span.document_span | |
] | |
if citation_docs: | |
citation_str = "".join([f"[{d}]" for d in sorted(set(citation_docs))]) | |
else: | |
citation_str = "[citation needed]" | |
conflict_docs = [ | |
conflict_span.document_span.document_index + 1 | |
for conflict_span in self.conflict_spans | |
if conflict_span.document_span | |
] | |
conflict_str = ", ".join([str(d) for d in sorted(set(conflict_docs))]) | |
if conflict_str: | |
conflict_str = f"[conflicts with {conflict_str}]" | |
# Build the inline citation string | |
inline_citation = f" {citation_str}{conflict_str}" | |
# Find the last sentence-ending punctuation in the text | |
match = re.search(r'([.!?]*)(?!.*[.!?])', self.answer_text, re.DOTALL) | |
if match: | |
punct_pos = match.start(1) | |
pre_punct_text = self.answer_text[:punct_pos] | |
punct = match.group(1) | |
post_punct_text = self.answer_text[punct_pos + 1:] | |
new_text = pre_punct_text + inline_citation + punct + post_punct_text | |
else: | |
new_text = self.answer_text + inline_citation | |
# Reconstruct the full text with citations before the period | |
inline_str = new_text | |
return inline_str | |
def __str__(self): | |
return self._to_str().strip() | |
class CitationResults: | |
def __init__(self, results: List[CitationResult], token_context_length: int, text_context_length: int, text: str, eos_token: str): | |
self.results = results | |
self.token_context_length = token_context_length | |
self.text_context_length = text_context_length | |
self.text = text | |
self.eos_token = eos_token | |
def __str__(self): | |
inline_strs = [] | |
last_result_end = self.text_context_length | |
for result in self.results: | |
if result.text_answer_bounds[0] > last_result_end: | |
inline_strs.append(self.text[last_result_end:result.text_answer_bounds[0]]) | |
inline_strs.append(result._to_str()) | |
last_result_end = result.text_answer_bounds[1] | |
if last_result_end < len(self.text): | |
inline_strs.append(self.text[last_result_end:].rstrip().rstrip(self.eos_token)) | |
inline_str = "".join(inline_strs).strip() | |
return inline_str | |
class SaliencyBasedSystem(ABC): | |
def __init__( | |
self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
entropy_scale_saliency_scores: bool = False, | |
smoothing_window_size: Optional[int] = None, | |
normalize_saliency_scores: bool = False, | |
z_threshold: Union[float, str] = "auto", | |
span_padding: int = 7, | |
): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.entropy_scale_saliency_scores=entropy_scale_saliency_scores | |
self.smoothing_window_size = smoothing_window_size | |
self.normalize_saliency_scores=normalize_saliency_scores | |
self.z_threshold=z_threshold | |
self.span_padding=span_padding | |
@abstractmethod | |
def generate_citations( | |
self, | |
token_text_mapper: TokenTextMapper, | |
context_length: int, | |
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
) -> CitationResults: | |
pass | |
def _get_answer_bounds_all_sentences(self, token_text_mapper, context_length): | |
answer_end = -1 if token_text_mapper.text.rstrip().endswith(token_text_mapper.tokenizer.eos_token) else None | |
answer_text = token_text_mapper.get_text(context_length, answer_end) | |
sentences = nltk.sent_tokenize(answer_text) | |
sentences_positions = [] | |
for sent in sentences: | |
if re.match(LIST_INDEXER_REGEX, sent): | |
continue | |
sent = sent.strip() | |
sent_start = answer_text.find(sent) | |
sent_end = sent_start + len(sent) | |
sentences_positions.append([{ | |
"start": sent_start, | |
"end": sent_end, | |
"label": sent, | |
}]) | |
sentences_answer_bounds = get_answer_bounds_selections(sentences_positions, token_text_mapper, context_length) | |
return sentences_answer_bounds | |
def _get_span_document_bounds( | |
self, | |
span_start: int, | |
span_end: int, | |
document_bounds: Optional[List[Tuple[int, int]]] = None, | |
document_attribution_threshold: float = 0.4, | |
) -> Tuple[int, Optional[Tuple[int, int]]]: | |
if not document_bounds: | |
return -1, None | |
# Find the document that contains the most tokens in the span. Usually a span will be completely contained in a single document, | |
# but in case a span straddles two documents this will ensure that the document containing most of the span is selected. | |
best_doc_idx = -1 | |
best_span_tokens_in_doc = 0 | |
best_doc_span = None | |
for i, (doc_start, doc_end) in enumerate(document_bounds): | |
doc_span_start = max(span_start, doc_start) | |
doc_span_end = min(span_end, doc_end) | |
span_tokens_in_doc = doc_span_end - doc_span_start | |
if span_tokens_in_doc > best_span_tokens_in_doc: | |
best_doc_idx = i | |
best_span_tokens_in_doc = span_tokens_in_doc | |
best_doc_span = (doc_span_start, doc_span_end) | |
# If the span overlaps with any document, make sure it overlaps by at least the document_attribution_threshold | |
span_len = span_end - span_start | |
if best_span_tokens_in_doc < document_attribution_threshold * span_len: | |
best_doc_idx = -1 | |
if best_doc_idx < 0: | |
best_doc_span = None | |
return best_doc_idx, best_doc_span | |
def _get_logits( | |
self, | |
input_ids, | |
context_length=None, | |
mask_windows=None, | |
control_token_mask=None, | |
past_key_values=None, | |
): | |
inputs_embeds = self.model.get_input_embeddings()(input_ids) | |
# Reason for using context_length-1 instead of context_length: | |
# The last token of the context predicts the first token of the answer, so we need to include it in the input | |
# if we want the loss to account for all tokens in the answer. | |
answer_input_ids = input_ids[:, context_length-1:] if context_length else input_ids | |
answer_inputs_embeds = inputs_embeds[:, context_length-1:] if context_length else inputs_embeds | |
attention_mask = torch.ones_like(input_ids) | |
if mask_windows: | |
if control_token_mask is None: | |
raise ValueError( | |
"control_token_mask must be provided if mask_windows are provided" | |
) | |
if control_token_mask.shape != input_ids.shape: | |
raise ValueError( | |
"control_token_mask must have the same shape as input_ids" | |
) | |
if len(mask_windows) != input_ids.shape[0]: | |
raise ValueError( | |
"mask_windows must have the same length as the batch size" | |
) | |
for i, mask_window in enumerate(mask_windows): | |
if mask_window is not None: | |
attention_mask[i, mask_window[0]:mask_window[1]] = control_token_mask[ | |
i, mask_window[0]:mask_window[1] | |
] | |
logits = self.model( | |
inputs_embeds=answer_inputs_embeds, | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
return_dict=True, | |
).logits | |
return logits, answer_input_ids, answer_inputs_embeds | |
def _entropy_scale_saliency_scores(self, logits: torch.Tensor, saliency_scores: torch.Tensor) -> torch.Tensor: | |
if logits.requires_grad: | |
logits = logits.detach() | |
# Upcast to float to avoid precision issues | |
logits = logits.float() | |
context_length = len(saliency_scores) | |
shift_logits = logits[..., :context_length-1, :] | |
shift_probs = torch.nn.functional.softmax(shift_logits, dim=-1) | |
shift_log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) | |
shift_entropies = -torch.sum(shift_probs * shift_log_probs, dim=-1) | |
shift_entropies /= shift_entropies.max() | |
# Entropy at index i represents the next token entropy for input token i. | |
# To get the entropy of the distrubution from which token i was sampled, we need to use entropy at i-1. | |
shift_saliency_scores = saliency_scores[1:] | |
saliency_scores = torch.cat([saliency_scores[:1], shift_saliency_scores * shift_entropies[0]]) | |
return saliency_scores | |
def _smooth_and_normalize_saliency_scores(self, saliency_scores: torch.Tensor) -> torch.Tensor: | |
if self.smoothing_window_size is not None and self.smoothing_window_size > 1: | |
avg_kernel = torch.tensor( | |
[1/self.smoothing_window_size for _ in range(self.smoothing_window_size)], | |
dtype=saliency_scores.dtype, | |
device=saliency_scores.device, | |
) | |
saliency_scores = torch.nn.functional.conv1d( | |
saliency_scores.unsqueeze(0).unsqueeze(0), | |
avg_kernel.unsqueeze(0).unsqueeze(0), | |
padding="same", | |
).squeeze(0).squeeze(0) | |
if self.normalize_saliency_scores: | |
saliency_scores = saliency_scores / saliency_scores.abs().max() | |
return saliency_scores | |
def _get_loss( | |
self, | |
logits: torch.Tensor, | |
labels: torch.Tensor, | |
context_length: Optional[int] = None, | |
answer_bounds: Optional[Tuple[int, int]] = None, | |
reduce: bool = True, | |
): | |
# Upcast to float to avoid precision issues | |
logits = logits.float() | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
if answer_bounds is not None: | |
if context_length is None: | |
raise ValueError("context_length is required if answer_bounds is provided") | |
# set all labels to -100 if they fall outside the answer bounds | |
shift_labels = shift_labels.clone() | |
bool_mask = torch.ones_like(shift_labels, dtype=torch.bool) | |
answer_start = answer_bounds[0] - context_length | |
answer_end = answer_bounds[1] - context_length | |
bool_mask[..., answer_start:answer_end] = False | |
shift_labels[bool_mask] = -100 | |
# Flatten the tokens | |
shift_logits = shift_logits.view(-1, self.model.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Ensure tensors are on the same device | |
shift_labels = shift_labels.to(shift_logits.device) | |
reduction = "mean" if reduce else "none" | |
loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction) | |
loss = loss_fct(shift_logits, shift_labels) | |
if not reduce: | |
loss = loss.view(-1, labels.shape[1]-1) | |
return loss | |
def _pool_losses(self, losses: torch.Tensor, context_length: int, answer_bounds: Tuple[int, int]) -> torch.Tensor: | |
answer_start = answer_bounds[0] - context_length | |
answer_end = answer_bounds[1] - context_length | |
loss = losses[..., answer_start:answer_end].mean(dim=-1) | |
return loss | |
def _create_document_span( | |
self, | |
token_text_mapper: TokenTextMapper, | |
span_start: int, | |
span_end: int, | |
document_bounds: Optional[List[Tuple[int, int]]] = None, | |
) -> Optional[DocumentSpan]: | |
span_doc_index, span_doc_bounds = self._get_span_document_bounds(span_start, span_end, document_bounds) | |
if span_doc_index < 0: | |
return None | |
span_doc_text_bounds = token_text_mapper.get_text_bounds(*span_doc_bounds) | |
span_doc_text = token_text_mapper.get_text(*span_doc_bounds) | |
# span_doc_text can start with a space if span_doc_bounds[0] is a space-prefixed token. We want to skip this space | |
# to stay consistent with the original document text. | |
if span_doc_text[0] == " " and span_doc_text[1] != " ": | |
span_doc_text_bounds = (span_doc_text_bounds[0]+1, span_doc_text_bounds[1]) | |
span_doc_text = span_doc_text[1:] | |
# Compute the relative bounds of the span within the document (where position 0 is the start of the document) | |
doc_start = document_bounds[span_doc_index][0] | |
span_doc_rel_bounds = (span_doc_bounds[0]-doc_start, span_doc_bounds[1]-doc_start) | |
doc_text_start = token_text_mapper.get_text_idx(doc_start) | |
# doc_text_start can land on a space if doc_start is a space-prefixed token. We want to skip this space | |
# to stay consistent with the original document text. | |
if token_text_mapper.text[doc_text_start] == " " and token_text_mapper.text[doc_text_start+1] != " ": | |
doc_text_start += 1 | |
span_doc_text_rel_bounds = (span_doc_text_bounds[0]-doc_text_start, span_doc_text_bounds[1]-doc_text_start) | |
document_span = DocumentSpan( | |
span_doc_bounds, | |
span_doc_text_bounds, | |
span_doc_text, | |
span_doc_rel_bounds, | |
span_doc_text_rel_bounds, | |
span_doc_index, | |
) | |
return document_span | |
def _create_spans( | |
self, | |
token_text_mapper: TokenTextMapper, | |
saliency_scores: torch.Tensor, | |
is_positive: bool, | |
document_bounds: Optional[List[Tuple[int, int]]] = None, | |
) -> Tuple[List[SalientSpan], float]: | |
saliency_scores = saliency_scores.cpu().numpy() | |
abs_saliency_scores = np.abs(saliency_scores) | |
if is_positive: | |
mask = saliency_scores >= 0 | |
else: | |
mask = saliency_scores < 0 | |
# Compute automatic threshold for z-scores if needed | |
if self.z_threshold == "auto": | |
saliency_entropy = entropy(abs_saliency_scores) | |
z_threshold = 2.0 * np.exp(saliency_entropy / np.log(len(abs_saliency_scores))).item() | |
else: | |
z_threshold = self.z_threshold | |
# Construct spans | |
spans = [] | |
if mask.any(): | |
abs_saliency_z_scores = self._standardize_scores(abs_saliency_scores) | |
abs_saliency_z_scores[~mask] = -3.0 | |
padding_counter = 0 | |
span_start = None | |
for i, z_score in enumerate(abs_saliency_z_scores): | |
if z_score >= z_threshold: | |
if span_start is None: | |
span_start = max(i - self.span_padding, 0) | |
padding_counter = 0 | |
elif span_start is not None: | |
if padding_counter == self.span_padding or i == len(abs_saliency_z_scores) - 1: | |
span_end = i + 1 if i == len(abs_saliency_z_scores) - 1 else i | |
span_text_bounds = token_text_mapper.get_text_bounds(span_start, span_end) | |
span_text = token_text_mapper.get_text(span_start, span_end) | |
# If this span intersects a document, get the span bounds within the document | |
document_span = self._create_document_span(token_text_mapper, span_start, span_end, document_bounds) | |
spans.append( | |
SalientSpan( | |
(span_start, span_end), | |
span_text_bounds, | |
span_text, | |
document_span, | |
) | |
) | |
span_start = None | |
padding_counter += 1 | |
return spans, z_threshold | |
def _standardize_scores(self, scores: np.ndarray) -> np.ndarray: | |
scores = scores.reshape(-1, 1) | |
sc = StandardScaler() | |
return sc.fit_transform(scores).squeeze(1) | |
class SlidingWindowSystem(SaliencyBasedSystem): | |
def __init__(self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
window_size: int = 7, | |
window_overlap: int = 2, | |
window_batch_size: int = 1, | |
entropy_scale_saliency_scores: bool = False, | |
smoothing_window_size: Optional[int] = None, | |
normalize_saliency_scores: bool = False, | |
z_threshold: Union[float, str] = "auto", | |
span_padding: int = 7, | |
): | |
super().__init__( | |
model, | |
tokenizer, | |
entropy_scale_saliency_scores, | |
smoothing_window_size, | |
normalize_saliency_scores, | |
z_threshold, | |
span_padding, | |
) | |
self.window_size = window_size | |
self.window_overlap = window_overlap | |
self.window_batch_size = window_batch_size | |
def generate_citations( | |
self, | |
token_text_mapper: TokenTextMapper, | |
context_length: int, | |
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
) -> CitationResults: | |
if not answer_bounds: | |
answer_bounds = self._get_answer_bounds_all_sentences(token_text_mapper, context_length) | |
answer_bounds = [bounds[1] for bounds in answer_bounds] | |
if isinstance(answer_bounds[0], int): | |
answer_bounds = [answer_bounds] | |
if document_bounds and isinstance(document_bounds[0], int): | |
document_bounds = [document_bounds] | |
# Get the context logits and the token-level losses for the answer | |
losses_dict, context_logits, _ = self._get_model_outputs(token_text_mapper, context_length) | |
# Compute saliency and extract citation/conflict spans for each answer segment | |
results = [] | |
for answer_bounds_ in answer_bounds: | |
# Compute saliency scores for each token with respect to its impact on | |
# the likelihood of the current answer segment | |
saliency_scores = self._compute_saliency_scores(losses_dict, context_logits, context_length, answer_bounds_) | |
# Create citation spans | |
citation_spans, z_threshold = self._create_spans(token_text_mapper, saliency_scores, True, document_bounds) | |
# Create conflict spans | |
conflict_spans, _ = self._create_spans(token_text_mapper, saliency_scores, False, document_bounds) | |
text_answer_bounds = token_text_mapper.get_text_bounds(*answer_bounds_) | |
answer_text = token_text_mapper.get_text(*answer_bounds_) | |
results.append( | |
CitationResult( | |
answer_bounds_, | |
text_answer_bounds, | |
answer_text, | |
saliency_scores.tolist(), | |
citation_spans, | |
conflict_spans, | |
z_threshold, | |
) | |
) | |
text_context_length = token_text_mapper.get_text_idx(context_length) | |
citation_results = CitationResults( | |
results, | |
context_length, | |
text_context_length, | |
token_text_mapper.text, | |
token_text_mapper.tokenizer.eos_token, | |
) | |
return citation_results | |
def _get_model_outputs( | |
self, | |
token_text_mapper: TokenTextMapper, | |
context_length: int, | |
) -> Tuple[Dict[Optional[Tuple[int, int]], torch.Tensor], torch.Tensor, torch.Tensor]: | |
# Pre-compute the KV cache for the context | |
losses_dict = {} | |
input_ids = token_text_mapper.input_ids.clone().to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model(input_ids[:, :context_length-1], return_dict=True) | |
past_key_values = outputs.past_key_values | |
context_logits, context_labels = outputs.logits, input_ids[:, :context_length-1] | |
# Pre-compute the token-level answer losses in batch: first with no windows applied and then for each window | |
control_token_mask = self._get_control_token_mask(input_ids) | |
window_ranges = [None] + self._get_window_ranges(context_length) | |
for i in range(0, len(window_ranges), self.window_batch_size): | |
batch_window_ranges = window_ranges[i:i+self.window_batch_size] | |
batch_input_ids = input_ids.expand(len(batch_window_ranges), input_ids.shape[-1]) | |
batch_control_token_mask = control_token_mask.expand(len(batch_window_ranges), control_token_mask.shape[-1]) | |
batch_past_key_values = [ | |
(layer[0].expand(len(batch_window_ranges), *layer[0].shape[1:]), layer[1].expand(len(batch_window_ranges), *layer[1].shape[1:])) | |
for layer in past_key_values | |
] | |
with torch.no_grad(): | |
batch_logits, batch_labels, _ = self._get_logits( | |
batch_input_ids, | |
context_length, | |
batch_window_ranges, | |
batch_control_token_mask, | |
batch_past_key_values, | |
) | |
batch_losses = self._get_loss(batch_logits, batch_labels, reduce=False) | |
for window_range, losses in zip(batch_window_ranges, batch_losses): | |
losses_dict[window_range] = losses | |
return losses_dict, context_logits, context_labels | |
def _compute_saliency_scores( | |
self, | |
losses_dict: Dict[Optional[Tuple[int, int]], torch.Tensor], | |
context_logits: torch.Tensor, | |
context_length: int, | |
answer_bounds: Tuple[int, int], | |
) -> torch.Tensor: | |
base_loss = self._pool_losses(losses_dict[None], context_length, answer_bounds) | |
window_losses, token_window_ids = self._calculate_window_losses(losses_dict, context_length, answer_bounds) | |
window_saliency_scores = window_losses - base_loss | |
token_saliency_scores = torch.stack([window_saliency_scores[token_window_ids[i]].mean() for i in range(context_length)]) | |
if self.entropy_scale_saliency_scores: | |
token_saliency_scores = self._entropy_scale_saliency_scores(context_logits, token_saliency_scores) | |
token_saliency_scores = self._smooth_and_normalize_saliency_scores(token_saliency_scores) | |
return token_saliency_scores | |
def _get_control_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: | |
# TODO: Pass answer_prefix in from the configured answer template. Not sure how to support general | |
# "user-provided" patterns that should be considered control tokens. For example, the answer prefix | |
# should only be considered a control token if it's the first content in an assistant turn, and so on. | |
answer_prefix = "Answer:" | |
# Get the control token sequences for this model | |
ctl_seqs = [[self.tokenizer.convert_tokens_to_ids(tok)] for tok in self.tokenizer.all_special_tokens] | |
if self.tokenizer.chat_template is not None: | |
chat_template_ctl_seqs = [ | |
# Llama-2 / Mistral | |
"[INST]", | |
"[/INST]", | |
f"[/INST] {answer_prefix}", | |
# Llama-3 | |
"<|start_header_id|>system<|end_header_id|>\n\n", | |
"<|start_header_id|>user<|end_header_id|>\n\n", | |
"<|start_header_id|>assistant<|end_header_id|>\n\n", | |
f"<|start_header_id|>assistant<|end_header_id|>\n\n{answer_prefix}", | |
# Zephyr | |
"<|system|>\n", | |
"<|user|>\n", | |
"<|assistant|>\n", | |
f"<|assistant|>\n{answer_prefix}", | |
# Zephyr tokenizes the < differently if it's at the start of a line | |
"\n<|system|>\n", | |
"\n<|user|>\n", | |
"\n<|assistant|>\n", | |
f"\n<|assistant|>\n{answer_prefix}", | |
# TODO: Add support for other model-specific chat template control sequences here | |
] | |
for ctl_seq in chat_template_ctl_seqs: | |
if ctl_seq.replace(answer_prefix, "").strip() in self.tokenizer.chat_template: | |
ctl_seqs.append(self.tokenizer.encode(ctl_seq, add_special_tokens=False)) | |
# Create the mask | |
ctl_seqs = [torch.tensor(ctl_seq).to(input_ids.device) for ctl_seq in ctl_seqs] | |
control_token_mask = torch.zeros_like(input_ids) | |
for i in range(input_ids.shape[-1]): | |
for ctl_seq in ctl_seqs: | |
if i+len(ctl_seq) <= input_ids.shape[-1] and torch.all(input_ids[0, i:i+len(ctl_seq)] == ctl_seq): | |
control_token_mask[0, i:i+len(ctl_seq)] = 1 | |
return control_token_mask | |
def _calculate_window_losses(self, losses_dict, context_length: int, answer_bounds: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]: | |
window_losses = [] | |
token_window_ids = [[] for _ in range(context_length)] | |
for start, end in self._get_window_ranges(context_length): | |
for i in range(start, end): | |
token_window_ids[i].append(len(window_losses)) | |
loss = self._pool_losses(losses_dict[(start,end)], context_length, answer_bounds) | |
window_losses.append(loss) | |
return torch.stack(window_losses), token_window_ids | |
def _get_window_ranges(self, context_length: int) -> List[Tuple[int, int]]: | |
return [(start, min(start + self.window_size, context_length)) | |
for start in range(0, context_length - self.window_overlap, self.window_size - self.window_overlap)] | |
class GradientBasedSystem(SaliencyBasedSystem): | |
def __init__(self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
entropy_scale_saliency_scores: bool = False, | |
smoothing_window_size: Optional[int] = None, | |
normalize_saliency_scores: bool = False, | |
z_threshold: Union[float, str] = "auto", | |
span_padding: int = 7, | |
): | |
super().__init__( | |
model, | |
tokenizer, | |
entropy_scale_saliency_scores, | |
smoothing_window_size, | |
normalize_saliency_scores, | |
z_threshold, | |
span_padding, | |
) | |
def generate_citations( | |
self, | |
token_text_mapper: TokenTextMapper, | |
context_length: int, | |
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None, | |
) -> CitationResults: | |
# TODO: much of this method is shared with the sliding window implementation. | |
# See how we can abstract the shared logic out to SalienceBasedSystem | |
if not answer_bounds: | |
answer_bounds = self._get_answer_bounds_all_sentences(token_text_mapper, context_length) | |
answer_bounds = [bounds[1] for bounds in answer_bounds] | |
if isinstance(answer_bounds[0], int): | |
answer_bounds = [answer_bounds] | |
if document_bounds and isinstance(document_bounds[0], int): | |
document_bounds = [document_bounds] | |
# Compute saliency and extract citation/conflict spans for each answer segment | |
results = [] | |
for answer_bounds_ in answer_bounds: | |
# Compute saliency scores for each token with respect to its impact on | |
# the likelihood of the current answer segment | |
saliency_scores = self._compute_saliency_scores(token_text_mapper.input_ids, context_length, answer_bounds_) | |
# Create citation spans | |
citation_spans, z_threshold = self._create_spans(token_text_mapper, saliency_scores, True, document_bounds) | |
# Create conflict spans | |
conflict_spans, _ = self._create_spans(token_text_mapper, saliency_scores, False, document_bounds) | |
text_answer_bounds = token_text_mapper.get_text_bounds(*answer_bounds_) | |
answer_text = token_text_mapper.get_text(*answer_bounds_) | |
results.append( | |
CitationResult( | |
answer_bounds_, | |
text_answer_bounds, | |
answer_text, | |
saliency_scores.tolist(), | |
citation_spans, | |
conflict_spans, | |
z_threshold, | |
) | |
) | |
text_context_length = token_text_mapper.get_text_idx(context_length) | |
citation_results = CitationResults( | |
results, | |
context_length, | |
text_context_length, | |
token_text_mapper.text, | |
token_text_mapper.tokenizer.eos_token, | |
) | |
return citation_results | |
def _compute_saliency_scores( | |
self, | |
input_ids: torch.Tensor, | |
context_length: int, | |
answer_bounds: Tuple[int, int], | |
) -> torch.Tensor: | |
# We have to separately compute the logits for each answer segment because the gradient computation | |
# will be different for each segment | |
input_ids = input_ids.clone().to(self.model.device) | |
logits, labels, inputs_embeds = self._get_logits(input_ids) | |
loss = self._get_loss(logits, labels, 1, answer_bounds) | |
inputs_embeds.retain_grad() | |
# backpropagate the loss to the input embeddings | |
self.model.zero_grad() | |
loss.backward() | |
# compute the component of the gradients with respect to the input embeddings | |
inputs_grads = inputs_embeds.grad[0, :context_length] | |
inputs_embeds = inputs_embeds.detach()[0, :context_length] | |
token_saliency_scores = -(inputs_grads * inputs_embeds).sum(dim=-1) / inputs_embeds.norm(p=2, dim=-1) | |
if self.entropy_scale_saliency_scores: | |
token_saliency_scores = self._entropy_scale_saliency_scores(logits, token_saliency_scores) | |
token_saliency_scores = self._smooth_and_normalize_saliency_scores(token_saliency_scores) | |
return token_saliency_scores | |
class PromptBasedCitationResult(CitationResult): | |
def __init__( | |
self, | |
text_answer_bounds: Tuple[int, int], | |
answer_text: str, | |
citation_spans: List[SalientSpan], | |
conflict_spans: List[SalientSpan], | |
citation_messages: List[List[Dict[str, str]]], | |
conflict_messages: List[List[Dict[str, str]]], | |
): | |
super().__init__(None, text_answer_bounds, answer_text, None, citation_spans, conflict_spans, 0.0) | |
self.citation_messages = citation_messages | |
self.conflict_messages = conflict_messages | |
class PromptBasedCitationResults(CitationResults): | |
def __init__(self, results: List[PromptBasedCitationResult], text_context_length: int, text: str): | |
super().__init__(results, None, text_context_length, text, "") | |
class PromptBasedSystem: | |
def __init__( | |
self, | |
config: Dict[str, str], | |
model: Optional[PreTrainedModel] = None, | |
tokenizer: Optional[PreTrainedTokenizer] = None, | |
): | |
self.config = config | |
self.model_handler = self._create_model_handler(model, tokenizer) | |
def _create_model_handler( | |
self, | |
model: Optional[PreTrainedModel] = None, | |
tokenizer: Optional[PreTrainedTokenizer] = None, | |
) -> ModelHandler: | |
model_name = self.config['model'] | |
if model is not None and tokenizer is not None: | |
return HFModelHandler(model, tokenizer) | |
elif model_name.startswith(('gpt-4', 'claude')): | |
return APIModelHandler(api_key=self.config.get('api_key'), model_name=model_name) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', offload_folder="offload_citation_system", torch_dtype=torch.float16) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return HFModelHandler(model, tokenizer) | |
def generate_citations( | |
self, | |
question: str, | |
answer: str, | |
documents: List[Dict[str, str]], | |
question_context: Optional[str] = None, | |
) -> PromptBasedCitationResults: | |
# Setup the citation and conflict prompt messages for each (question/answer, sentence, document) triple | |
answer_sentences = nltk.sent_tokenize(answer) | |
answer_sentences = [sent.strip() for sent in answer_sentences if not re.match(LIST_INDEXER_REGEX, sent)] | |
citation_messages = [[] for _ in range(len(answer_sentences))] | |
conflict_messages = [[] for _ in range(len(answer_sentences))] | |
for answer_sent, sent_citation_messages, sent_conflict_messages in zip(answer_sentences, citation_messages, conflict_messages): | |
for doc_idx, doc in enumerate(documents): | |
for for_conflicts in [False, True]: | |
messages = get_fewshot_attribution_messages(self.config, for_conflicts) | |
messages.extend( | |
get_attribution_messages( | |
self.config, | |
question, | |
answer, | |
doc_idx, | |
doc, | |
answer_sent, | |
for_conflicts, | |
question_context, | |
response_prefill=self.model_handler.supports_response_prefill, | |
prepend_instruction=False, | |
) | |
) | |
if for_conflicts: | |
sent_conflict_messages.append(messages) | |
else: | |
sent_citation_messages.append(messages) | |
# Generate the attribution responses in batch | |
generation_kwargs = { | |
"temperature": self.config['temperature'], | |
"top_p": self.config['top_p'], | |
"max_tokens": self.config['max_new_tokens'], | |
"batch_size": self.config.get('gen_batch_size', 1), | |
"seed": self.config.get("gen_seed"), | |
} | |
# Prepare list of sentence-document messages for batch inference | |
all_messages = citation_messages + conflict_messages | |
all_sent_messages = list(chain.from_iterable(all_messages)) | |
# Do batch inference | |
responses = self.model_handler.generate(all_sent_messages, **generation_kwargs) | |
for messages, response in zip(all_sent_messages, responses): | |
if self.model_handler.supports_response_prefill: | |
messages[-1]["content"] += response | |
else: | |
messages.append({"role": "assistant", "content": response}) | |
# Process the responses to get the CitationResults | |
citation_results = self._process_cited_responses( | |
citation_messages, | |
conflict_messages, | |
answer, | |
answer_sentences, | |
documents, | |
) | |
return citation_results | |
def _process_cited_responses( | |
self, | |
citation_messages: List[List[List[Dict[str, str]]]], | |
conflict_messages: List[List[List[Dict[str, str]]]], | |
answer: str, | |
answer_sentences: List[str], | |
documents: List[Dict[str, str]], | |
) -> PromptBasedCitationResults: | |
# Prepare text body for results object | |
text = "" | |
document_start_positions = [] | |
for i, doc in enumerate(documents): | |
document_start_positions.append(len(text)) | |
text += self.config["document_template"].format(ID=i+1, T=doc.get('title', 'Untitled'), P=doc['text']) + "\n" | |
text += "\n" + self.config["answer_template"].format(A=answer) | |
# Extract citation and conflict spans from the responses | |
citation_results = [] | |
for answer_sent, sent_citation_messages, sent_conflict_messages in zip(answer_sentences, citation_messages, conflict_messages): | |
citation_spans = [] | |
conflict_spans = [] | |
for doc_idx, doc in enumerate(documents): | |
citation_response, conflict_response = sent_citation_messages[doc_idx][-1]["content"], sent_conflict_messages[doc_idx][-1]["content"] | |
doc_start_pos = document_start_positions[doc_idx] | |
citation_spans.extend(self._extract_spans(citation_response, doc_idx, doc, doc_start_pos, False)) | |
conflict_spans.extend(self._extract_spans(conflict_response, doc_idx, doc, doc_start_pos, True)) | |
answer_start_pos = text.rfind(answer_sent) | |
citation_results.append( | |
PromptBasedCitationResult( | |
(answer_start_pos, answer_start_pos+len(answer_sent)), | |
answer_sent, | |
citation_spans, | |
conflict_spans, | |
sent_citation_messages, | |
sent_conflict_messages, | |
) | |
) | |
# Create the results object | |
text_context_length = list(re.finditer(self.config["context_regex"], text))[-1].end() | |
citation_results = PromptBasedCitationResults(citation_results, text_context_length, text) | |
return citation_results | |
def _extract_spans( | |
self, | |
response: str, | |
doc_idx: int, | |
doc: Dict[str, str], | |
doc_start_pos: int, | |
for_conflicts: bool, | |
) -> List[SalientSpan]: | |
spans = [] | |
attr_tag = "conflict" if for_conflicts else "support" | |
attr_start_tag = f"<|{attr_tag}|>" | |
attr_end_tag = f"<|end_{attr_tag}|>" | |
# Get re-written document body from the response | |
response_doc_match = re.search(self.config["document_regex"], response, re.DOTALL) | |
if response_doc_match is None: | |
return spans | |
response_doc_body = response_doc_match.group(1) | |
doc_start_pos += response_doc_match.start(1) | |
# The document needs to have been faithfully re-written by the model, or else we can't expect match bounds to be accurate | |
# when applied to the original document text. | |
if response_doc_body.replace(attr_start_tag, "").replace(attr_end_tag, "") != doc["text"]: | |
return spans | |
# Extract spans from the response | |
span_regex = f"{attr_start_tag}(.+?){attr_end_tag}".replace("|", r"\|") | |
span_matches = re.finditer(span_regex, response_doc_body, re.DOTALL) | |
for i, span_match in enumerate(span_matches): | |
offset = (len(attr_start_tag) + len(attr_end_tag)) * i + len(attr_start_tag) | |
span_text = span_match.group(1) | |
span_start = span_match.start(1) - offset | |
span_end = span_match.end(1) - offset | |
span_text_bounds = (span_start+doc_start_pos, span_end+doc_start_pos) | |
span_text_rel_bounds = (span_start, span_end) | |
spans.append( | |
SalientSpan( | |
None, | |
span_text_bounds, | |
span_text, | |
DocumentSpan( | |
None, | |
span_text_bounds, | |
span_text, | |
None, | |
span_text_rel_bounds, | |
doc_idx, | |
), | |
) | |
) | |
return spans |