import streamlit as st
import argparse
import json
import nltk
import os
import re
from streamlit_annotation_tools import text_highlighter
from streamlit_extras.row import row
from utils.answer_helpers import LIST_INDEXER_REGEX
st.set_page_config(page_title="Span Annotator", layout="wide")
def load_eli5(eli5_path, ans_path=None):
with open(eli5_path, "r", encoding="utf-8") as f:
data = json.load(f)
if ans_path is not None:
with open(ans_path, "r", encoding="utf-8") as f:
ans_data = json.load(f)
if len(ans_data) != len(data):
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match")
for example, ans in zip(data, ans_data):
example["answer"] = ans
return data
def load_annotations(eli5_path, ann_path, ans_path=None):
if not os.path.exists(ann_path):
eli5_data = load_eli5(eli5_path, ans_path)
empty_ann_data = [
"answer_sentences": [],
"span_annotation_done": False
for _ in eli5_data
save_annotations(ann_path, empty_ann_data)
with open(ann_path, "r", encoding="utf-8") as f:
ann_data = json.load(f)
return ann_data
def save_annotations(ann_path, ann_data):
with open(ann_path, "w", encoding="utf-8") as f:
json.dump(ann_data, f, indent=4)
add_status("Data saved!", "✅")
def count_annotations(answer_sentences):
n_sent_citation = []
n_sent_conflict = []
n_sent_ann_docs = []
for sent in answer_sentences:
n_ann = {
"citation_spans": 0,
"conflict_spans": 0
ann_docs = set()
for span_type in ["citation_spans", "conflict_spans"]:
for i, doc_spans in enumerate(sent[span_type]):
for ann_set in doc_spans:
ann_set_len = len(ann_set)
if ann_set_len > 0:
n_ann[span_type] += ann_set_len
return n_sent_citation, n_sent_conflict, n_sent_ann_docs
def count_done(ann_data):
n_done = 0
for data in ann_data:
if data.get("span_annotation_done", False):
n_done += 1
return n_done
def add_status(body, icon="ℹ️"):
statuses = st.session_state.get("statuses", [])
statuses.append({"body": body, "icon": icon})
st.session_state.statuses = statuses
def show_status():
statuses = st.session_state.get("statuses", [])
for status in statuses:
st.session_state.statuses = []
def run(args):
# eli5_data and ann_data will always be the same length
eli5_data = load_eli5(args.eli5_path, args.eli5_ans_path)
ann_data = load_annotations(args.eli5_path, args.ann_path, args.eli5_ans_path)
if "current_idx" not in st.session_state:
st.session_state.current_idx = 0
if "current_sent_idx" not in st.session_state:
st.session_state.current_sent_idx = 0
with st.sidebar:
st.sidebar.title("✍️ Span Annotator")
nav_row = row([1, 1])
current_idx = st.session_state.current_idx
if nav_row.button("⬅️ Prev"):
current_idx = max(0, st.session_state.current_idx - 1)
if nav_row.button("Next ➡️"):
current_idx = min(len(eli5_data) - 1, st.session_state.current_idx + 1)
current_idx = st.number_input("Go to question", value=current_idx+1, min_value=1, max_value=len(eli5_data), step=1)-1
if current_idx != st.session_state.current_idx:
st.session_state.current_idx = current_idx
st.session_state.current_sent_idx = 0
current_data = eli5_data[st.session_state.current_idx]
current_ann_data = ann_data[st.session_state.current_idx]
is_done = current_ann_data.get("span_annotation_done", False)
if is_done and st.button("❌ Mark Not Done"):
current_ann_data["span_annotation_done"] = False
save_annotations(args.ann_path, ann_data)
if not is_done and st.button("✅ Mark Done"):
current_ann_data["span_annotation_done"] = True
save_annotations(args.ann_path, ann_data)
n_done = count_done(ann_data)
st.markdown(f"### 📊 Progress: {n_done}/{len(ann_data)}")
status = "✅ Done" if is_done else "❌ Not Done"
st.markdown(f"#### 🤔 Question #{st.session_state.current_idx+1} / {len(eli5_data)} ({status})")
if current_data["question_ctx"] and current_data["question_ctx"] != "[removed]":
with st.expander("Question Context", expanded=True):
with st.expander("❓ Question", expanded=True):
with st.expander("💡 Answer Sentences", expanded=True):
if not current_ann_data.get("answer_sentences", []):
current_ann_data["answer_sentences"] = [
"sentence": sent,
"citation_spans": [[] for _ in current_data["docs"]],
"conflict_spans": [[] for _ in current_data["docs"]],
for sent in nltk.sent_tokenize(current_data["answer"])
if not re.match(LIST_INDEXER_REGEX, sent)
answer_sentences = current_ann_data["answer_sentences"]
n_sent_citation, n_sent_conflict, n_sent_ann_docs = count_annotations(answer_sentences)
n_docs = len(current_data["docs"])
current_sent_idx =
"Select a sentence",
format_func=lambda i: f"**[{n_sent_citation[i]} citations, {n_sent_conflict[i]} conflicts in {n_sent_ann_docs[i]}/{n_docs} docs]** {answer_sentences[i]['sentence']}"
if current_sent_idx != st.session_state.current_sent_idx:
st.session_state.current_sent_idx = current_sent_idx
selected_sent = answer_sentences[st.session_state.current_sent_idx]
ann_header = st.container()
with st.expander("Documents", expanded=True):
span_type =
"Span type",
["citation_spans", "conflict_spans"],
format_func=lambda x: x.split("_")[0].capitalize()
span_icon = "📜" if span_type == "citation_spans" else "🔥"
ann_header.markdown(f"#### {span_icon} Annotating: {span_type.split('_')[0].capitalize()} Spans")
for i, doc in enumerate(current_data["docs"]):
title_row = row([4, 1])
title_row.markdown(f"**[{i+1}] {doc['title']}**")
if span_type == "citation_spans":
claims = [current_data["claims"][j] for j, ind in enumerate(doc["answers_found"]) if ind == 1]
with title_row.popover(f"Supports {len(claims)} claims"):
for claim in claims:
st.markdown(f"- {claim}")
spans = text_highlighter(doc["text"], selected_sent[span_type][i])
spans = [] if spans == [[]] else spans
if spans is not None and spans != selected_sent[span_type][i]:
selected_sent[span_type][i] = spans
save_annotations(args.ann_path, ann_data)
# Finally, show any status messages
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
parser.add_argument("--eli5_ans_path", type=str, help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", default=None)
parser.add_argument("--ann_path", type=str, help="Path to the span annotations JSON file", default=None)
args = parser.parse_args()
if args.ann_path is None:
args.ann_path = os.path.splitext(args.eli5_path)[0] + "_spans.json"