Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import torch
import re
from torch.nn import CrossEntropyLoss
import streamlit as st
import numpy as np
from sklearn.preprocessing import KBinsDiscretizer
import nltk
from typing import List, Tuple, Dict
def get_chat_messages(dialogue_history):
messages = []
pairs = re.split("user:", dialogue_history, flags=re.IGNORECASE | re.DOTALL)
for pair in pairs:
pair = pair.strip()
if not pair:
continue
roles = re.split("assistant:", pair, flags=re.IGNORECASE | re.DOTALL)
messages.append({"role": "user", "content": roles[0].strip()})
if len(roles) > 1:
messages.append({"role": "assistant", "content": roles[1].strip()})
return messages
# Note: This method is LLM-specific. To customize, need to manually add LLM-specific # chat-template control tokens such as [INST] and [\INST] to the method below.
def get_control_token_mask(tokenizer, input_ids):
control_token_mask = np.zeros_like(input_ids)
prev_token = None
for i, token_id in enumerate(input_ids[0]):
token = tokenizer.decode(token_id, skip_special_tokens=False)
if token == "INST" and prev_token == "/":
control_token_mask[0, i - 2 : i + 2] = 1
elif token == "INST":
control_token_mask[0, i - 1 : i + 2] = 1
# TODO: the "" token corresponds to token 28705, a "blank" token that is used in JSON formatting.
# this is considered "control" and excluded from the window because it detracts from
# the saliency of the rest of the tokens in the window.
# Investigate if we can remove this after switching from cross entropy to PMI.
elif token in tokenizer.all_special_tokens + [""]:
control_token_mask[0, i] = 1
prev_token = token
return control_token_mask
@st.cache_resource
def get_logits(
_model,
_past_key_values,
input_ids,
context_length,
mask_windows=None,
control_token_mask=None,
):
input_ids = torch.tensor(input_ids).to(_model.device)
# Reason for using context_length-1 instead of context_length:
# The last token of the context predicts the first token of the last turn, so we need to include it in the input
# if we want the loss to account for all tokens in the last turn.
last_turn_input_ids = input_ids[:, context_length-1:]
attention_mask = torch.ones_like(input_ids)
if mask_windows is None:
mask_windows = []
if len(mask_windows) > 0:
if control_token_mask is None:
raise ValueError(
"control_token_mask must be provided if mask_windows are provided"
)
control_token_mask = torch.tensor(control_token_mask).to(_model.device)
for mask_window_start, mask_window_end in mask_windows:
attention_mask[:, mask_window_start:mask_window_end] = control_token_mask[
:, mask_window_start:mask_window_end
]
with torch.no_grad():
logits = _model(
last_turn_input_ids,
attention_mask=attention_mask,
past_key_values=_past_key_values,
return_dict=True,
).logits
return logits, last_turn_input_ids
st.cache_data
def get_loss(
_model,
_past_key_values,
input_ids,
context_length,
mask_windows=None,
control_token_mask=None,
answer_bounds=None,
):
logits, labels = get_logits(
_model,
_past_key_values,
input_ids,
context_length,
mask_windows,
control_token_mask
)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# set all labels to -100 if they fall outside the answer bounds
if answer_bounds is not None:
shift_labels = shift_labels.clone()
bool_mask = torch.ones_like(shift_labels, dtype=torch.bool)
bool_mask[..., answer_bounds[0]:answer_bounds[1]] = False
shift_labels[bool_mask] = -100
# Flatten the tokens
shift_logits = shift_logits.view(-1, _model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
return loss
def get_input_ids(tokenizer, dialogue_history, return_tensors="np"):
messages = get_chat_messages(dialogue_history)
context_ids = tokenizer.apply_chat_template(messages[:-1], return_tensors=return_tensors)
context_length = context_ids.shape[-1]
input_ids = tokenizer.apply_chat_template(messages, return_tensors=return_tensors)
return input_ids, context_length
def get_answer_with_token_bounds(tokenizer, dialogue_history):
input_ids, context_length = get_input_ids(tokenizer, dialogue_history)
answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[0, context_length:])
answer_token_bounds = []
start_pos = 0
for token in answer_tokens:
answer_token_bounds.append((start_pos, start_pos+len(token)))
start_pos += len(token)
answer_text = "".join(answer_tokens).replace("▁", " ")
return answer_text, answer_token_bounds
def get_answer_bounds_selections(highlight_result, answer_token_bounds):
if highlight_result is None:
highlight_result = []
selections = []
for selected_text in highlight_result:
if not selected_text:
continue
selected_start, selected_end, selected_label = selected_text[0]["start"], selected_text[0]["end"], selected_text[0]["label"]
token_start = [i for i, (start, end) in enumerate(answer_token_bounds) if start <= selected_start < end][0]
token_end = [i for i, (start, end) in enumerate(answer_token_bounds) if start < selected_end <= end][0]
answer_bounds = (token_start, token_end+1)
selections.append((selected_label, answer_bounds))
return selections
def create_spans(citation_span_padding, texts, text_rel_loss_bins):
citation_spans = []
citation_span_texts = []
padding_counter = 0
span_start = None
max_bin = max(text_rel_loss_bins)
for i, rel_loss_bin in enumerate(text_rel_loss_bins):
if rel_loss_bin == max_bin:
if span_start is None:
span_start = max(i - citation_span_padding, 0)
padding_counter = 0
elif span_start is not None:
if padding_counter == citation_span_padding or i == len(text_rel_loss_bins) - 1:
span_end = i + 1 if i == len(text_rel_loss_bins) - 1 else i
citation_spans.append((span_start, span_end))
citation_span_texts.append(" ".join(texts[span_start:span_end]))
span_start = None
padding_counter += 1
return citation_spans, citation_span_texts
def discretize_scores(scores, n_bins=3):
"""Discretize losses into bins."""
kbins = KBinsDiscretizer(n_bins=n_bins, encode="ordinal", strategy="uniform")
bins = kbins.fit_transform(np.expand_dims(scores, 1)).squeeze(1).astype(int)
return bins
def generate_citations_logic(tokenizer, model, device, input_ids, context_length, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, past_key_values=None, answer_bounds=None):
control_token_mask = get_control_token_mask(tokenizer, input_ids)
token_window_ids = [[] for _ in range(context_length)]
window_losses = []
window_texts = []
window_ranges = []
if past_key_values is None:
with torch.no_grad():
inp_ids = torch.tensor(input_ids).to(device)
attention_mask = torch.ones_like(inp_ids)
past_key_values = model(
inp_ids[:, :context_length-1],
attention_mask=attention_mask[:, :context_length-1],
return_dict=True,
).past_key_values
base_loss = get_loss(model, past_key_values, input_ids, context_length, answer_bounds=answer_bounds)
for start in range(0, context_length - overlapping_tokens, window_size - overlapping_tokens):
end = min(start + window_size, context_length)
text = tokenizer.decode(input_ids[0, start:end], skip_special_tokens=False)
for i in range(start, end):
token_window_ids[i].append(len(window_losses))
loss = get_loss(
model,
past_key_values,
input_ids,
context_length,
[(start, end)],
control_token_mask,
answer_bounds=answer_bounds,
)
window_losses.append(loss)
window_texts.append(text)
window_ranges.append((start, end))
window_losses = torch.stack(window_losses)
rel_losses = window_losses - base_loss
if normalize_citation_scores:
rel_losses /= rel_losses.abs().max()
token_rel_losses = [rel_losses[token_window_ids[i]].mean().item() for i in range(context_length)]
texts = []
text_rel_losses = []
text_start = 0
last_rel_loss = None
for i, rel_loss in enumerate(token_rel_losses):
if (last_rel_loss is not None and rel_loss != last_rel_loss) or i == len(token_rel_losses) - 1:
text_end = i + 1 if i == len(token_rel_losses) - 1 else i
text = tokenizer.decode(input_ids[0, text_start : text_end])
texts.append(text)
text_rel_loss = rel_loss if i == len(token_rel_losses)-1 else last_rel_loss
text_rel_losses.append(text_rel_loss)
text_start = i
last_rel_loss = rel_loss
text_rel_losses = np.array(text_rel_losses)
abs_rel_losses = np.abs(text_rel_losses)
is_positive = text_rel_losses >= 0
pos_rel_losses = abs_rel_losses[is_positive]
neg_rel_losses = abs_rel_losses[~is_positive]
pos_bins = discretize_scores(pos_rel_losses)
neg_bins = discretize_scores(neg_rel_losses)
text_rel_loss_bins = np.zeros(len(text_rel_losses), dtype=int)
text_rel_loss_bins[is_positive] = pos_bins
citation_spans, citation_span_texts = create_spans(
citation_span_padding, texts, text_rel_loss_bins
)
text_rel_loss_bins[is_positive] = 0
text_rel_loss_bins[~is_positive] = neg_bins
contradiction_spans, contradiction_span_texts = create_spans(
citation_span_padding, texts, text_rel_loss_bins
)
return texts, text_rel_losses, window_texts, window_ranges, rel_losses, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans
def generate_inline_citations(context: str, answer_text: str, tokenizer, model, device, dialogue_history) -> str:
# Sentence tokenize the answer_text
sentences = nltk.sent_tokenize(answer_text)
input_ids, context_length = get_input_ids(tokenizer, dialogue_history)
annotated_sentences = []
current_position = 0
# Extract source spans from the context
source_pattern = r'\[( *\d+ *)\] Source:'
source_matches = list(re.finditer(source_pattern, context))
source_spans = [(int(m.group(1)), m.start()) for m in source_matches]
context_end = context.find('```') if '```' in context else len(context)
for sent_index, sentence in enumerate(sentences):
# Tokenize the sentence
sentence_tokens = tokenizer.encode(sentence)
answer_bounds = (current_position, current_position + len(sentence_tokens))
# Generate citation and contradiction spans for this sentence
texts, _, _, _, _, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans = generate_citations_logic(
tokenizer, model, device, input_ids, context_length,
window_size=7, overlapping_tokens=2, citation_span_padding=3,
normalize_citation_scores=True,
answer_bounds=answer_bounds
)
citation_span_indicies = []
contradiction_span_indicies = []
# Convert textspans to context_index
for span_start, span_end in citation_spans:
txt = " ".join(texts[span_start:span_end])
start = context.find(txt)
end = start + len(txt)
citation_span_indicies.append((start, end))
for span_start, span_end in contradiction_spans:
txt = " ".join(texts[span_start:span_end])
start = context.find(txt)
end = start + len(txt)
contradiction_span_indicies.append((start, end))
# Convert spans to dicts and add answer_sent_index
citation_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in citation_span_indicies]
contradiction_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in contradiction_span_indicies]
# Map spans to sources
def map_span_to_source(span):
span_start = span['start']
for i, (source_num, source_start) in enumerate(source_spans):
if i == len(source_spans) - 1:
if source_start <= span_start < context_end:
return source_num
elif source_start <= span_start < source_spans[i+1][1]:
return source_num
return "unknown"
for span in citation_dicts + contradiction_dicts:
span['source'] = map_span_to_source(span)
# Annotate the sentence
annotated_sentence = annotate_sentence(sentence, citation_dicts, contradiction_dicts, sent_index)
annotated_sentences.append(annotated_sentence)
current_position += len(sentence_tokens)
return " ".join(annotated_sentences)
def find_conflicting_spans(citations: List[Dict], contradictions: List[Dict]) -> List[Dict]:
conflicts = []
for citation in citations:
for contradiction in contradictions:
if spans_overlap(citation, contradiction):
conflicts.append({
'span': get_overlap(citation, contradiction),
'citation': citation,
'contradiction': contradiction
})
return conflicts
def spans_overlap(span1: Dict, span2: Dict) -> bool:
return span1['start'] < span2['end'] and span2['start'] < span1['end']
def get_overlap(span1: Dict, span2: Dict) -> Tuple[int, int]:
return (max(span1['start'], span2['start']), min(span1['end'], span2['end']))
def annotate_sentence(sentence: str, citation_spans: List[Dict], contradiction_spans: List[Dict], sent_index: int) -> str:
# Filter spans for the current sentence
relevant_citations = [span for span in citation_spans if span['answer_sent_index'] == sent_index]
relevant_contradictions = [span for span in contradiction_spans if span['answer_sent_index'] == sent_index]
# Collect all annotations
annotations = []
for span in relevant_citations:
if span['source'] == "unknown":
continue
annotations.append(f"[{span['source']}]")
# Handle conflicts
for con_span in relevant_contradictions:
# conflicting_citations = [cit for cit in relevant_citations if spans_overlap(cit, con_span)]
# for cit in conflicting_citations:
# cit_num = cit['source']
con_num = con_span['source']
# if cit_num == "unknown" or con_num == "unknown":
if con_span['source'] == "unknown":
continue
# annotations.append(f"[{cit_num}][conflicts with [{con_num}]?]")
annotations.append(f"[conflicts with [{con_num}]?]")
# Remove duplicates and sort
annotations = sorted(set(annotations))
# Check for .</s> at the end of the sentence
eos_match = re.search(r'\.(\</s\>)?$', sentence)
if eos_match:
eos_start = eos_match.start()
stripped_sentence = sentence[:eos_start]
eos_token = sentence[eos_start:]
else:
stripped_sentence = sentence.rstrip()
eos_token = ''
# Apply annotations
if annotations:
result = stripped_sentence + ' ' + ' '.join(annotations)
else:
result = stripped_sentence + ' [citation needed]'
# Add the eos_token back
if eos_token:
result += eos_token
return result