Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
saliency-based-citation/compute_comparison.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
80 lines (72 sloc)
3.8 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import os | |
import json | |
import argparse | |
def load_json_data(json_file): | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
return data | |
def main(): | |
parser = argparse.ArgumentParser(description='Compute Metrics Comparison') | |
parser.add_argument("--output_path", type=str, help="Path where pipeline has output all results", default="attribution_results") | |
parser.add_argument("--floatfmt" , type=str, help="Float format for printing", default=".4f") | |
parser.add_argument("--show_all", action="store_true", help="Show all results including precision and recall") | |
parser.add_argument("--plot_ranges", action="store_true", help="Make range plots") | |
args = parser.parse_args() | |
# get all folders in the output path | |
config_names = sorted([f for f in os.listdir(args.output_path) if os.path.isdir(os.path.join(args.output_path, f))]) | |
llms = sorted(list(set(["_".join(config_name.split("_")[:2]) for config_name in config_names]))) | |
ablations = sorted(list(set(["_".join(config_name.split("_")[2:]) for config_name in config_names]))) | |
systems = ["prompt_based", "sliding_window", "gradient_based"] | |
for span_type in ["citation", "conflict"]: | |
results = [] | |
for system in systems: | |
for ablation in ablations: | |
if system == "prompt_based" and ablation != "": | |
continue | |
result = {} | |
result["system"] = f"{system}_{ablation}" if ablation != "" else system | |
for llm in llms: | |
llm_ablation_dirname = f"{llm}_{ablation}" if ablation != "" else llm | |
summary_stats_path = os.path.join(args.output_path, llm_ablation_dirname, system, "eli5_summary_stats.json") | |
if not os.path.exists(summary_stats_path): | |
continue | |
summary_stats = load_json_data(summary_stats_path) | |
for level in ["char", "doc"]: | |
if args.show_all: | |
result[f"{llm}_{level}_precision"] = summary_stats[span_type][level]["precision"] | |
result[f"{llm}_{level}_recall"] = summary_stats[span_type][level]["recall"] | |
result[f"{llm}_{level}_f1"] = summary_stats[span_type][level]["f1"] | |
if len(result) > 1: | |
results.append(result) | |
df = pd.DataFrame(results) | |
print("\nSpan Type:", span_type) | |
print() | |
print(df.to_markdown(index=False, floatfmt=args.floatfmt)) | |
if args.plot_ranges: | |
# extract numeric range from `system` column | |
range_name = "_".join(ablations[0].split("_")[:-1]) | |
range_values = [float(ab.split("_")[-1]) for ab in ablations] | |
df[range_name] = range_values | |
# sort by range values | |
df.sort_values(by=[range_name], inplace=True) | |
for llm in llms: | |
for level in ["char", "doc"]: | |
# plot with range_values on x-axis using df.plot | |
plot = df.plot( | |
x=range_name, | |
y=[f"{llm}_{level}_precision", f"{llm}_{level}_recall", f"{llm}_{level}_f1"], | |
kind="line", | |
title=f"Effect of {range_name}\n({llm}, {span_type} spans, {level}-level scores)", | |
xticks=[v for v in range_values if v.is_integer()], | |
ylabel="Precision, Recall, F1", | |
ylim=(0.0, 0.8), | |
figsize=(5, 4) | |
) | |
# save plot | |
plot.get_figure().savefig( | |
os.path.join(args.output_path, f"{llm}_{span_type}_{level}_range_plot.png") | |
) | |
if __name__ == "__main__": | |
main() |