Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import streamlit as st
import torch
import argparse
import html
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.trainer_utils import set_seed
from streamlit_extras.word_importances import format_word_importances
from annotated_text import annotated_text
from streamlit_annotation_tools import text_highlighter
import utils
from examples import EXAMPLES
st.set_page_config(page_title="Citation Demo", layout="wide")
class streamlit_page_state:
def __init__(self):
self.dirty=False
def set_dirty(self):
self.dirty=True
def check_rerun(self):
if self.dirty:
self.dirty=False
st.rerun()
@st.cache_resource
def get_model(_args):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(_args.model)
model = AutoModelForCausalLM.from_pretrained(
_args.model, torch_dtype=torch.float16
).to(device)
return tokenizer, model, device
@st.cache_data
def generate_citations(dialogue_history, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, answer_bounds=None):
tokenizer, model, device = get_model(args)
input_ids, context_length = utils.get_input_ids(tokenizer, dialogue_history)
with torch.no_grad():
inp_ids = torch.tensor(input_ids).to(device)
attention_mask = torch.ones_like(inp_ids)
past_key_values = model(
inp_ids[:, :context_length-1],
attention_mask=attention_mask[:, :context_length-1],
return_dict=True,
).past_key_values
return utils.generate_citations_logic(
tokenizer,
model,
device,
input_ids,
context_length,
window_size,
overlapping_tokens,
citation_span_padding,
normalize_citation_scores,
past_key_values,
answer_bounds,
)
def show_citation_spans(citation_spans, citation_span_texts, texts, color=None):
annotated_texts = []
last_span_end = 0
for i, (span, span_text) in enumerate(zip(citation_spans, citation_span_texts)):
span_start, span_end = span
if span_start > last_span_end:
annotated_texts.append(" ".join(texts[last_span_end:span_start]).replace("```", ""))
annotated_texts.append((span_text.replace("```", ""), str(i+1), color))
last_span_end = span_end
if last_span_end < len(texts):
annotated_texts.append(" ".join(texts[last_span_end:]).replace("```", ""))
annotated_text(annotated_texts)
def run(args):
if "page_state" not in st.session_state:
st.session_state.page_state = streamlit_page_state()
if "dialogue_history" not in st.session_state:
st.session_state.dialogue_history = "user: Hello, how are you?\nassistant: I'm fine, thank you. How can I help you today?"
tokenizer, model, device = get_model(args)
hist_container = st.empty()
st.session_state.dialogue_history = hist_container.text_area(
"Dialogue history",
st.session_state.dialogue_history,
height=400,
on_change=st.session_state.page_state.set_dirty
)
st.session_state.page_state.check_rerun()
tab_settings, tab_examples = st.sidebar.tabs(["Settings", "Examples"])
with tab_settings:
if st.button("Generate"):
messages = utils.get_chat_messages(st.session_state.dialogue_history)
if len(messages) > 0:
input_ids , _ = utils.get_input_ids(tokenizer, st.session_state.dialogue_history, return_tensors="pt")
set_seed(42)
outputs = model.generate(
input_ids.to(device),
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
outputs[0, input_ids.shape[-1] :], skip_special_tokens=True
)
speaker = "assistant" if messages[-1]["role"] == "user" else "user"
st.session_state.dialogue_history += f"\n{speaker}: {generated.strip()}"
st.rerun()
window_size = st.slider("Window size (tokens)", 1, 10, 7)
if window_size > 1:
overlapping_tokens = st.slider("Overlapping tokens", 0, window_size-1, min(window_size-1, 2))
else:
overlapping_tokens = 0
citation_span_padding = st.slider("Citation span padding (# of contiguous salient blocks)", 0, 10, 3)
normalize_citation_scores = st.toggle("Normalize citation scores", value=True)
with tab_examples:
example = st.selectbox("Select an example", list(EXAMPLES.keys()))
if st.button("Use example"):
st.session_state.dialogue_history = EXAMPLES[example]
st.rerun()
if not st.session_state.dialogue_history:
st.warning("Please enter something into the dialogue history box.")
st.stop()
answer_text, answer_token_bounds = utils.get_answer_with_token_bounds(tokenizer, st.session_state.dialogue_history)
highlight_result = text_highlighter(answer_text)
selections = utils.get_answer_bounds_selections(highlight_result, answer_token_bounds)
if selections:
selection = st.radio("Cite sources for answer span:", selections, format_func=lambda x: x[0])
_, answer_bounds = selection
texts, text_rel_losses, window_texts, window_ranges, rel_losses, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans = generate_citations(
st.session_state.dialogue_history, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, answer_bounds
)
tab_saliency, tab_citations, tab_contradictions, tab_inline_citations, tab_details = st.tabs(["Saliency", "Citations", "Contradictions", "Inline Citations", "Details"])
with tab_saliency:
escaped_texts = [html.escape(text) for text in texts]
imp_html = format_word_importances(words=escaped_texts, importances=text_rel_losses)
st.write(imp_html, unsafe_allow_html=True)
with tab_citations:
show_citation_spans(citation_spans, citation_span_texts, texts)
with tab_contradictions:
show_citation_spans(contradiction_spans, contradiction_span_texts, texts, "#FA8072")
with tab_inline_citations:
context = " ".join(texts) # Assuming the context is the concatenation of all text segments
annotated_answer = utils.generate_inline_citations(
context,
answer_text,
tokenizer,
model,
device,
st.session_state.dialogue_history
)
st.write(annotated_answer)
with tab_details:
st.markdown("#### Context masking windows")
st.dataframe({"Window text": window_texts, "Window range": window_ranges, "Relative loss": rel_losses.tolist()}, use_container_width=True)
st.markdown("#### Contiguous blocks of text grouped by saliency score")
st.dataframe({"Text": texts, "Relative loss": text_rel_losses}, use_container_width=True)
st.markdown("#### Citation spans surrounding the most salient text blocks")
st.dataframe({"Citation text": citation_span_texts, "Span range": citation_spans}, use_container_width=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mistral-7B-Instruct-v0.1"
)
args = parser.parse_args()
run(args)