Skip to content
Permalink
4d623e0e9d
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
198 lines (177 sloc) 8.23 KB
import streamlit as st
import argparse
import json
import nltk
import os
import re
from streamlit_annotation_tools import text_highlighter
from streamlit_extras.row import row
from utils.answer_helpers import LIST_INDEXER_REGEX
st.set_page_config(page_title="Span Annotator", layout="wide")
@st.cache_resource
def load_eli5(eli5_path, ans_path=None):
with open(eli5_path, "r", encoding="utf-8") as f:
data = json.load(f)
if ans_path is not None:
with open(ans_path, "r", encoding="utf-8") as f:
ans_data = json.load(f)
if len(ans_data) != len(data):
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match")
for example, ans in zip(data, ans_data):
example["answer"] = ans
return data
@st.cache_resource
def load_annotations(eli5_path, ann_path, ans_path=None):
if not os.path.exists(ann_path):
eli5_data = load_eli5(eli5_path, ans_path)
empty_ann_data = [
{
"answer_sentences": [],
"span_annotation_done": False
}
for _ in eli5_data
]
save_annotations(ann_path, empty_ann_data)
with open(ann_path, "r", encoding="utf-8") as f:
ann_data = json.load(f)
return ann_data
def save_annotations(ann_path, ann_data):
with open(ann_path, "w", encoding="utf-8") as f:
json.dump(ann_data, f, indent=4)
add_status("Data saved!", "✅")
def count_annotations(answer_sentences):
n_sent_citation = []
n_sent_conflict = []
n_sent_ann_docs = []
for sent in answer_sentences:
n_ann = {
"citation_spans": 0,
"conflict_spans": 0
}
ann_docs = set()
for span_type in ["citation_spans", "conflict_spans"]:
for i, doc_spans in enumerate(sent[span_type]):
for ann_set in doc_spans:
ann_set_len = len(ann_set)
if ann_set_len > 0:
n_ann[span_type] += ann_set_len
ann_docs.add(i)
n_sent_citation.append(n_ann["citation_spans"])
n_sent_conflict.append(n_ann["conflict_spans"])
n_sent_ann_docs.append(len(ann_docs))
return n_sent_citation, n_sent_conflict, n_sent_ann_docs
def count_done(ann_data):
n_done = 0
for data in ann_data:
if data.get("span_annotation_done", False):
n_done += 1
return n_done
def add_status(body, icon="ℹ️"):
statuses = st.session_state.get("statuses", [])
statuses.append({"body": body, "icon": icon})
st.session_state.statuses = statuses
def show_status():
statuses = st.session_state.get("statuses", [])
for status in statuses:
st.toast(**status)
st.session_state.statuses = []
def run(args):
# eli5_data and ann_data will always be the same length
eli5_data = load_eli5(args.eli5_path, args.eli5_ans_path)
ann_data = load_annotations(args.eli5_path, args.ann_path, args.eli5_ans_path)
if "current_idx" not in st.session_state:
st.session_state.current_idx = 0
if "current_sent_idx" not in st.session_state:
st.session_state.current_sent_idx = 0
with st.sidebar:
st.sidebar.title("✍️ Span Annotator")
nav_row = row([1, 1])
current_idx = st.session_state.current_idx
if nav_row.button("⬅️ Prev"):
current_idx = max(0, st.session_state.current_idx - 1)
if nav_row.button("Next ➡️"):
current_idx = min(len(eli5_data) - 1, st.session_state.current_idx + 1)
current_idx = st.number_input("Go to question", value=current_idx+1, min_value=1, max_value=len(eli5_data), step=1)-1
if current_idx != st.session_state.current_idx:
st.session_state.current_idx = current_idx
st.session_state.current_sent_idx = 0
st.rerun()
current_data = eli5_data[st.session_state.current_idx]
current_ann_data = ann_data[st.session_state.current_idx]
is_done = current_ann_data.get("span_annotation_done", False)
if is_done and st.button("❌ Mark Not Done"):
current_ann_data["span_annotation_done"] = False
save_annotations(args.ann_path, ann_data)
st.rerun()
if not is_done and st.button("✅ Mark Done"):
current_ann_data["span_annotation_done"] = True
save_annotations(args.ann_path, ann_data)
st.rerun()
n_done = count_done(ann_data)
st.markdown(f"### 📊 Progress: {n_done}/{len(ann_data)}")
status = "✅ Done" if is_done else "❌ Not Done"
st.markdown(f"#### 🤔 Question #{st.session_state.current_idx+1} / {len(eli5_data)} ({status})")
if current_data["question_ctx"] and current_data["question_ctx"] != "[removed]":
with st.expander("Question Context", expanded=True):
st.markdown(current_data["question_ctx"])
with st.expander("❓ Question", expanded=True):
st.markdown(current_data["question"])
with st.expander("💡 Answer Sentences", expanded=True):
if not current_ann_data.get("answer_sentences", []):
current_ann_data["answer_sentences"] = [
{
"sentence": sent,
"citation_spans": [[] for _ in current_data["docs"]],
"conflict_spans": [[] for _ in current_data["docs"]],
}
for sent in nltk.sent_tokenize(current_data["answer"])
if not re.match(LIST_INDEXER_REGEX, sent)
]
answer_sentences = current_ann_data["answer_sentences"]
n_sent_citation, n_sent_conflict, n_sent_ann_docs = count_annotations(answer_sentences)
n_docs = len(current_data["docs"])
current_sent_idx = st.radio(
"Select a sentence",
range(len(answer_sentences)),
index=st.session_state.current_sent_idx,
format_func=lambda i: f"**[{n_sent_citation[i]} citations, {n_sent_conflict[i]} conflicts in {n_sent_ann_docs[i]}/{n_docs} docs]** {answer_sentences[i]['sentence']}"
)
if current_sent_idx != st.session_state.current_sent_idx:
st.session_state.current_sent_idx = current_sent_idx
st.rerun()
selected_sent = answer_sentences[st.session_state.current_sent_idx]
ann_header = st.container()
with st.expander("Documents", expanded=True):
span_type = st.radio(
"Span type",
["citation_spans", "conflict_spans"],
horizontal=True,
format_func=lambda x: x.split("_")[0].capitalize()
)
span_icon = "📜" if span_type == "citation_spans" else "🔥"
ann_header.markdown(f"#### {span_icon} Annotating: {span_type.split('_')[0].capitalize()} Spans")
for i, doc in enumerate(current_data["docs"]):
title_row = row([4, 1])
title_row.markdown(f"**[{i+1}] {doc['title']}**")
if span_type == "citation_spans":
claims = [current_data["claims"][j] for j, ind in enumerate(doc["answers_found"]) if ind == 1]
with title_row.popover(f"Supports {len(claims)} claims"):
for claim in claims:
st.markdown(f"- {claim}")
spans = text_highlighter(doc["text"], selected_sent[span_type][i])
spans = [] if spans == [[]] else spans
if spans is not None and spans != selected_sent[span_type][i]:
selected_sent[span_type][i] = spans
save_annotations(args.ann_path, ann_data)
st.rerun()
# Finally, show any status messages
show_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
parser.add_argument("--eli5_ans_path", type=str, help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", default=None)
parser.add_argument("--ann_path", type=str, help="Path to the span annotations JSON file", default=None)
args = parser.parse_args()
if args.ann_path is None:
args.ann_path = os.path.splitext(args.eli5_path)[0] + "_spans.json"
run(args)