import json
import numpy as np
import argparse
from sklearn.metrics import precision_score, recall_score, f1_score

def load_json_data(json_file):
    """
    Load data from the JSON file.

    Returns:
        list: The data from the JSON file.
    """
    with open(json_file, 'r') as f:
        data = json.load(f)
    return data

def parse_args():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser(description='Compute metrics on the pipeline output')
    parser.add_argument("--eli5_path", type=str, 
                      help="Path to the ELI5 dataset JSON file", default="./data/eli5_eval_bm25_top100_reranked_oracle.json")
    parser.add_argument('--gt_path', type=str,
                      help='Path to the JSON file containing the ground truth annotations', default='./data/eli5_eval_bm25_top100_reranked_oracle_spans_generated_llama31_70B_first_100_joep_joek.json')
    parser.add_argument('--pred_path', type=str,
                      help='Path to the JSON file containing the predicted spans', default='./eli5_salsa_sample.json')
    parser.add_argument("--doc_level_f1_threshold", type=float, 
                      help="Threshold for the document-level F1 score to determine if a document can be considered 'correctly' cited", default=0.5)
    return parser.parse_args()

def get_span_group_char_indicators(doc_len, span_group):
    char_inds = np.zeros(doc_len, dtype=int)
    for span in span_group:
        char_inds[span["start"]:span["end"]] = 1
    return char_inds

