import torch
import re
from torch.nn import CrossEntropyLoss
import streamlit as st
import numpy as np
from sklearn.preprocessing import KBinsDiscretizer
import nltk
from typing import List, Tuple, Dict
def get_chat_messages(dialogue_history):
messages = []
pairs = re.split("user:", dialogue_history, flags=re.IGNORECASE | re.DOTALL)
for pair in pairs:
pair = pair.strip()
if not pair:
roles = re.split("assistant:", pair, flags=re.IGNORECASE | re.DOTALL)
messages.append({"role": "user", "content": roles[0].strip()})
if len(roles) > 1:
messages.append({"role": "assistant", "content": roles[1].strip()})
return messages
# Note: This method is LLM-specific. To customize, need to manually add LLM-specific # chat-template control tokens such as [INST] and [\INST] to the method below.
def get_control_token_mask(tokenizer, input_ids):
control_token_mask = np.zeros_like(input_ids)
prev_token = None
for i, token_id in enumerate(input_ids[0]):
token = tokenizer.decode(token_id, skip_special_tokens=False)
if token == "INST" and prev_token == "/":
control_token_mask[0, i - 2 : i + 2] = 1
elif token == "INST":
control_token_mask[0, i - 1 : i + 2] = 1
# TODO: the "" token corresponds to token 28705, a "blank" token that is used in JSON formatting.
# this is considered "control" and excluded from the window because it detracts from
# the saliency of the rest of the tokens in the window.
# Investigate if we can remove this after switching from cross entropy to PMI.
elif token in tokenizer.all_special_tokens + [""]:
control_token_mask[0, i] = 1
prev_token = token
return control_token_mask
def get_logits(
input_ids = torch.tensor(input_ids).to(_model.device)
# Reason for using context_length-1 instead of context_length:
# The last token of the context predicts the first token of the last turn, so we need to include it in the input
# if we want the loss to account for all tokens in the last turn.
last_turn_input_ids = input_ids[:, context_length-1:]
attention_mask = torch.ones_like(input_ids)
if mask_windows is None:
mask_windows = []
if len(mask_windows) > 0:
if control_token_mask is None:
raise ValueError(
"control_token_mask must be provided if mask_windows are provided"
control_token_mask = torch.tensor(control_token_mask).to(_model.device)
for mask_window_start, mask_window_end in mask_windows:
attention_mask[:, mask_window_start:mask_window_end] = control_token_mask[
:, mask_window_start:mask_window_end
with torch.no_grad():
logits = _model(
return logits, last_turn_input_ids
def get_loss(
logits, labels = get_logits(
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# set all labels to -100 if they fall outside the answer bounds
if answer_bounds is not None:
shift_labels = shift_labels.clone()
bool_mask = torch.ones_like(shift_labels, dtype=torch.bool)
bool_mask[..., answer_bounds[0]:answer_bounds[1]] = False
shift_labels[bool_mask] = -100
# Flatten the tokens
shift_logits = shift_logits.view(-1, _model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels =
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
return loss
def get_input_ids(tokenizer, dialogue_history, return_tensors="np"):
messages = get_chat_messages(dialogue_history)
context_ids = tokenizer.apply_chat_template(messages[:-1], return_tensors=return_tensors)
context_length = context_ids.shape[-1]
input_ids = tokenizer.apply_chat_template(messages, return_tensors=return_tensors)
return input_ids, context_length
def get_answer_with_token_bounds(tokenizer, dialogue_history):
input_ids, context_length = get_input_ids(tokenizer, dialogue_history)
answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[0, context_length:])
answer_token_bounds = []
start_pos = 0
for token in answer_tokens:
answer_token_bounds.append((start_pos, start_pos+len(token)))
start_pos += len(token)
answer_text = "".join(answer_tokens).replace("▁", " ")
return answer_text, answer_token_bounds
def get_answer_bounds_selections(highlight_result, answer_token_bounds):
if highlight_result is None:
highlight_result = []
selections = []
for selected_text in highlight_result:
if not selected_text:
selected_start, selected_end, selected_label = selected_text[0]["start"], selected_text[0]["end"], selected_text[0]["label"]
token_start = [i for i, (start, end) in enumerate(answer_token_bounds) if start <= selected_start < end][0]
token_end = [i for i, (start, end) in enumerate(answer_token_bounds) if start < selected_end <= end][0]
answer_bounds = (token_start, token_end+1)
selections.append((selected_label, answer_bounds))
return selections
def create_spans(citation_span_padding, texts, text_rel_loss_bins):
citation_spans = []
citation_span_texts = []
padding_counter = 0
span_start = None
max_bin = max(text_rel_loss_bins)
for i, rel_loss_bin in enumerate(text_rel_loss_bins):
if rel_loss_bin == max_bin:
if span_start is None:
span_start = max(i - citation_span_padding, 0)
padding_counter = 0
elif span_start is not None:
if padding_counter == citation_span_padding or i == len(text_rel_loss_bins) - 1:
span_end = i + 1 if i == len(text_rel_loss_bins) - 1 else i
citation_spans.append((span_start, span_end))
citation_span_texts.append(" ".join(texts[span_start:span_end]))
span_start = None
padding_counter += 1
return citation_spans, citation_span_texts
def discretize_scores(scores, n_bins=3):
"""Discretize losses into bins."""
kbins = KBinsDiscretizer(n_bins=n_bins, encode="ordinal", strategy="uniform")
bins = kbins.fit_transform(np.expand_dims(scores, 1)).squeeze(1).astype(int)
return bins
def generate_citations_logic(tokenizer, model, device, input_ids, context_length, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, past_key_values=None, answer_bounds=None):
control_token_mask = get_control_token_mask(tokenizer, input_ids)
token_window_ids = [[] for _ in range(context_length)]
window_losses = []
window_texts = []
window_ranges = []
if past_key_values is None:
with torch.no_grad():
inp_ids = torch.tensor(input_ids).to(device)
attention_mask = torch.ones_like(inp_ids)
past_key_values = model(
inp_ids[:, :context_length-1],
attention_mask=attention_mask[:, :context_length-1],
base_loss = get_loss(model, past_key_values, input_ids, context_length, answer_bounds=answer_bounds)
for start in range(0, context_length - overlapping_tokens, window_size - overlapping_tokens):
end = min(start + window_size, context_length)
text = tokenizer.decode(input_ids[0, start:end], skip_special_tokens=False)
for i in range(start, end):
loss = get_loss(
[(start, end)],
window_ranges.append((start, end))
window_losses = torch.stack(window_losses)
rel_losses = window_losses - base_loss
if normalize_citation_scores:
rel_losses /= rel_losses.abs().max()
token_rel_losses = [rel_losses[token_window_ids[i]].mean().item() for i in range(context_length)]
texts = []
text_rel_losses = []
text_start = 0
last_rel_loss = None
for i, rel_loss in enumerate(token_rel_losses):
if (last_rel_loss is not None and rel_loss != last_rel_loss) or i == len(token_rel_losses) - 1:
text_end = i + 1 if i == len(token_rel_losses) - 1 else i
text = tokenizer.decode(input_ids[0, text_start : text_end])
text_rel_loss = rel_loss if i == len(token_rel_losses)-1 else last_rel_loss
text_start = i
last_rel_loss = rel_loss
text_rel_losses = np.array(text_rel_losses)
abs_rel_losses = np.abs(text_rel_losses)
is_positive = text_rel_losses >= 0
pos_rel_losses = abs_rel_losses[is_positive]
neg_rel_losses = abs_rel_losses[~is_positive]
pos_bins = discretize_scores(pos_rel_losses)
neg_bins = discretize_scores(neg_rel_losses)
text_rel_loss_bins = np.zeros(len(text_rel_losses), dtype=int)
text_rel_loss_bins[is_positive] = pos_bins
citation_spans, citation_span_texts = create_spans(
citation_span_padding, texts, text_rel_loss_bins
text_rel_loss_bins[is_positive] = 0
text_rel_loss_bins[~is_positive] = neg_bins
contradiction_spans, contradiction_span_texts = create_spans(
citation_span_padding, texts, text_rel_loss_bins
return texts, text_rel_losses, window_texts, window_ranges, rel_losses, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans
def generate_inline_citations(context: str, answer_text: str, tokenizer, model, device, dialogue_history) -> str:
# Sentence tokenize the answer_text
sentences = nltk.sent_tokenize(answer_text)
input_ids, context_length = get_input_ids(tokenizer, dialogue_history)
annotated_sentences = []
current_position = 0
# Extract source spans from the context
source_pattern = r'\[( *\d+ *)\] Source:'
source_matches = list(re.finditer(source_pattern, context))
source_spans = [(int(, m.start()) for m in source_matches]
context_end = context.find('```') if '```' in context else len(context)
for sent_index, sentence in enumerate(sentences):
# Tokenize the sentence
sentence_tokens = tokenizer.encode(sentence)
answer_bounds = (current_position, current_position + len(sentence_tokens))
# Generate citation and contradiction spans for this sentence
texts, _, _, _, _, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans = generate_citations_logic(
tokenizer, model, device, input_ids, context_length,
window_size=7, overlapping_tokens=2, citation_span_padding=3,
citation_span_indicies = []
contradiction_span_indicies = []
# Convert textspans to context_index
for span_start, span_end in citation_spans:
txt = " ".join(texts[span_start:span_end])
start = context.find(txt)
end = start + len(txt)
citation_span_indicies.append((start, end))
for span_start, span_end in contradiction_spans:
txt = " ".join(texts[span_start:span_end])
start = context.find(txt)
end = start + len(txt)
contradiction_span_indicies.append((start, end))
# Convert spans to dicts and add answer_sent_index
citation_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in citation_span_indicies]
contradiction_dicts = [{'start': span[0], 'end': span[1], 'answer_sent_index': sent_index} for span in contradiction_span_indicies]
# Map spans to sources
def map_span_to_source(span):
span_start = span['start']
for i, (source_num, source_start) in enumerate(source_spans):
if i == len(source_spans) - 1:
if source_start <= span_start < context_end:
return source_num
elif source_start <= span_start < source_spans[i+1][1]:
return source_num
return "unknown"
for span in citation_dicts + contradiction_dicts:
span['source'] = map_span_to_source(span)
# Annotate the sentence
annotated_sentence = annotate_sentence(sentence, citation_dicts, contradiction_dicts, sent_index)
current_position += len(sentence_tokens)
return " ".join(annotated_sentences)
def find_conflicting_spans(citations: List[Dict], contradictions: List[Dict]) -> List[Dict]:
conflicts = []
for citation in citations:
for contradiction in contradictions:
if spans_overlap(citation, contradiction):
'span': get_overlap(citation, contradiction),
'citation': citation,
'contradiction': contradiction
return conflicts
def spans_overlap(span1: Dict, span2: Dict) -> bool:
return span1['start'] < span2['end'] and span2['start'] < span1['end']
def get_overlap(span1: Dict, span2: Dict) -> Tuple[int, int]:
return (max(span1['start'], span2['start']), min(span1['end'], span2['end']))
def annotate_sentence(sentence: str, citation_spans: List[Dict], contradiction_spans: List[Dict], sent_index: int) -> str:
# Filter spans for the current sentence
relevant_citations = [span for span in citation_spans if span['answer_sent_index'] == sent_index]
relevant_contradictions = [span for span in contradiction_spans if span['answer_sent_index'] == sent_index]
# Collect all annotations
annotations = []
for span in relevant_citations:
if span['source'] == "unknown":
# Handle conflicts
for con_span in relevant_contradictions:
# conflicting_citations = [cit for cit in relevant_citations if spans_overlap(cit, con_span)]
# for cit in conflicting_citations:
# cit_num = cit['source']
con_num = con_span['source']
# if cit_num == "unknown" or con_num == "unknown":
if con_span['source'] == "unknown":
# annotations.append(f"[{cit_num}][conflicts with [{con_num}]?]")
annotations.append(f"[conflicts with [{con_num}]?]")
# Remove duplicates and sort
annotations = sorted(set(annotations))
# Check for .</s> at the end of the sentence
eos_match ='\.(\</s\>)?$', sentence)
if eos_match:
eos_start = eos_match.start()
stripped_sentence = sentence[:eos_start]
eos_token = sentence[eos_start:]
stripped_sentence = sentence.rstrip()
eos_token = ''
# Apply annotations
if annotations:
result = stripped_sentence + ' ' + ' '.join(annotations)
result = stripped_sentence + ' [citation needed]'
# Add the eos_token back
if eos_token:
result += eos_token
return result