Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/citation_demo.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
191 lines (159 sloc)
7.93 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import torch | |
import argparse | |
import html | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.trainer_utils import set_seed | |
from streamlit_extras.word_importances import format_word_importances | |
from annotated_text import annotated_text | |
from streamlit_annotation_tools import text_highlighter | |
import utils | |
from examples import EXAMPLES | |
st.set_page_config(page_title="Citation Demo", layout="wide") | |
class streamlit_page_state: | |
def __init__(self): | |
self.dirty=False | |
def set_dirty(self): | |
self.dirty=True | |
def check_rerun(self): | |
if self.dirty: | |
self.dirty=False | |
st.rerun() | |
@st.cache_resource | |
def get_model(_args): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(_args.model) | |
model = AutoModelForCausalLM.from_pretrained( | |
_args.model, torch_dtype=torch.float16 | |
).to(device) | |
return tokenizer, model, device | |
@st.cache_data | |
def generate_citations(dialogue_history, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, answer_bounds=None): | |
tokenizer, model, device = get_model(args) | |
input_ids, context_length = utils.get_input_ids(tokenizer, dialogue_history) | |
with torch.no_grad(): | |
inp_ids = torch.tensor(input_ids).to(device) | |
attention_mask = torch.ones_like(inp_ids) | |
past_key_values = model( | |
inp_ids[:, :context_length-1], | |
attention_mask=attention_mask[:, :context_length-1], | |
return_dict=True, | |
).past_key_values | |
return utils.generate_citations_logic( | |
tokenizer, | |
model, | |
device, | |
input_ids, | |
context_length, | |
window_size, | |
overlapping_tokens, | |
citation_span_padding, | |
normalize_citation_scores, | |
past_key_values, | |
answer_bounds, | |
) | |
def show_citation_spans(citation_spans, citation_span_texts, texts, color=None): | |
annotated_texts = [] | |
last_span_end = 0 | |
for i, (span, span_text) in enumerate(zip(citation_spans, citation_span_texts)): | |
span_start, span_end = span | |
if span_start > last_span_end: | |
annotated_texts.append(" ".join(texts[last_span_end:span_start]).replace("```", "")) | |
annotated_texts.append((span_text.replace("```", ""), str(i+1), color)) | |
last_span_end = span_end | |
if last_span_end < len(texts): | |
annotated_texts.append(" ".join(texts[last_span_end:]).replace("```", "")) | |
annotated_text(annotated_texts) | |
def run(args): | |
if "page_state" not in st.session_state: | |
st.session_state.page_state = streamlit_page_state() | |
if "dialogue_history" not in st.session_state: | |
st.session_state.dialogue_history = "user: Hello, how are you?\nassistant: I'm fine, thank you. How can I help you today?" | |
tokenizer, model, device = get_model(args) | |
hist_container = st.empty() | |
st.session_state.dialogue_history = hist_container.text_area( | |
"Dialogue history", | |
st.session_state.dialogue_history, | |
height=400, | |
on_change=st.session_state.page_state.set_dirty | |
) | |
st.session_state.page_state.check_rerun() | |
tab_settings, tab_examples = st.sidebar.tabs(["Settings", "Examples"]) | |
with tab_settings: | |
if st.button("Generate"): | |
messages = utils.get_chat_messages(st.session_state.dialogue_history) | |
if len(messages) > 0: | |
input_ids , _ = utils.get_input_ids(tokenizer, st.session_state.dialogue_history, return_tensors="pt") | |
set_seed(42) | |
outputs = model.generate( | |
input_ids.to(device), | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
generated = tokenizer.decode( | |
outputs[0, input_ids.shape[-1] :], skip_special_tokens=True | |
) | |
speaker = "assistant" if messages[-1]["role"] == "user" else "user" | |
st.session_state.dialogue_history += f"\n{speaker}: {generated.strip()}" | |
st.rerun() | |
window_size = st.slider("Window size (tokens)", 1, 10, 7) | |
if window_size > 1: | |
overlapping_tokens = st.slider("Overlapping tokens", 0, window_size-1, min(window_size-1, 2)) | |
else: | |
overlapping_tokens = 0 | |
citation_span_padding = st.slider("Citation span padding (# of contiguous salient blocks)", 0, 10, 3) | |
normalize_citation_scores = st.toggle("Normalize citation scores", value=True) | |
with tab_examples: | |
example = st.selectbox("Select an example", list(EXAMPLES.keys())) | |
if st.button("Use example"): | |
st.session_state.dialogue_history = EXAMPLES[example] | |
st.rerun() | |
if not st.session_state.dialogue_history: | |
st.warning("Please enter something into the dialogue history box.") | |
st.stop() | |
answer_text, answer_token_bounds = utils.get_answer_with_token_bounds(tokenizer, st.session_state.dialogue_history) | |
highlight_result = text_highlighter(answer_text) | |
selections = utils.get_answer_bounds_selections(highlight_result, answer_token_bounds) | |
if selections: | |
selection = st.radio("Cite sources for answer span:", selections, format_func=lambda x: x[0]) | |
_, answer_bounds = selection | |
texts, text_rel_losses, window_texts, window_ranges, rel_losses, citation_span_texts, citation_spans, contradiction_span_texts, contradiction_spans = generate_citations( | |
st.session_state.dialogue_history, window_size, overlapping_tokens, citation_span_padding, normalize_citation_scores, answer_bounds | |
) | |
tab_saliency, tab_citations, tab_contradictions, tab_inline_citations, tab_details = st.tabs(["Saliency", "Citations", "Contradictions", "Inline Citations", "Details"]) | |
with tab_saliency: | |
escaped_texts = [html.escape(text) for text in texts] | |
imp_html = format_word_importances(words=escaped_texts, importances=text_rel_losses) | |
st.write(imp_html, unsafe_allow_html=True) | |
with tab_citations: | |
show_citation_spans(citation_spans, citation_span_texts, texts) | |
with tab_contradictions: | |
show_citation_spans(contradiction_spans, contradiction_span_texts, texts, "#FA8072") | |
with tab_inline_citations: | |
context = " ".join(texts) # Assuming the context is the concatenation of all text segments | |
annotated_answer = utils.generate_inline_citations( | |
context, | |
answer_text, | |
tokenizer, | |
model, | |
device, | |
st.session_state.dialogue_history | |
) | |
st.write(annotated_answer) | |
with tab_details: | |
st.markdown("#### Context masking windows") | |
st.dataframe({"Window text": window_texts, "Window range": window_ranges, "Relative loss": rel_losses.tolist()}, use_container_width=True) | |
st.markdown("#### Contiguous blocks of text grouped by saliency score") | |
st.dataframe({"Text": texts, "Relative loss": text_rel_losses}, use_container_width=True) | |
st.markdown("#### Citation spans surrounding the most salient text blocks") | |
st.dataframe({"Citation text": citation_span_texts, "Span range": citation_spans}, use_container_width=True) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model", type=str, default="mistralai/Mistral-7B-Instruct-v0.1" | |
) | |
args = parser.parse_args() | |
run(args) |