Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Optional, Union
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
AutoModelForCausalLM,
AutoTokenizer,
)
from sklearn.preprocessing import StandardScaler
from scipy.stats import entropy
from itertools import chain
import numpy as np
import torch
import nltk
import regex as re
from utils.token_text_mapper import TokenTextMapper
from utils.answer_helpers import get_answer_bounds_selections, LIST_INDEXER_REGEX
from utils.model_handlers import ModelHandler, HFModelHandler, APIModelHandler
from prompts import get_attribution_messages, get_fewshot_attribution_messages
#TODO: Add python docs to all classes and methods
class Span:
def __init__(
self,
token_bounds: Tuple[int, int],
text_bounds: Tuple[int, int],
text: str,
):
self.token_bounds = token_bounds
self.text_bounds = text_bounds
self.text = text
def __str__(self):
return self.text
class DocumentSpan(Span):
def __init__(
self,
token_bounds: Tuple[int, int],
text_bounds: Tuple[int, int],
text: str,
token_rel_bounds: Tuple[int, int],
text_rel_bounds: Tuple[int, int],
document_index: int,
):
super().__init__(token_bounds, text_bounds, text)
self.token_rel_bounds = token_rel_bounds
self.text_rel_bounds = text_rel_bounds
self.document_index = document_index
class SalientSpan(Span):
def __init__(
self,
token_bounds: Tuple[int, int],
text_bounds: Tuple[int, int],
text: str,
document_span: Optional[DocumentSpan] = None,
):
super().__init__(token_bounds, text_bounds, text)
self.document_span = document_span
class CitationResult:
def __init__(
self,
token_answer_bounds: Tuple[int, int],
text_answer_bounds: Tuple[int, int],
answer_text: str,
saliency_scores: List[float],
citation_spans: List[SalientSpan],
conflict_spans: List[SalientSpan],
z_threshold: float,
):
self.token_answer_bounds = token_answer_bounds
self.text_answer_bounds = text_answer_bounds
self.answer_text = answer_text
self.saliency_scores = saliency_scores
self.citation_spans = citation_spans
self.conflict_spans = conflict_spans
self.z_threshold = z_threshold
def _to_str(self):
citation_docs = [
citation_span.document_span.document_index + 1
for citation_span in self.citation_spans
if citation_span.document_span
]
if citation_docs:
citation_str = "".join([f"[{d}]" for d in sorted(set(citation_docs))])
else:
citation_str = "[citation needed]"
conflict_docs = [
conflict_span.document_span.document_index + 1
for conflict_span in self.conflict_spans
if conflict_span.document_span
]
conflict_str = ", ".join([str(d) for d in sorted(set(conflict_docs))])
if conflict_str:
conflict_str = f"[conflicts with {conflict_str}]"
# Build the inline citation string
inline_citation = f" {citation_str}{conflict_str}"
# Find the last sentence-ending punctuation in the text
match = re.search(r'([.!?]*)(?!.*[.!?])', self.answer_text, re.DOTALL)
if match:
punct_pos = match.start(1)
pre_punct_text = self.answer_text[:punct_pos]
punct = match.group(1)
post_punct_text = self.answer_text[punct_pos + 1:]
new_text = pre_punct_text + inline_citation + punct + post_punct_text
else:
new_text = self.answer_text + inline_citation
# Reconstruct the full text with citations before the period
inline_str = new_text
return inline_str
def __str__(self):
return self._to_str().strip()
class CitationResults:
def __init__(self, results: List[CitationResult], token_context_length: int, text_context_length: int, text: str, eos_token: str):
self.results = results
self.token_context_length = token_context_length
self.text_context_length = text_context_length
self.text = text
self.eos_token = eos_token
def __str__(self):
inline_strs = []
last_result_end = self.text_context_length
for result in self.results:
if result.text_answer_bounds[0] > last_result_end:
inline_strs.append(self.text[last_result_end:result.text_answer_bounds[0]])
inline_strs.append(result._to_str())
last_result_end = result.text_answer_bounds[1]
if last_result_end < len(self.text):
inline_strs.append(self.text[last_result_end:].rstrip().rstrip(self.eos_token))
inline_str = "".join(inline_strs).strip()
return inline_str
class SaliencyBasedSystem(ABC):
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
entropy_scale_saliency_scores: bool = False,
smoothing_window_size: Optional[int] = None,
normalize_saliency_scores: bool = False,
z_threshold: Union[float, str] = "auto",
span_padding: int = 7,
):
self.model = model
self.tokenizer = tokenizer
self.entropy_scale_saliency_scores=entropy_scale_saliency_scores
self.smoothing_window_size = smoothing_window_size
self.normalize_saliency_scores=normalize_saliency_scores
self.z_threshold=z_threshold
self.span_padding=span_padding
@abstractmethod
def generate_citations(
self,
token_text_mapper: TokenTextMapper,
context_length: int,
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
) -> CitationResults:
pass
def _get_answer_bounds_all_sentences(self, token_text_mapper, context_length):
answer_end = -1 if token_text_mapper.text.rstrip().endswith(token_text_mapper.tokenizer.eos_token) else None
answer_text = token_text_mapper.get_text(context_length, answer_end)
sentences = nltk.sent_tokenize(answer_text)
sentences_positions = []
for sent in sentences:
if re.match(LIST_INDEXER_REGEX, sent):
continue
sent = sent.strip()
sent_start = answer_text.find(sent)
sent_end = sent_start + len(sent)
sentences_positions.append([{
"start": sent_start,
"end": sent_end,
"label": sent,
}])
sentences_answer_bounds = get_answer_bounds_selections(sentences_positions, token_text_mapper, context_length)
return sentences_answer_bounds
def _get_span_document_bounds(
self,
span_start: int,
span_end: int,
document_bounds: Optional[List[Tuple[int, int]]] = None,
document_attribution_threshold: float = 0.4,
) -> Tuple[int, Optional[Tuple[int, int]]]:
if not document_bounds:
return -1, None
# Find the document that contains the most tokens in the span. Usually a span will be completely contained in a single document,
# but in case a span straddles two documents this will ensure that the document containing most of the span is selected.
best_doc_idx = -1
best_span_tokens_in_doc = 0
best_doc_span = None
for i, (doc_start, doc_end) in enumerate(document_bounds):
doc_span_start = max(span_start, doc_start)
doc_span_end = min(span_end, doc_end)
span_tokens_in_doc = doc_span_end - doc_span_start
if span_tokens_in_doc > best_span_tokens_in_doc:
best_doc_idx = i
best_span_tokens_in_doc = span_tokens_in_doc
best_doc_span = (doc_span_start, doc_span_end)
# If the span overlaps with any document, make sure it overlaps by at least the document_attribution_threshold
span_len = span_end - span_start
if best_span_tokens_in_doc < document_attribution_threshold * span_len:
best_doc_idx = -1
if best_doc_idx < 0:
best_doc_span = None
return best_doc_idx, best_doc_span
def _get_logits(
self,
input_ids,
context_length=None,
mask_windows=None,
control_token_mask=None,
past_key_values=None,
):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
# Reason for using context_length-1 instead of context_length:
# The last token of the context predicts the first token of the answer, so we need to include it in the input
# if we want the loss to account for all tokens in the answer.
answer_input_ids = input_ids[:, context_length-1:] if context_length else input_ids
answer_inputs_embeds = inputs_embeds[:, context_length-1:] if context_length else inputs_embeds
attention_mask = torch.ones_like(input_ids)
if mask_windows:
if control_token_mask is None:
raise ValueError(
"control_token_mask must be provided if mask_windows are provided"
)
if control_token_mask.shape != input_ids.shape:
raise ValueError(
"control_token_mask must have the same shape as input_ids"
)
if len(mask_windows) != input_ids.shape[0]:
raise ValueError(
"mask_windows must have the same length as the batch size"
)
for i, mask_window in enumerate(mask_windows):
if mask_window is not None:
attention_mask[i, mask_window[0]:mask_window[1]] = control_token_mask[
i, mask_window[0]:mask_window[1]
]
logits = self.model(
inputs_embeds=answer_inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
return_dict=True,
).logits
return logits, answer_input_ids, answer_inputs_embeds
def _entropy_scale_saliency_scores(self, logits: torch.Tensor, saliency_scores: torch.Tensor) -> torch.Tensor:
if logits.requires_grad:
logits = logits.detach()
# Upcast to float to avoid precision issues
logits = logits.float()
context_length = len(saliency_scores)
shift_logits = logits[..., :context_length-1, :]
shift_probs = torch.nn.functional.softmax(shift_logits, dim=-1)
shift_log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
shift_entropies = -torch.sum(shift_probs * shift_log_probs, dim=-1)
shift_entropies /= shift_entropies.max()
# Entropy at index i represents the next token entropy for input token i.
# To get the entropy of the distrubution from which token i was sampled, we need to use entropy at i-1.
shift_saliency_scores = saliency_scores[1:]
saliency_scores = torch.cat([saliency_scores[:1], shift_saliency_scores * shift_entropies[0]])
return saliency_scores
def _smooth_and_normalize_saliency_scores(self, saliency_scores: torch.Tensor) -> torch.Tensor:
if self.smoothing_window_size is not None and self.smoothing_window_size > 1:
avg_kernel = torch.tensor(
[1/self.smoothing_window_size for _ in range(self.smoothing_window_size)],
dtype=saliency_scores.dtype,
device=saliency_scores.device,
)
saliency_scores = torch.nn.functional.conv1d(
saliency_scores.unsqueeze(0).unsqueeze(0),
avg_kernel.unsqueeze(0).unsqueeze(0),
padding="same",
).squeeze(0).squeeze(0)
if self.normalize_saliency_scores:
saliency_scores = saliency_scores / saliency_scores.abs().max()
return saliency_scores
def _get_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
context_length: Optional[int] = None,
answer_bounds: Optional[Tuple[int, int]] = None,
reduce: bool = True,
):
# Upcast to float to avoid precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
if answer_bounds is not None:
if context_length is None:
raise ValueError("context_length is required if answer_bounds is provided")
# set all labels to -100 if they fall outside the answer bounds
shift_labels = shift_labels.clone()
bool_mask = torch.ones_like(shift_labels, dtype=torch.bool)
answer_start = answer_bounds[0] - context_length
answer_end = answer_bounds[1] - context_length
bool_mask[..., answer_start:answer_end] = False
shift_labels[bool_mask] = -100
# Flatten the tokens
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
reduction = "mean" if reduce else "none"
loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction)
loss = loss_fct(shift_logits, shift_labels)
if not reduce:
loss = loss.view(-1, labels.shape[1]-1)
return loss
def _pool_losses(self, losses: torch.Tensor, context_length: int, answer_bounds: Tuple[int, int]) -> torch.Tensor:
answer_start = answer_bounds[0] - context_length
answer_end = answer_bounds[1] - context_length
loss = losses[..., answer_start:answer_end].mean(dim=-1)
return loss
def _create_document_span(
self,
token_text_mapper: TokenTextMapper,
span_start: int,
span_end: int,
document_bounds: Optional[List[Tuple[int, int]]] = None,
) -> Optional[DocumentSpan]:
span_doc_index, span_doc_bounds = self._get_span_document_bounds(span_start, span_end, document_bounds)
if span_doc_index < 0:
return None
span_doc_text_bounds = token_text_mapper.get_text_bounds(*span_doc_bounds)
span_doc_text = token_text_mapper.get_text(*span_doc_bounds)
# span_doc_text can start with a space if span_doc_bounds[0] is a space-prefixed token. We want to skip this space
# to stay consistent with the original document text.
if span_doc_text[0] == " " and span_doc_text[1] != " ":
span_doc_text_bounds = (span_doc_text_bounds[0]+1, span_doc_text_bounds[1])
span_doc_text = span_doc_text[1:]
# Compute the relative bounds of the span within the document (where position 0 is the start of the document)
doc_start = document_bounds[span_doc_index][0]
span_doc_rel_bounds = (span_doc_bounds[0]-doc_start, span_doc_bounds[1]-doc_start)
doc_text_start = token_text_mapper.get_text_idx(doc_start)
# doc_text_start can land on a space if doc_start is a space-prefixed token. We want to skip this space
# to stay consistent with the original document text.
if token_text_mapper.text[doc_text_start] == " " and token_text_mapper.text[doc_text_start+1] != " ":
doc_text_start += 1
span_doc_text_rel_bounds = (span_doc_text_bounds[0]-doc_text_start, span_doc_text_bounds[1]-doc_text_start)
document_span = DocumentSpan(
span_doc_bounds,
span_doc_text_bounds,
span_doc_text,
span_doc_rel_bounds,
span_doc_text_rel_bounds,
span_doc_index,
)
return document_span
def _create_spans(
self,
token_text_mapper: TokenTextMapper,
saliency_scores: torch.Tensor,
is_positive: bool,
document_bounds: Optional[List[Tuple[int, int]]] = None,
) -> Tuple[List[SalientSpan], float]:
saliency_scores = saliency_scores.cpu().numpy()
abs_saliency_scores = np.abs(saliency_scores)
if is_positive:
mask = saliency_scores >= 0
else:
mask = saliency_scores < 0
# Compute automatic threshold for z-scores if needed
if self.z_threshold == "auto":
saliency_entropy = entropy(abs_saliency_scores)
z_threshold = 2.0 * np.exp(saliency_entropy / np.log(len(abs_saliency_scores))).item()
else:
z_threshold = self.z_threshold
# Construct spans
spans = []
if mask.any():
abs_saliency_z_scores = self._standardize_scores(abs_saliency_scores)
abs_saliency_z_scores[~mask] = -3.0
padding_counter = 0
span_start = None
for i, z_score in enumerate(abs_saliency_z_scores):
if z_score >= z_threshold:
if span_start is None:
span_start = max(i - self.span_padding, 0)
padding_counter = 0
elif span_start is not None:
if padding_counter == self.span_padding or i == len(abs_saliency_z_scores) - 1:
span_end = i + 1 if i == len(abs_saliency_z_scores) - 1 else i
span_text_bounds = token_text_mapper.get_text_bounds(span_start, span_end)
span_text = token_text_mapper.get_text(span_start, span_end)
# If this span intersects a document, get the span bounds within the document
document_span = self._create_document_span(token_text_mapper, span_start, span_end, document_bounds)
spans.append(
SalientSpan(
(span_start, span_end),
span_text_bounds,
span_text,
document_span,
)
)
span_start = None
padding_counter += 1
return spans, z_threshold
def _standardize_scores(self, scores: np.ndarray) -> np.ndarray:
scores = scores.reshape(-1, 1)
sc = StandardScaler()
return sc.fit_transform(scores).squeeze(1)
class SlidingWindowSystem(SaliencyBasedSystem):
def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
window_size: int = 7,
window_overlap: int = 2,
window_batch_size: int = 1,
entropy_scale_saliency_scores: bool = False,
smoothing_window_size: Optional[int] = None,
normalize_saliency_scores: bool = False,
z_threshold: Union[float, str] = "auto",
span_padding: int = 7,
):
super().__init__(
model,
tokenizer,
entropy_scale_saliency_scores,
smoothing_window_size,
normalize_saliency_scores,
z_threshold,
span_padding,
)
self.window_size = window_size
self.window_overlap = window_overlap
self.window_batch_size = window_batch_size
def generate_citations(
self,
token_text_mapper: TokenTextMapper,
context_length: int,
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
) -> CitationResults:
if not answer_bounds:
answer_bounds = self._get_answer_bounds_all_sentences(token_text_mapper, context_length)
answer_bounds = [bounds[1] for bounds in answer_bounds]
if isinstance(answer_bounds[0], int):
answer_bounds = [answer_bounds]
if document_bounds and isinstance(document_bounds[0], int):
document_bounds = [document_bounds]
# Get the context logits and the token-level losses for the answer
losses_dict, context_logits, _ = self._get_model_outputs(token_text_mapper, context_length)
# Compute saliency and extract citation/conflict spans for each answer segment
results = []
for answer_bounds_ in answer_bounds:
# Compute saliency scores for each token with respect to its impact on
# the likelihood of the current answer segment
saliency_scores = self._compute_saliency_scores(losses_dict, context_logits, context_length, answer_bounds_)
# Create citation spans
citation_spans, z_threshold = self._create_spans(token_text_mapper, saliency_scores, True, document_bounds)
# Create conflict spans
conflict_spans, _ = self._create_spans(token_text_mapper, saliency_scores, False, document_bounds)
text_answer_bounds = token_text_mapper.get_text_bounds(*answer_bounds_)
answer_text = token_text_mapper.get_text(*answer_bounds_)
results.append(
CitationResult(
answer_bounds_,
text_answer_bounds,
answer_text,
saliency_scores.tolist(),
citation_spans,
conflict_spans,
z_threshold,
)
)
text_context_length = token_text_mapper.get_text_idx(context_length)
citation_results = CitationResults(
results,
context_length,
text_context_length,
token_text_mapper.text,
token_text_mapper.tokenizer.eos_token,
)
return citation_results
def _get_model_outputs(
self,
token_text_mapper: TokenTextMapper,
context_length: int,
) -> Tuple[Dict[Optional[Tuple[int, int]], torch.Tensor], torch.Tensor, torch.Tensor]:
# Pre-compute the KV cache for the context
losses_dict = {}
input_ids = token_text_mapper.input_ids.clone().to(self.model.device)
with torch.no_grad():
outputs = self.model(input_ids[:, :context_length-1], return_dict=True)
past_key_values = outputs.past_key_values
context_logits, context_labels = outputs.logits, input_ids[:, :context_length-1]
# Pre-compute the token-level answer losses in batch: first with no windows applied and then for each window
control_token_mask = self._get_control_token_mask(input_ids)
window_ranges = [None] + self._get_window_ranges(context_length)
for i in range(0, len(window_ranges), self.window_batch_size):
batch_window_ranges = window_ranges[i:i+self.window_batch_size]
batch_input_ids = input_ids.expand(len(batch_window_ranges), input_ids.shape[-1])
batch_control_token_mask = control_token_mask.expand(len(batch_window_ranges), control_token_mask.shape[-1])
batch_past_key_values = [
(layer[0].expand(len(batch_window_ranges), *layer[0].shape[1:]), layer[1].expand(len(batch_window_ranges), *layer[1].shape[1:]))
for layer in past_key_values
]
with torch.no_grad():
batch_logits, batch_labels, _ = self._get_logits(
batch_input_ids,
context_length,
batch_window_ranges,
batch_control_token_mask,
batch_past_key_values,
)
batch_losses = self._get_loss(batch_logits, batch_labels, reduce=False)
for window_range, losses in zip(batch_window_ranges, batch_losses):
losses_dict[window_range] = losses
return losses_dict, context_logits, context_labels
def _compute_saliency_scores(
self,
losses_dict: Dict[Optional[Tuple[int, int]], torch.Tensor],
context_logits: torch.Tensor,
context_length: int,
answer_bounds: Tuple[int, int],
) -> torch.Tensor:
base_loss = self._pool_losses(losses_dict[None], context_length, answer_bounds)
window_losses, token_window_ids = self._calculate_window_losses(losses_dict, context_length, answer_bounds)
window_saliency_scores = window_losses - base_loss
token_saliency_scores = torch.stack([window_saliency_scores[token_window_ids[i]].mean() for i in range(context_length)])
if self.entropy_scale_saliency_scores:
token_saliency_scores = self._entropy_scale_saliency_scores(context_logits, token_saliency_scores)
token_saliency_scores = self._smooth_and_normalize_saliency_scores(token_saliency_scores)
return token_saliency_scores
def _get_control_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
# TODO: Pass answer_prefix in from the configured answer template. Not sure how to support general
# "user-provided" patterns that should be considered control tokens. For example, the answer prefix
# should only be considered a control token if it's the first content in an assistant turn, and so on.
answer_prefix = "Answer:"
# Get the control token sequences for this model
ctl_seqs = [[self.tokenizer.convert_tokens_to_ids(tok)] for tok in self.tokenizer.all_special_tokens]
if self.tokenizer.chat_template is not None:
chat_template_ctl_seqs = [
# Llama-2 / Mistral
"[INST]",
"[/INST]",
f"[/INST] {answer_prefix}",
# Llama-3
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|start_header_id|>assistant<|end_header_id|>\n\n",
f"<|start_header_id|>assistant<|end_header_id|>\n\n{answer_prefix}",
# Zephyr
"<|system|>\n",
"<|user|>\n",
"<|assistant|>\n",
f"<|assistant|>\n{answer_prefix}",
# Zephyr tokenizes the < differently if it's at the start of a line
"\n<|system|>\n",
"\n<|user|>\n",
"\n<|assistant|>\n",
f"\n<|assistant|>\n{answer_prefix}",
# TODO: Add support for other model-specific chat template control sequences here
]
for ctl_seq in chat_template_ctl_seqs:
if ctl_seq.replace(answer_prefix, "").strip() in self.tokenizer.chat_template:
ctl_seqs.append(self.tokenizer.encode(ctl_seq, add_special_tokens=False))
# Create the mask
ctl_seqs = [torch.tensor(ctl_seq).to(input_ids.device) for ctl_seq in ctl_seqs]
control_token_mask = torch.zeros_like(input_ids)
for i in range(input_ids.shape[-1]):
for ctl_seq in ctl_seqs:
if i+len(ctl_seq) <= input_ids.shape[-1] and torch.all(input_ids[0, i:i+len(ctl_seq)] == ctl_seq):
control_token_mask[0, i:i+len(ctl_seq)] = 1
return control_token_mask
def _calculate_window_losses(self, losses_dict, context_length: int, answer_bounds: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
window_losses = []
token_window_ids = [[] for _ in range(context_length)]
for start, end in self._get_window_ranges(context_length):
for i in range(start, end):
token_window_ids[i].append(len(window_losses))
loss = self._pool_losses(losses_dict[(start,end)], context_length, answer_bounds)
window_losses.append(loss)
return torch.stack(window_losses), token_window_ids
def _get_window_ranges(self, context_length: int) -> List[Tuple[int, int]]:
return [(start, min(start + self.window_size, context_length))
for start in range(0, context_length - self.window_overlap, self.window_size - self.window_overlap)]
class GradientBasedSystem(SaliencyBasedSystem):
def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
entropy_scale_saliency_scores: bool = False,
smoothing_window_size: Optional[int] = None,
normalize_saliency_scores: bool = False,
z_threshold: Union[float, str] = "auto",
span_padding: int = 7,
):
super().__init__(
model,
tokenizer,
entropy_scale_saliency_scores,
smoothing_window_size,
normalize_saliency_scores,
z_threshold,
span_padding,
)
def generate_citations(
self,
token_text_mapper: TokenTextMapper,
context_length: int,
answer_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
document_bounds: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
) -> CitationResults:
# TODO: much of this method is shared with the sliding window implementation.
# See how we can abstract the shared logic out to SalienceBasedSystem
if not answer_bounds:
answer_bounds = self._get_answer_bounds_all_sentences(token_text_mapper, context_length)
answer_bounds = [bounds[1] for bounds in answer_bounds]
if isinstance(answer_bounds[0], int):
answer_bounds = [answer_bounds]
if document_bounds and isinstance(document_bounds[0], int):
document_bounds = [document_bounds]
# Compute saliency and extract citation/conflict spans for each answer segment
results = []
for answer_bounds_ in answer_bounds:
# Compute saliency scores for each token with respect to its impact on
# the likelihood of the current answer segment
saliency_scores = self._compute_saliency_scores(token_text_mapper.input_ids, context_length, answer_bounds_)
# Create citation spans
citation_spans, z_threshold = self._create_spans(token_text_mapper, saliency_scores, True, document_bounds)
# Create conflict spans
conflict_spans, _ = self._create_spans(token_text_mapper, saliency_scores, False, document_bounds)
text_answer_bounds = token_text_mapper.get_text_bounds(*answer_bounds_)
answer_text = token_text_mapper.get_text(*answer_bounds_)
results.append(
CitationResult(
answer_bounds_,
text_answer_bounds,
answer_text,
saliency_scores.tolist(),
citation_spans,
conflict_spans,
z_threshold,
)
)
text_context_length = token_text_mapper.get_text_idx(context_length)
citation_results = CitationResults(
results,
context_length,
text_context_length,
token_text_mapper.text,
token_text_mapper.tokenizer.eos_token,
)
return citation_results
def _compute_saliency_scores(
self,
input_ids: torch.Tensor,
context_length: int,
answer_bounds: Tuple[int, int],
) -> torch.Tensor:
# We have to separately compute the logits for each answer segment because the gradient computation
# will be different for each segment
input_ids = input_ids.clone().to(self.model.device)
logits, labels, inputs_embeds = self._get_logits(input_ids)
loss = self._get_loss(logits, labels, 1, answer_bounds)
inputs_embeds.retain_grad()
# backpropagate the loss to the input embeddings
self.model.zero_grad()
loss.backward()
# compute the component of the gradients with respect to the input embeddings
inputs_grads = inputs_embeds.grad[0, :context_length]
inputs_embeds = inputs_embeds.detach()[0, :context_length]
token_saliency_scores = -(inputs_grads * inputs_embeds).sum(dim=-1) / inputs_embeds.norm(p=2, dim=-1)
if self.entropy_scale_saliency_scores:
token_saliency_scores = self._entropy_scale_saliency_scores(logits, token_saliency_scores)
token_saliency_scores = self._smooth_and_normalize_saliency_scores(token_saliency_scores)
return token_saliency_scores
class PromptBasedCitationResult(CitationResult):
def __init__(
self,
text_answer_bounds: Tuple[int, int],
answer_text: str,
citation_spans: List[SalientSpan],
conflict_spans: List[SalientSpan],
citation_messages: List[List[Dict[str, str]]],
conflict_messages: List[List[Dict[str, str]]],
):
super().__init__(None, text_answer_bounds, answer_text, None, citation_spans, conflict_spans, 0.0)
self.citation_messages = citation_messages
self.conflict_messages = conflict_messages
class PromptBasedCitationResults(CitationResults):
def __init__(self, results: List[PromptBasedCitationResult], text_context_length: int, text: str):
super().__init__(results, None, text_context_length, text, "")
class PromptBasedSystem:
def __init__(
self,
config: Dict[str, str],
model: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
self.config = config
self.model_handler = self._create_model_handler(model, tokenizer)
def _create_model_handler(
self,
model: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
) -> ModelHandler:
model_name = self.config['model']
if model is not None and tokenizer is not None:
return HFModelHandler(model, tokenizer)
elif model_name.startswith(('gpt-4', 'claude')):
return APIModelHandler(api_key=self.config.get('api_key'), model_name=model_name)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', offload_folder="offload_citation_system", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return HFModelHandler(model, tokenizer)
def generate_citations(
self,
question: str,
answer: str,
documents: List[Dict[str, str]],
question_context: Optional[str] = None,
) -> PromptBasedCitationResults:
# Setup the citation and conflict prompt messages for each (question/answer, sentence, document) triple
answer_sentences = nltk.sent_tokenize(answer)
answer_sentences = [sent.strip() for sent in answer_sentences if not re.match(LIST_INDEXER_REGEX, sent)]
citation_messages = [[] for _ in range(len(answer_sentences))]
conflict_messages = [[] for _ in range(len(answer_sentences))]
for answer_sent, sent_citation_messages, sent_conflict_messages in zip(answer_sentences, citation_messages, conflict_messages):
for doc_idx, doc in enumerate(documents):
for for_conflicts in [False, True]:
messages = get_fewshot_attribution_messages(self.config, for_conflicts)
messages.extend(
get_attribution_messages(
self.config,
question,
answer,
doc_idx,
doc,
answer_sent,
for_conflicts,
question_context,
response_prefill=self.model_handler.supports_response_prefill,
prepend_instruction=False,
)
)
if for_conflicts:
sent_conflict_messages.append(messages)
else:
sent_citation_messages.append(messages)
# Generate the attribution responses in batch
generation_kwargs = {
"temperature": self.config['temperature'],
"top_p": self.config['top_p'],
"max_tokens": self.config['max_new_tokens'],
"batch_size": self.config.get('gen_batch_size', 1),
"seed": self.config.get("gen_seed"),
}
# Prepare list of sentence-document messages for batch inference
all_messages = citation_messages + conflict_messages
all_sent_messages = list(chain.from_iterable(all_messages))
# Do batch inference
responses = self.model_handler.generate(all_sent_messages, **generation_kwargs)
for messages, response in zip(all_sent_messages, responses):
if self.model_handler.supports_response_prefill:
messages[-1]["content"] += response
else:
messages.append({"role": "assistant", "content": response})
# Process the responses to get the CitationResults
citation_results = self._process_cited_responses(
citation_messages,
conflict_messages,
answer,
answer_sentences,
documents,
)
return citation_results
def _process_cited_responses(
self,
citation_messages: List[List[List[Dict[str, str]]]],
conflict_messages: List[List[List[Dict[str, str]]]],
answer: str,
answer_sentences: List[str],
documents: List[Dict[str, str]],
) -> PromptBasedCitationResults:
# Prepare text body for results object
text = ""
document_start_positions = []
for i, doc in enumerate(documents):
document_start_positions.append(len(text))
text += self.config["document_template"].format(ID=i+1, T=doc.get('title', 'Untitled'), P=doc['text']) + "\n"
text += "\n" + self.config["answer_template"].format(A=answer)
# Extract citation and conflict spans from the responses
citation_results = []
for answer_sent, sent_citation_messages, sent_conflict_messages in zip(answer_sentences, citation_messages, conflict_messages):
citation_spans = []
conflict_spans = []
for doc_idx, doc in enumerate(documents):
citation_response, conflict_response = sent_citation_messages[doc_idx][-1]["content"], sent_conflict_messages[doc_idx][-1]["content"]
doc_start_pos = document_start_positions[doc_idx]
citation_spans.extend(self._extract_spans(citation_response, doc_idx, doc, doc_start_pos, False))
conflict_spans.extend(self._extract_spans(conflict_response, doc_idx, doc, doc_start_pos, True))
answer_start_pos = text.rfind(answer_sent)
citation_results.append(
PromptBasedCitationResult(
(answer_start_pos, answer_start_pos+len(answer_sent)),
answer_sent,
citation_spans,
conflict_spans,
sent_citation_messages,
sent_conflict_messages,
)
)
# Create the results object
text_context_length = list(re.finditer(self.config["context_regex"], text))[-1].end()
citation_results = PromptBasedCitationResults(citation_results, text_context_length, text)
return citation_results
def _extract_spans(
self,
response: str,
doc_idx: int,
doc: Dict[str, str],
doc_start_pos: int,
for_conflicts: bool,
) -> List[SalientSpan]:
spans = []
attr_tag = "conflict" if for_conflicts else "support"
attr_start_tag = f"<|{attr_tag}|>"
attr_end_tag = f"<|end_{attr_tag}|>"
# Get re-written document body from the response
response_doc_match = re.search(self.config["document_regex"], response, re.DOTALL)
if response_doc_match is None:
return spans
response_doc_body = response_doc_match.group(1)
doc_start_pos += response_doc_match.start(1)
# The document needs to have been faithfully re-written by the model, or else we can't expect match bounds to be accurate
# when applied to the original document text.
if response_doc_body.replace(attr_start_tag, "").replace(attr_end_tag, "") != doc["text"]:
return spans
# Extract spans from the response
span_regex = f"{attr_start_tag}(.+?){attr_end_tag}".replace("|", r"\|")
span_matches = re.finditer(span_regex, response_doc_body, re.DOTALL)
for i, span_match in enumerate(span_matches):
offset = (len(attr_start_tag) + len(attr_end_tag)) * i + len(attr_start_tag)
span_text = span_match.group(1)
span_start = span_match.start(1) - offset
span_end = span_match.end(1) - offset
span_text_bounds = (span_start+doc_start_pos, span_end+doc_start_pos)
span_text_rel_bounds = (span_start, span_end)
spans.append(
SalientSpan(
None,
span_text_bounds,
span_text,
DocumentSpan(
None,
span_text_bounds,
span_text,
None,
span_text_rel_bounds,
doc_idx,
),
)
)
return spans