def add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds):
    # Compute summary statistics
    num_samples = len(gt_inds)
    num_gt_positive = np.sum(gt_inds).item()
    num_pred_positive = np.sum(pred_inds).item()
    percent_gt_positive = 100 * (num_gt_positive / num_samples)
    percent_pred_positive = 100 * (num_pred_positive / num_samples)
    precision = precision_score(gt_inds, pred_inds)
    recall = recall_score(gt_inds, pred_inds)
    f1 = f1_score(gt_inds, pred_inds)

    # Add to the dictionary
    summary_stats[span_type][level] = {
        'num_samples': num_samples,
        'num_gt_positive': num_gt_positive,
        'num_pred_positive': num_pred_positive,
        'percent_gt_positive': percent_gt_positive,
        'percent_pred_positive': percent_pred_positive,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

    # Print out
    print()
    print(f"({level}-level) Number of GT positives: {num_gt_positive} / {num_samples} ({percent_gt_positive:.2f}%)")
    print(f"({level}-level) Number of predicted positives: {num_pred_positive} / {num_samples} ({percent_pred_positive:.2f}%)")
    print(f"({level}-level) Precision: {precision:.4f}")
    print(f"({level}-level) Recall: {recall:.4f}")
    print(f"({level}-level) F1: {f1:.4f}")
    print()

def main():
    # Parse command line arguments
    args = parse_args()
    
    print("\n---------------------------------------------------------------------------------------------------------")
    print(f">>> Computing metrics for {args.pred_path}")
    print("---------------------------------------------------------------------------------------------------------\n")

    # Load both JSON files
    eli5_data = load_json_data(args.eli5_path)
    gt_data = load_json_data(args.gt_path)
    pred_data = load_json_data(args.pred_path)
    
    # Iterate through the data
    citation_gt_char_inds = []
    citation_pred_char_inds = []
    conflict_gt_char_inds = []
    conflict_pred_char_inds = []

    citation_gt_doc_inds = []
    citation_pred_doc_inds = []
    conflict_gt_doc_inds = []
    conflict_pred_doc_inds = []

    for eli5_item, gt_item, pred_item in zip(eli5_data, gt_data, pred_data):
        for gt_sentence, pred_sentence in zip(gt_item['answer_sentences'], pred_item['answer_sentences']):
            for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
                ("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), 
                ("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
            ]:
                for i, (gt_doc_span_groups, pred_doc_span_groups) in enumerate(
                    zip(gt_sentence[f'{span_type}_spans'], pred_sentence[f'{span_type}_spans'])
                ):
                    # Get pred and gt indicators {0, 1} for each character position in the document
                    doc_len = len(eli5_item["docs"][i]["text"])
                    # Predicted spans aren't grouped so there will only be one group
                    pred_doc_spans = pred_doc_span_groups[0] if pred_doc_span_groups else []
                    pred_doc_char_inds = get_span_group_char_indicators(doc_len, pred_doc_spans)
                    # Ground truth spans are grouped
                    gt_doc_span_groups = gt_doc_span_groups if gt_doc_span_groups else [[]]
                    gt_doc_group_char_inds = [
                        get_span_group_char_indicators(doc_len, gt_doc_spans) 
                        for gt_doc_spans in gt_doc_span_groups
                    ]
                    # get the f1 score for each group vs the predicted indicators and pick the best matching group
                    # (GT span groups are OR'd together meaning that the predicted spans can match any of the groups to be correct)
                    doc_group_f1_scores = [
                        f1_score(gt_doc_char_inds, pred_doc_char_inds, zero_division=0)
                        for gt_doc_char_inds in gt_doc_group_char_inds
                    ]
                    best_doc_group_f1_idx = np.argmax(doc_group_f1_scores)
                    doc_f1_score = doc_group_f1_scores[best_doc_group_f1_idx]
                    gt_doc_char_inds = gt_doc_group_char_inds[best_doc_group_f1_idx]
                    # Append to the character-level indicator lists
                    gt_char_inds.append(gt_doc_char_inds)
                    pred_char_inds.append(pred_doc_char_inds)

                    # get pred and gt indicators {0, 1} for each document
                    gt_doc_ind = 1 if np.sum(gt_doc_char_inds) > 0 else 0
                    if gt_doc_ind == 0:
                        # If no GT spans exist than consider the document cited if any predicted spans exist
                        pred_doc_ind = 1 if np.sum(pred_doc_char_inds) > 0 else 0
                    else:
                        # If GT spans exist than consider the document cited if the F1 score is above the threshold
                        # i.e. the predicted citation spans are close enough to the GT spans
                        pred_doc_ind = 1 if doc_f1_score >= args.doc_level_f1_threshold else 0
                    # Append to the document-level indicator lists
                    gt_doc_inds.append(gt_doc_ind)
                    pred_doc_inds.append(pred_doc_ind)

    citation_gt_char_inds = np.concatenate(citation_gt_char_inds)
    citation_pred_char_inds = np.concatenate(citation_pred_char_inds)
    conflict_gt_char_inds = np.concatenate(conflict_gt_char_inds)
    conflict_pred_char_inds = np.concatenate(conflict_pred_char_inds)

    citation_gt_doc_inds = np.array(citation_gt_doc_inds)
    citation_pred_doc_inds = np.array(citation_pred_doc_inds)
    conflict_gt_doc_inds = np.array(conflict_gt_doc_inds)
    conflict_pred_doc_inds = np.array(conflict_pred_doc_inds)
    
    # Print summary statistics
    summary_stats = {
        "citation": {},
        "conflict": {},
    }
    for span_type, gt_char_inds, pred_char_inds, gt_doc_inds, pred_doc_inds in [
        ("citation", citation_gt_char_inds, citation_pred_char_inds, citation_gt_doc_inds, citation_pred_doc_inds), 
        ("conflict", conflict_gt_char_inds, conflict_pred_char_inds, conflict_gt_doc_inds, conflict_pred_doc_inds),
    ]:  
        print("-------------------------")
        print(f"> Span type: {span_type}")
        print("-------------------------")
        for level, gt_inds, pred_inds in [("char", gt_char_inds, pred_char_inds), ("doc", gt_doc_inds, pred_doc_inds)]:
            add_summary_stats(summary_stats, span_type, level, gt_inds, pred_inds)

    # Save the summary statistics
    output_path = args.pred_path.replace(".json", "_summary_stats.json")
    with open(output_path, 'w') as f:
        json.dump(summary_stats, f, indent=2)
    print(f"Summary statistics saved to: {output_path}")

if __name__ == "__main__":
    main()