Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/span_annotator.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
198 lines (177 sloc)
8.23 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import argparse | |
import json | |
import nltk | |
import os | |
import re | |
from streamlit_annotation_tools import text_highlighter | |
from streamlit_extras.row import row | |
from utils.answer_helpers import LIST_INDEXER_REGEX | |
st.set_page_config(page_title="Span Annotator", layout="wide") | |
@st.cache_resource | |
def load_eli5(eli5_path, ans_path=None): | |
with open(eli5_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
if ans_path is not None: | |
with open(ans_path, "r", encoding="utf-8") as f: | |
ans_data = json.load(f) | |
if len(ans_data) != len(data): | |
raise ValueError("Number of ELI5 questions and generated ELI5 answers do not match") | |
for example, ans in zip(data, ans_data): | |
example["answer"] = ans | |
return data | |
@st.cache_resource | |
def load_annotations(eli5_path, ann_path, ans_path=None): | |
if not os.path.exists(ann_path): | |
eli5_data = load_eli5(eli5_path, ans_path) | |
empty_ann_data = [ | |
{ | |
"answer_sentences": [], | |
"span_annotation_done": False | |
} | |
for _ in eli5_data | |
] | |
save_annotations(ann_path, empty_ann_data) | |
with open(ann_path, "r", encoding="utf-8") as f: | |
ann_data = json.load(f) | |
return ann_data | |
def save_annotations(ann_path, ann_data): | |
with open(ann_path, "w", encoding="utf-8") as f: | |
json.dump(ann_data, f, indent=4) | |
add_status("Data saved!", "✅") | |
def count_annotations(answer_sentences): | |
n_sent_citation = [] | |
n_sent_conflict = [] | |
n_sent_ann_docs = [] | |
for sent in answer_sentences: | |
n_ann = { | |
"citation_spans": 0, | |
"conflict_spans": 0 | |
} | |
ann_docs = set() | |
for span_type in ["citation_spans", "conflict_spans"]: | |
for i, doc_spans in enumerate(sent[span_type]): | |
for ann_set in doc_spans: | |
ann_set_len = len(ann_set) | |
if ann_set_len > 0: | |
n_ann[span_type] += ann_set_len | |
ann_docs.add(i) | |
n_sent_citation.append(n_ann["citation_spans"]) | |
n_sent_conflict.append(n_ann["conflict_spans"]) | |
n_sent_ann_docs.append(len(ann_docs)) | |
return n_sent_citation, n_sent_conflict, n_sent_ann_docs | |
def count_done(ann_data): | |
n_done = 0 | |
for data in ann_data: | |
if data.get("span_annotation_done", False): | |
n_done += 1 | |
return n_done | |
def add_status(body, icon="ℹ️"): | |
statuses = st.session_state.get("statuses", []) | |
statuses.append({"body": body, "icon": icon}) | |
st.session_state.statuses = statuses | |
def show_status(): | |
statuses = st.session_state.get("statuses", []) | |
for status in statuses: | |
st.toast(**status) | |
st.session_state.statuses = [] | |
def run(args): | |
# eli5_data and ann_data will always be the same length | |
eli5_data = load_eli5(args.eli5_path, args.eli5_ans_path) | |
ann_data = load_annotations(args.eli5_path, args.ann_path, args.eli5_ans_path) | |
if "current_idx" not in st.session_state: | |
st.session_state.current_idx = 0 | |
if "current_sent_idx" not in st.session_state: | |
st.session_state.current_sent_idx = 0 | |
with st.sidebar: | |
st.sidebar.title("✍️ Span Annotator") | |
nav_row = row([1, 1]) | |
current_idx = st.session_state.current_idx | |
if nav_row.button("⬅️ Prev"): | |
current_idx = max(0, st.session_state.current_idx - 1) | |
if nav_row.button("Next ➡️"): | |
current_idx = min(len(eli5_data) - 1, st.session_state.current_idx + 1) | |
current_idx = st.number_input("Go to question", value=current_idx+1, min_value=1, max_value=len(eli5_data), step=1)-1 | |
if current_idx != st.session_state.current_idx: | |
st.session_state.current_idx = current_idx | |
st.session_state.current_sent_idx = 0 | |
st.rerun() | |
current_data = eli5_data[st.session_state.current_idx] | |
current_ann_data = ann_data[st.session_state.current_idx] | |
is_done = current_ann_data.get("span_annotation_done", False) | |
if is_done and st.button("❌ Mark Not Done"): | |
current_ann_data["span_annotation_done"] = False | |
save_annotations(args.ann_path, ann_data) | |
st.rerun() | |
if not is_done and st.button("✅ Mark Done"): | |
current_ann_data["span_annotation_done"] = True | |
save_annotations(args.ann_path, ann_data) | |
st.rerun() | |
n_done = count_done(ann_data) | |
st.markdown(f"### 📊 Progress: {n_done}/{len(ann_data)}") | |
status = "✅ Done" if is_done else "❌ Not Done" | |
st.markdown(f"#### 🤔 Question #{st.session_state.current_idx+1} / {len(eli5_data)} ({status})") | |
if current_data["question_ctx"] and current_data["question_ctx"] != "[removed]": | |
with st.expander("Question Context", expanded=True): | |
st.markdown(current_data["question_ctx"]) | |
with st.expander("❓ Question", expanded=True): | |
st.markdown(current_data["question"]) | |
with st.expander("💡 Answer Sentences", expanded=True): | |
if not current_ann_data.get("answer_sentences", []): | |
current_ann_data["answer_sentences"] = [ | |
{ | |
"sentence": sent, | |
"citation_spans": [[] for _ in current_data["docs"]], | |
"conflict_spans": [[] for _ in current_data["docs"]], | |
} | |
for sent in nltk.sent_tokenize(current_data["answer"]) | |
if not re.match(LIST_INDEXER_REGEX, sent) | |
] | |
answer_sentences = current_ann_data["answer_sentences"] | |
n_sent_citation, n_sent_conflict, n_sent_ann_docs = count_annotations(answer_sentences) | |
n_docs = len(current_data["docs"]) | |
current_sent_idx = st.radio( | |
"Select a sentence", | |
range(len(answer_sentences)), | |
index=st.session_state.current_sent_idx, | |
format_func=lambda i: f"**[{n_sent_citation[i]} citations, {n_sent_conflict[i]} conflicts in {n_sent_ann_docs[i]}/{n_docs} docs]** {answer_sentences[i]['sentence']}" | |
) | |
if current_sent_idx != st.session_state.current_sent_idx: | |
st.session_state.current_sent_idx = current_sent_idx | |
st.rerun() | |
selected_sent = answer_sentences[st.session_state.current_sent_idx] | |
ann_header = st.container() | |
with st.expander("Documents", expanded=True): | |
span_type = st.radio( | |
"Span type", | |
["citation_spans", "conflict_spans"], | |
horizontal=True, | |
format_func=lambda x: x.split("_")[0].capitalize() | |
) | |
span_icon = "📜" if span_type == "citation_spans" else "🔥" | |
ann_header.markdown(f"#### {span_icon} Annotating: {span_type.split('_')[0].capitalize()} Spans") | |
for i, doc in enumerate(current_data["docs"]): | |
title_row = row([4, 1]) | |
title_row.markdown(f"**[{i+1}] {doc['title']}**") | |
if span_type == "citation_spans": | |
claims = [current_data["claims"][j] for j, ind in enumerate(doc["answers_found"]) if ind == 1] | |
with title_row.popover(f"Supports {len(claims)} claims"): | |
for claim in claims: | |
st.markdown(f"- {claim}") | |
spans = text_highlighter(doc["text"], selected_sent[span_type][i]) | |
spans = [] if spans == [[]] else spans | |
if spans is not None and spans != selected_sent[span_type][i]: | |
selected_sent[span_type][i] = spans | |
save_annotations(args.ann_path, ann_data) | |
st.rerun() | |
# Finally, show any status messages | |
show_status() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--eli5_path", type=str, help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json") | |
parser.add_argument("--eli5_ans_path", type=str, help="Path to the ELI5 generated answers JSON file. If not given, the original Reddit answers will be used.", default=None) | |
parser.add_argument("--ann_path", type=str, help="Path to the span annotations JSON file", default=None) | |
args = parser.parse_args() | |
if args.ann_path is None: | |
args.ann_path = os.path.splitext(args.eli5_path)[0] + "_spans.json" | |
run(args) |