Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/utils.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
385 lines (341 sloc)
16 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import re | |
from torch.nn import CrossEntropyLoss | |
import streamlit as st | |
import numpy as np | |
from sklearn.preprocessing import KBinsDiscretizer | |
import nltk | |
from typing import List, Tuple, Dict | |
def get_chat_messages(dialogue_history): | |
messages = [] | |
pairs = re.split("user:", dialogue_history, flags=re.IGNORECASE | re.DOTALL) | |
for pair in pairs: | |
pair = pair.strip() | |
if not pair: | |
continue | |
roles = re.split("assistant:", pair, flags=re.IGNORECASE | re.DOTALL) | |
messages.append({"role": "user", "content": roles[0].strip()}) | |
if len(roles) > 1: | |
messages.append({"role": "assistant", "content": roles[1].strip()}) | |
return messages | |
# Note: This method is LLM-specific. To customize, need to manually add LLM-specific # chat-template control tokens such as [INST] and [\INST] to the method below. | |
def get_control_token_mask(tokenizer, input_ids): | |
control_token_mask = np.zeros_like(input_ids) | |
prev_token = None | |
for i, token_id in enumerate(input_ids[0]): | |
token = tokenizer.decode(token_id, skip_special_tokens=False) | |
if token == "INST" and prev_token == "/": | |
control_token_mask[0, i - 2 : i + 2] = 1 | |
elif token == "INST": | |
control_token_mask[0, i - 1 : i + 2] = 1 | |
# TODO: the "" token corresponds to token 28705, a "blank" token that is used in JSON formatting. | |
# this is considered "control" and excluded from the window because it detracts from | |
# the saliency of the rest of the tokens in the window. | |
# Investigate if we can remove this after switching from cross entropy to PMI. | |
elif token in tokenizer.all_special_tokens + [""]: | |
control_token_mask[0, i] = 1 | |
prev_token = token | |
return control_token_mask | |
@st.cache_resource | |
def get_logits( | |
_model, | |
_past_key_values, | |
input_ids, | |
context_length, | |
mask_windows=None, | |
control_token_mask=None, | |
): | |
input_ids = torch.tensor(input_ids).to(_model.device) | |
# Reason for using context_length-1 instead of context_length: | |
# The last token of the context predicts the first token of the last turn, so we need to include it in the input | |
# if we want the loss to account for all tokens in the last turn. | |
last_turn_input_ids = input_ids[:, context_length-1:] | |
attention_mask = torch.ones_like(input_ids) | |
if mask_windows is None: | |
mask_windows = [] | |
if len(mask_windows) > 0: | |
if control_token_mask is None: | |
raise ValueError( | |
"control_token_mask must be provided if mask_windows are provided" | |
) | |
control_token_mask = torch.tensor(control_token_mask).to(_model.device) | |
for mask_window_start, mask_window_end in mask_windows: | |
attention_mask[:, mask_window_start:mask_window_end] = control_token_mask[ | |
:, mask_window_start:mask_window_end | |
] | |
with torch.no_grad(): | |
logits = _model( | |
last_turn_input_ids, | |
attention_mask=attention_mask, | |
past_key_values=_past_key_values, | |
return_dict=True, | |
).logits | |
return logits, last_turn_input_ids | |
st.cache_data | |
def get_loss( | |
_model, | |
_past_key_values, | |
input_ids, | |
context_length, | |
mask_windows=None, | |
control_token_mask=None, | |
answer_bounds=None, | |
): | |
logits, labels = get_logits( | |
_model, | |
_past_key_values, | |
input_ids, | |
context_length, | |
mask_windows, | |
control_token_mask | |
) | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# set all labels to -100 if they fall outside the answer bounds | |
if answer_bounds is not None: | |
shift_labels = shift_labels.clone() | |
bool_mask = torch.ones_like(shift_labels, dtype=torch.bool) | |
bool_mask[..., answer_bounds[0]:answer_bounds[1]] = False | |
shift_labels[bool_mask] = -100 | |
# Flatten the tokens | |
shift_logits = shift_logits.view(-1, _model.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Ensure tensors are on the same device | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(shift_logits, shift_labels) | |
return loss | |
def get_input_ids(tokenizer, dialogue_history, return_tensors="np"): | |
messages = get_chat_messages(dialogue_history) | |
context_ids = tokenizer.apply_chat_template(messages[:-1], return_tensors=return_tensors) | |
context_length = context_ids.shape[-1] | |
input_ids = tokenizer.apply_chat_template(messages, return_tensors=return_tensors) | |
return input_ids, context_length | |
def get_answer_with_token_bounds(tokenizer, dialogue_history): | |
input_ids, context_length = get_input_ids(tokenizer, dialogue_history) | |
answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[0, context_length:]) | |
answer_token_bounds = [] | |
start_pos = 0 | |
for token in answer_tokens: | |
answer_token_bounds.append((start_pos, start_pos+len(token))) | |
start_pos += len(token) | |
answer_text = "".join(answer_tokens).replace("▁", " ") | |
return answer_text, answer_token_bounds | |
def get_answer_bounds_selections(highlight_result, answer_token_bounds): | |
if highlight_result is None: | |
highlight_result = [] | |
selections = [] | |
for selected_text in highlight_result: | |
if not selected_text: | |
continue | |
selected_start, selected_end, selected_label = selected_text[0]["start"], selected_text[0]["end"], selected_text[0]["label"] | |
token_start = [i for i, (start, end) in enumerate(answer_token_bounds) if start <= selected_start < end][0] | |
token_end = [i for i, (start, end) in enumerate(answer_token_bounds) if start < selected_end <= end][0] | |
answer_bounds = (token_start, token_end+1) | |
selections.append((selected_label, answer_bounds)) | |
return selections | |
def create_spans(citation_span_padding, texts, text_rel_loss_bins): | |
citation_spans = [] | |
citation_span_texts = [] | |
padding_counter = 0 | |
span_start = None | |
max_bin = max(text_rel_loss_bins) | |
for i, rel_loss_bin in enumerate(text_rel_loss_bins): | |
if rel_loss_bin == max_bin: | |
if span_start is None: | |
span_start = max(i - citation_span_padding, 0) | |
padding_counter = 0 | |
elif span_start is not None: | |
if padding_counter == citation_span_padding or i == len(text_rel_loss_bins) - 1: | |
span_end = i + 1 if i == len(text_rel_loss_bins) - 1 else i | |
citation_spans.append((span_start, span_end)) | |
citation_span_texts.append(" ".join(texts[span_start:span_end])) | |
span_start = None | |
padding_counter += 1 | |
return citation_spans, citation_span_texts | |
def discretize_scores(scores, n_bins=3): | |
"""Discretize losses into bins.""" | |
kbins = KBinsDiscretizer(n_bins=n_bins, encode="ordinal", strategy="uniform") | |
bins = kbins.fit_transform(np.expand_dims(scores, 1)).squeeze(1).astype(int) | |
return bins | |
def generate_citations_logic(tokenizer, model, device, input_ids, context_length, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, past_key_values=None, answer_bounds=None): | |
control_token_mask = get_control_token_mask(tokenizer, input_ids) | |
token_window_ids = [[] for _ in range(context_length)] | |
window_losses = [] | |
window_texts = [] | |
window_ranges = [] | |
if past_key_values is None: | |
with torch.no_grad(): | |
inp_ids = torch.tensor(input_ids).to(device) | |
attention_mask = torch.ones_like(inp_ids) | |
past_key_values = model( | |
inp_ids[:, :context_length-1], | |
attention_mask=attention_mask[:, :context_length-1], | |
return_dict=True, | |
).past_key_values | |
base_loss = get_loss(model, past_key_values, input_ids, context_length, answer_bounds=answer_bounds) | |
for start in range(0, context_length - overlapping_tokens, window_size - overlapping_tokens): | |
end = min(start + window_size, context_length) | |
text = tokenizer.decode(input_ids[0, start:end], skip_special_tokens=False) | |
for i in range(start, end): | |
token_window_ids[i].append(len(window_losses)) | |
loss = get_loss( | |
model, | |
past_key_values, | |
input_ids, | |
context_length, | |
[(start, end)], | |
control_token_mask, | |
answer_bounds=answer_bounds, | |
) | |
window_losses.append(loss) | |
window_texts.append(text) | |
window_ranges.append((start, end)) | |
window_losses = torch.stack(window_losses) | |
rel_losses = window_losses - base_loss | |
if normalize_citation_scores: | |
rel_losses /= rel_losses.abs().max() | |
token_rel_losses = [rel_losses[token_window_ids[i]].mean().item() for i in range(context_length)] | |
texts = [] | |
text_rel_losses = [] | |
text_start = 0 | |
last_rel_loss = None | |
for i, rel_loss in enumerate(token_rel_losses): | |
if (last_rel_loss is not None and rel_loss != last_rel_loss) or i == len(token_rel_losses) - 1: | |
text_end = i + 1 if i == len(token_rel_losses) - 1 else i | |
text = tokenizer.decode(input_ids[0, text_start : text_end]) | |
texts.append(text) | |
text_rel_loss = rel_loss if i == len(token_rel_losses)-1 else last_rel_loss | |
text_rel_losses.append(text_rel_loss) | |
text_start = i | |
last_rel_loss = rel_loss | |
text_rel_losses = np.array(text_rel_losses) | |
abs_rel_losses = np.abs(text_rel_losses) | |
is_positive = text_rel_losses >= 0 | |
pos_rel_losses = abs_rel_losses[is_positive] | |
neg_rel_losses = abs_rel_losses[~is_positive] | |
pos_bins = discretize_scores(pos_rel_losses) | |
neg_bins = discretize_scores(neg_rel_losses) | |
text_rel_loss_bins = np.zeros(len(text_rel_losses), dtype=int) | |
text_rel_loss_bins[is_positive] = pos_bins | |
citation_spans, citation_span_texts = create_spans( | |
citation_span_padding, texts, text_rel_loss_bins | |
) | |
text_rel_loss_bins[is_positive] = 0 | |
text_rel_loss_bins[~is_positive] = neg_bins | |
contradiction_spans, contradiction_span_texts = create_spans( | |
citation_span_padding, texts, text_rel_loss_bins | |
) | |
return texts, text_rel_losses, window_texts, window_ranges, rel_losses, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans | |
def generate_inline_citations(context: str, answer_text: str, tokenizer, model, device, dialogue_history) -> str: | |
# Sentence tokenize the answer_text | |
sentences = nltk.sent_tokenize(answer_text) | |
input_ids, context_length = get_input_ids(tokenizer, dialogue_history) | |
annotated_sentences = [] | |
current_position = 0 | |
# Extract source spans from the context | |
source_pattern = r'\[( *\d+ *)\] Source:' | |
source_matches = list(re.finditer(source_pattern, context)) | |
source_spans = [(int(m.group(1)), m.start()) for m in source_matches] | |
context_end = context.find('```') if '```' in context else len(context) | |
for sent_index, sentence in enumerate(sentences): | |
# Tokenize the sentence | |
sentence_tokens = tokenizer.encode(sentence) | |
answer_bounds = (current_position, current_position + len(sentence_tokens)) | |
# Generate citation and contradiction spans for this sentence | |
texts, _, _, _, _, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans = generate_citations_logic( | |
tokenizer, model, device, input_ids, context_length, | |
window_size=7, overlapping_tokens=2, citation_span_padding=3, | |
normalize_citation_scores=True, | |
answer_bounds=answer_bounds | |
) | |
citation_span_indicies = [] | |
contradiction_span_indicies = [] | |
# Convert textspans to context_index | |
for span_start, span_end in citation_spans: | |
txt = " ".join(texts[span_start:span_end]) | |
start = context.find(txt) | |
end = start + len(txt) | |
citation_span_indicies.append((start, end)) | |
for span_start, span_end in contradiction_spans: | |
txt = " ".join(texts[span_start:span_end]) | |
start = context.find(txt) | |
end = start + len(txt) | |
contradiction_span_indicies.append((start, end)) | |
# Convert spans to dicts and add answer_sent_index | |
citation_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in citation_span_indicies] | |
contradiction_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in contradiction_span_indicies] | |
# Map spans to sources | |
def map_span_to_source(span): | |
span_start = span['start'] | |
for i, (source_num, source_start) in enumerate(source_spans): | |
if i == len(source_spans) - 1: | |
if source_start <= span_start < context_end: | |
return source_num | |
elif source_start <= span_start < source_spans[i+1][1]: | |
return source_num | |
return "unknown" | |
for span in citation_dicts + contradiction_dicts: | |
span['source'] = map_span_to_source(span) | |
# Annotate the sentence | |
annotated_sentence = annotate_sentence(sentence, citation_dicts, contradiction_dicts, sent_index) | |
annotated_sentences.append(annotated_sentence) | |
current_position += len(sentence_tokens) | |
return " ".join(annotated_sentences) | |
def find_conflicting_spans(citations: List[Dict], contradictions: List[Dict]) -> List[Dict]: | |
conflicts = [] | |
for citation in citations: | |
for contradiction in contradictions: | |
if spans_overlap(citation, contradiction): | |
conflicts.append({ | |
'span': get_overlap(citation, contradiction), | |
'citation': citation, | |
'contradiction': contradiction | |
}) | |
return conflicts | |
def spans_overlap(span1: Dict, span2: Dict) -> bool: | |
return span1['start'] < span2['end'] and span2['start'] < span1['end'] | |
def get_overlap(span1: Dict, span2: Dict) -> Tuple[int, int]: | |
return (max(span1['start'], span2['start']), min(span1['end'], span2['end'])) | |
def annotate_sentence(sentence: str, citation_spans: List[Dict], contradiction_spans: List[Dict], sent_index: int) -> str: | |
# Filter spans for the current sentence | |
relevant_citations = [span for span in citation_spans if span['answer_sent_index'] == sent_index] | |
relevant_contradictions = [span for span in contradiction_spans if span['answer_sent_index'] == sent_index] | |
# Collect all annotations | |
annotations = [] | |
for span in relevant_citations: | |
if span['source'] == "unknown": | |
continue | |
annotations.append(f"[{span['source']}]") | |
# Handle conflicts | |
for con_span in relevant_contradictions: | |
# conflicting_citations = [cit for cit in relevant_citations if spans_overlap(cit, con_span)] | |
# for cit in conflicting_citations: | |
# cit_num = cit['source'] | |
con_num = con_span['source'] | |
# if cit_num == "unknown" or con_num == "unknown": | |
if con_span['source'] == "unknown": | |
continue | |
# annotations.append(f"[{cit_num}][conflicts with [{con_num}]?]") | |
annotations.append(f"[conflicts with [{con_num}]?]") | |
# Remove duplicates and sort | |
annotations = sorted(set(annotations)) | |
# Check for .</s> at the end of the sentence | |
eos_match = re.search(r'\.(\</s\>)?$', sentence) | |
if eos_match: | |
eos_start = eos_match.start() | |
stripped_sentence = sentence[:eos_start] | |
eos_token = sentence[eos_start:] | |
else: | |
stripped_sentence = sentence.rstrip() | |
eos_token = '' | |
# Apply annotations | |
if annotations: | |
result = stripped_sentence + ' ' + ' '.join(annotations) | |
else: | |
result = stripped_sentence + ' [citation needed]' | |
# Add the eos_token back | |
if eos_token: | |
result += eos_token | |
return result |