Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
DAR-CTEval-F24/StudentNotebooks/Assignment08_FinalProjectNotebook/CTBench_LLM_promt.Rmd
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
519 lines (382 sloc)
20 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
```{r} | |
library(dplyr) | |
library(openai) # devtools::install_github("irudnyts/openai", ref = "r6") | |
library(rlist) | |
library(stringr) | |
library(purrr) | |
library(progress) | |
library(jsonlite) | |
``` | |
### Loading the pertinent datasets (Repo or Pub can be used) | |
```{r} | |
#Load dataset | |
CT_Pub.df<- readRDS("../../CTBench_source/corrected_data/ct_pub/CT_Pub_data_updated.Rds") | |
head(CT_Pub.df, 2) | |
CT_Repo.df<- readRDS("../../CTBench_source/corrected_data/ct_repo/CT_Repo_data_updated.Rds") | |
head(CT_Pub.df, 2) | |
#Model names | |
model.choices <- c("gpt-4-0613", | |
"gpt-4-turbo-preview", | |
"gpt-4o-mini", | |
"gpt-4o", | |
"o1-preview", | |
"o1-mini", | |
"Meta-Llama-3.1-8B-Instruct") | |
``` | |
###Single Generation Zero-shot | |
The following contain the function necessary for building the LLM prompt for a Zero shot generation. These work well for all models. | |
Row of data (CT_Pub or CT_Repo) is used to build the prompt and is the only parameter required for this function. | |
```{r} | |
build_zeroshot_prompt <- function(row) { | |
# Prompt structure | |
system_message <- "You are a helpful assistant with experience in the clinical domain and clinical trial design. \ | |
You'll be asked queries related to clinical trials. These inquiries will be delineated by a '##Question' heading. \ | |
Inside these queries, expect to find comprehensive details about the clinical trial structured within specific subsections, \ | |
indicated by '<>' tags. These subsections include essential information such as the trial's title, brief summary, \ | |
condition under study, inclusion and exclusion criteria, intervention, and outcomes." | |
# Baseline measure definition | |
system_message <- paste0(system_message, " In answer to this question, return a list of probable baseline features (each feature should be enclosed within a pair of backticks \ | |
and each feature should be separated by commas from other features) of the clinical trial. \ | |
Baseline features are the set of baseline or demographic characteristics that are assessed at baseline and used in the analysis of the \ | |
primary outcome measure(s) to characterize the study population and assess validity. Clinical trial-related publications typically \ | |
include a table of baseline features assessed by arm or comparison group and for the entire population of participants in the clinical trial.") | |
# Additional instructions | |
system_message <- paste0(system_message, " Do not give any additional explanations or use any tags or headers, only return the list of baseline features. ") | |
# Extract row information to generate the query | |
title <- row$BriefTitle | |
brief_summary <- row$BriefSummary | |
condition <- row$Conditions | |
eligibility_criteria <- row$EligibilityCriteria | |
intervention <- row$Interventions | |
outcome <- row$PrimaryOutcomes | |
# Construct the question | |
question <- "##Question:\n" | |
question <- paste0(question, "<Title> \n", title, "\n") | |
question <- paste0(question, "<Brief Summary> \n", brief_summary, "\n") | |
question <- paste0(question, "<Condition> \n", condition, "\n") | |
question <- paste0(question, "<Eligibility Criteria> \n", eligibility_criteria, "\n") | |
question <- paste0(question, "<Intervention> \n", intervention, "\n") | |
question <- paste0(question, "<Outcome> \n", outcome, "\n") | |
question <- paste0(question, "##Answer:\n") | |
return(c(system_message, question)) | |
} | |
``` | |
### Single Generation Triple-Shot | |
Similar to zero shot function, but for triple shot instead. 3 Trials are used to serve as examples in the triple shot generation prompt. These will be used as examples and their metadata and "answers" will be included to serve as templates for the LLM. | |
```{r} | |
# Given a row, will create an example for the triple shot prompt | |
build_example_questions_from_row <- function(data, ref_col_name) { | |
ids = c('NCT00000620', 'NCT01483560', 'NCT04280783') | |
examples = data[data$NCTId %in% ids, ] | |
question = "" | |
for (i in 1:nrow(examples)) { | |
row <- examples[i, ] | |
question <- paste0(question, "##Question:\n") | |
question <- paste0(question, "<Title> \n", row[['BriefTitle']], "\n") | |
question <- paste0(question, "<Brief Summary> \n", row[['BriefSummary']], "\n") | |
question <- paste0(question, "<Condition> \n", row[['Conditions']], "\n") | |
question <- paste0(question, "<Eligibility Criteria> \n", row[['EligibilityCriteria']], "\n") | |
question <- paste0(question, "<Intervention> \n", row[['Interventions']], "\n") | |
question <- paste0(question, "<Outcome> \n", row[['PrimaryOutcomes']], "\n") | |
question <- paste0(question, "##Answer:\n", row[[ref_col_name]], "\n\n") | |
} | |
return(question) | |
} | |
build_three_shot_prompt <- function(data, row, ref_col_name) { | |
# Prompt structure | |
system_message <- "You are a helpful assistant with experience in the clinical domain and clinical trial design. \ | |
You'll be asked queries related to clinical trials. These inquiries will be delineated by a '##Question' heading. \ | |
Inside these queries, expect to find comprehensive details about the clinical trial structured within specific subsections, \ | |
indicated by '<>' tags. These subsections include essential information such as the trial's title, brief summary, \ | |
condition under study, inclusion and exclusion criteria, intervention, and outcomes." | |
# Baseline measure definition | |
system_message <- paste0(system_message, "In answer to this question, return a list of probable baseline features (each feature should be enclosed within a pair of backticks \ | |
and each feature should be separated by commas from other features) of the clinical trial. \ | |
Baseline features are the set of baseline or demographic characteristics that are assessed at baseline and used in the analysis of the \ | |
primary outcome measure(s) to characterize the study population and assess validity. Clinical trial-related publications typically \ | |
include a table of baseline features assessed by arm or comparison group and for the entire population of participants in the clinical trial.") | |
# Additional instructions | |
system_message <- paste0(system_message, " You will be given three examples. In each example, the question is delineated by '##Question' heading and the corresponding answer is delineated by '##Answer' heading. \ | |
Follow a similar pattern when you generate answers. Do not give any additional explanations or use any tags or headings, only return the list of baseline features.") | |
# Generate examples | |
example <- build_example_questions_from_row(data, ref_col_name) | |
# Divide row information to generate the query | |
title <- row[['BriefTitle']] | |
brief_summary <- row[['BriefSummary']] | |
condition <- row[['Conditions']] | |
eligibility_criteria <- row[['EligibilityCriteria']] | |
intervention <- row[['Interventions']] | |
outcome <- row[['PrimaryOutcomes']] | |
question <- "##Question:\n" | |
question <- paste0(question, "<Title> \n", title, "\n") | |
question <- paste0(question, "<Brief Summary> \n", brief_summary, "\n") | |
question <- paste0(question, "<Condition> \n", condition, "\n") | |
question <- paste0(question, "<Eligibility Criteria> \n", eligibility_criteria, "\n") | |
question <- paste0(question, "<Intervention> \n", intervention, "\n") | |
question <- paste0(question, "<Outcome> \n", outcome, "\n") | |
question <- paste0(question, "##Answer:\n") | |
return(c(system_message, paste0(example, question))) | |
} | |
``` | |
### API call function | |
This function is made to serve as the main method to send the generation prompts to either OpenAI or the local instance of llama. It takes in the user and system prompts as a parameter (list:prompts) and the name of the specific model, the generation prompt should be sent to (model). | |
```{r} | |
# Set OpenAI API key | |
#mykey <- insert_your_key_here | |
#Sys.setenv(OPENAI_API_KEY = mykey ) | |
# Using purrr's insistently() to retry | |
rate <- rate_delay(5) # retry rate | |
# HERE'S THE MAGIC: This is how we hit the LLM endpoints | |
# This is "risky" because it doesn't protect against rate limits | |
risky_create_completion <- function(prompts, model) { | |
# Choose the endpoint based on model name | |
if (startsWith(model, "gpt-") || startsWith(model, "o1-")) { | |
client <- OpenAI() | |
} else { | |
client <- OpenAI( | |
base_url = "http://idea-llm-01.idea.rpi.edu:5000/v1/" | |
) | |
} | |
# This is where we specify the prompts! | |
client$chat$completions$create( | |
model = model, | |
messages = list( | |
list( | |
"role" = "system", | |
"content" = prompts[1] | |
), | |
list( | |
"role" = "user", | |
"content" = prompts[2] | |
) | |
), | |
) | |
} | |
# This wrapper is CRITICAL to avoid rate limit errors | |
insistent_create_completion <- insistently(risky_create_completion, rate, quiet = FALSE) | |
``` | |
### API Call Single Generation | |
This is an example of a single generation. Where we will generate candidate features for a specific trial's metadata. In this example we perform both zero and triple shot. | |
```{r} | |
model_index = 7 | |
# Zero shot | |
zs_prompts = build_zeroshot_prompt(CT_Pub.df[6, ]) | |
single_zs_test = insistent_create_completion(zs_prompts, model.choices[model_index])$choices[[1]]$message$content | |
# Triple Shot | |
ts_prompts = build_three_shot_prompt(CT_Pub.df, CT_Pub.df[6, ], "Paper_BaselineMeasures_Corrected") | |
single_ts_test = insistent_create_completion(ts_prompts, model.choices[model_index])$choices[[1]]$message$content | |
print("Single Zero Shot Test") | |
single_zs_test | |
print("Single Triple Shot Test") | |
single_ts_test | |
``` | |
### API Call Batch Generation | |
Similar to the single generation, this script will generate candidate features for all relevant trials in the dataframe of question. It takes the results and stores them into a dataframe with the following columns: | |
* trial_id (NCTId) | |
* model (name of model) | |
* gen_response (the generated candidate features) | |
* match_model (the name of the matching model for evaluation) | |
* len_matches (length of matched features) | |
* len_reference (length of unmatched reference features) | |
* len_candidate (length of unmatched candidate features) | |
* precision, recall, f1 (benchmark metrics) | |
```{r} | |
n = nrow(CT_Pub.df) | |
pb = progress_bar$new( | |
format = " Processing [:bar] :percent in :elapsed", | |
total = n, clear = FALSE, width = 60 | |
) | |
model_index = 1 | |
CT_Pub_responses.df = data.frame(matrix(ncol = 10, nrow = 0)) | |
colnames(CT_Pub_responses.df) = c("trial_id", "model", "gen_response", "match_model", "len_matches", "len_reference", "len_candidate", "precision", "recall", "f1") | |
n=10 | |
for (i in 1:n){ | |
#Zero Shot | |
zs_prompts = build_zeroshot_prompt(CT_Pub.df[i, ]) | |
zs_response = insistent_create_completion(zs_prompts, model.choices[model_index])$choices[[1]]$message$content | |
entry1 = c(CT_Pub.df[i,]$NCTId, paste0(model.choices[model_index],"-zs"), zs_response, NA, NA, NA, NA, NA, NA, NA) | |
CT_Pub_responses.df[nrow(CT_Pub_responses.df) + 1, ] = entry1 | |
#Triple Shot | |
ts_prompts = build_three_shot_prompt(CT_Pub.df, CT_Pub.df[i, ], "Paper_BaselineMeasures_Corrected") | |
ts_response = insistent_create_completion(ts_prompts, model.choices[model_index])$choices[[1]]$message$content | |
entry2 = c(CT_Pub.df[i,]$NCTId, paste0(model.choices[model_index],"-ts"), ts_response, NA, NA, NA, NA, NA, NA, NA) | |
CT_Pub_responses.df[nrow(CT_Pub_responses.df) + 1, ] = entry2 | |
pb$tick() | |
} | |
CT_Pub_responses.df | |
``` | |
### Evaluation Prompt functions | |
These are the functions to create the prompts for evaluation. The main goal of evaluation is to find the number of matches and thus the overall accuracy of the intially generated candidate features compared to the original trials' reference feature. | |
Note that the functions is tailored for llama thus a modified version of the the evaluation prompt is used. | |
```{r} | |
build_eval_prompt <- function(reference, candidate, qstart) { | |
# Define the system message | |
system <- " | |
You are an expert assistant in the medical domain and clinical trial design. You are provided with details of a clinical trial. | |
Your task is to determine which candidate baseline features match any feature in a reference baseline feature list for that trial. | |
You need to consider the context and semantics while matching the features. | |
For each candidate feature: | |
1. Identify a matching reference feature based on similarity in context and semantics. | |
2. Remember the matched pair. | |
3. A reference feature can only be matched to one candidate feature and cannot be further considered for any consecutive matches. | |
4. If there are multiple possible matches (i.e. one reference feature can be matched to multiple candidate features or vice versa), choose the most contextually similar one. | |
5. Also keep track of which reference and candidate features remain unmatched. | |
6. DO NOT provide the code to accomplish this and ONLY respond with the following JSON. Perform the matching yourself. | |
Once the matching is complete, omitting explanations provide the answer only in the following form: | |
{\"matched_features\": [[\"<reference feature 1>\" , \"<candidate feature 1>\" ],[\"<reference feature 2>\" , \"<candidate feature 2>\"]],\"remaining_reference_features\": [\"<unmatched reference feature 1>\" ,\"<unmatched reference feature 2>\"],\"remaining_candidate_features\" : [\"<unmatched candidate feature 1>\" ,\"<unmatched candidate feature 2>\"]} | |
7. Please generate a valid JSON object, ensuring it fits within a single JSON code block, with all keys and values properly quoted and all elements closed. Do not include line breaks within array elements." | |
# Start building the question message | |
question <- paste("\nHere is the trial information: \n\n", qstart, "\n\n", sep = "") | |
# Add the reference features | |
question <- paste(question, "Here is the list of reference features: \n\n", sep = "") | |
for (i in seq_along(reference)) { | |
question <- paste(question, i, ". ", reference[[i]], "\n", sep = "") | |
} | |
# Add the candidate features | |
question <- paste(question, "\nCandidate features: \n\n", sep = "") | |
for (i in seq_along(candidate)) { | |
question <- paste(question, i, ". ", candidate[[i]], "\n", sep = "") | |
} | |
return (c(system, question)) | |
} | |
get_question_from_row <- function(row) { | |
# Extract relevant fields from the row | |
title <- row["BriefTitle"] | |
brief_summary <- row["BriefSummary"] | |
condition <- row["Conditions"] | |
eligibility_criteria <- row["EligibilityCriteria"] | |
intervention <- row["Interventions"] | |
outcome <- row["PrimaryOutcomes"] | |
# Build the question string by concatenating the extracted fields | |
question <- "" | |
question <- paste(question, "<Title> \n", title, "\n", sep = "") | |
question <- paste(question, "<Brief Summary> \n", brief_summary, "\n", sep = "") | |
question <- paste(question, "<Condition> \n", condition, "\n", sep = "") | |
question <- paste(question, "<Eligibility Criteria> \n", eligibility_criteria, "\n", sep = "") | |
question <- paste(question, "<Intervention> \n", intervention, "\n", sep = "") | |
question <- paste(question, "<Outcome> \n", outcome, "\n", sep = "") | |
return(question) | |
} | |
extract_elements <- function(s) { | |
# Define the pattern to match text within backticks | |
pattern <- "`(.*?)`" | |
# Use the regmatches and gregexpr functions to find all matches | |
elements <- regmatches(s, gregexpr(pattern, s, perl = TRUE))[[1]] | |
# Remove the enclosing backticks from the matched elements | |
elements <- gsub("`", "", elements) | |
return(elements) | |
} | |
``` | |
### Single Eval Example | |
This script shows an example of running the evaluation prompt once. | |
```{r} | |
qstart = get_question_from_row(CT_Pub.df[1,]) | |
reference_list = extract_elements(CT_Pub.df[1,]["Paper_BaselineMeasures_Corrected"]) | |
candidate_list = extract_elements(CT_Pub_responses.df[1,"gen_response"]) | |
eval_prompts = build_eval_prompt(reference_list, candidate_list, qstart) | |
# model index set to 7 for llama | |
matched_json = insistent_create_completion(eval_prompts, model.choices[7])$choices[[1]]$message$content | |
temp_df = fromJSON(matched_json) | |
kable(temp_df$matched_features) | |
kable(temp_df$remaining_reference_features) | |
kable(temp_df$remaining_candidate_features) | |
print(temp_df) | |
``` | |
### Batch Evaluation | |
This script evaluates multiple different trials at once. Note that there are multiple helper functions I have employed to ensure that the output of the evaluation produces legitimate JSON, as pure generation from the evaluation LLM may provide JSON with incorrect syntax. | |
```{r} | |
extract_json <- function(text) { | |
# Regular expression to detect JSON objects or arrays, allowing nested structures | |
json_pattern <- "\\{(?:[^{}]|(?R))*\\}|\\[(?:[^[\\]]|(?R))*\\]" | |
# Extract all matches | |
matches <- regmatches(text, gregexpr(json_pattern, text, perl = TRUE))[[1]] | |
# Validate JSON strings by attempting to parse | |
valid_json <- matches[sapply(matches, function(x) { | |
tryCatch({ | |
fromJSON(x) | |
TRUE | |
}, error = function(e) FALSE) | |
})] | |
return(valid_json) | |
} | |
n = nrow(CT_Pub_responses.df) | |
pb = progress_bar$new( | |
format = " Processing [:bar] :percent in :elapsed", | |
total = n, clear = FALSE, width = 60 | |
) | |
eval_model_index = 1 | |
for (i in 1:n){ | |
#obtain information from pertinent rows | |
row = CT_Pub.df[CT_Pub.df$NCTId == CT_Pub_responses.df$trial_id,] | |
qstart = get_question_from_row(row) | |
reference_list = extract_elements(row["Paper_BaselineMeasures_Corrected"]) | |
candidate_list = extract_elements(CT_Pub_responses.df[i,]["gen_response"]) | |
eval_prompts = build_eval_prompt(reference_list, candidate_list, qstart) | |
retry = TRUE | |
while(retry){ | |
tryCatch( | |
{ | |
# model index set to 7 for llama | |
matched_json = insistent_create_completion(eval_prompts, model.choices[eval_model_index])$choices[[1]]$message$content | |
json_data = extract_json(matched_json) | |
temp_df = fromJSON(json_data) | |
retry = FALSE | |
}, | |
error = function(e) { | |
print(as.character(e)) | |
}) | |
} | |
#set match lengths: | |
CT_Pub_responses.df[i,"match_model"] = model.choices[eval_model_index] | |
CT_Pub_responses.df[i,"len_matches"] = nrow(temp_df$matched_features) | |
CT_Pub_responses.df[i,"len_reference"] = length(temp_df$remaining_reference_features) | |
CT_Pub_responses.df[i,"len_candidate"] = length(temp_df$remaining_candidate_features) | |
pb$tick() | |
} | |
# matches.df | |
CT_Pub_responses.df | |
``` | |
### Single Benchmark | |
This is an example of using the length of the matches, unmatched reference features, and unmatched candidate features to find benchmarking metrics of recall, precision, and f1. | |
```{r} | |
match_to_score <- function(matched_pairs, remaining_reference_features, remaining_candidate_features) { | |
# Calculate precision: TP / (TP + FP) | |
precision <- length(matched_pairs) / (length(matched_pairs) + length(remaining_candidate_features)) | |
# Calculate recall: TP / (TP + FN) | |
recall <- length(matched_pairs) / (length(matched_pairs) + length(remaining_reference_features)) | |
# Calculate F1 score: 2 * (precision * recall) / (precision + recall) | |
if (precision == 0 || recall == 0) { | |
f1 <- 0 | |
} else { | |
f1 <- 2 * (precision * recall) / (precision + recall) | |
} | |
# Return a list with precision, recall, and f1 | |
return(c(precision, recall, f1)) | |
} | |
gzs_dict <- jsonlite::fromJSON(matched_json) | |
# Extract matched features, remaining reference features, and remaining candidate features | |
matches <- gzs_dict$matched_features | |
remaining_references <- gzs_dict$remaining_reference_features | |
remaining_candidates <- gzs_dict$remaining_candidate_features | |
# Calculate precision, recall, and F1 score | |
score <- match_to_score(matches, remaining_references, remaining_candidates) | |
kable(score) | |
``` | |
### Batch Benchmarking | |
Similar to the single benchmarking script, benchmarking metrics are derived and stored in the data dataframe that contains all the information for a trial. | |
```{r} | |
for (i in 1:nrow(CT_Pub_responses.df)){ | |
#Store true positive, false positive, and false negative | |
TP = as.numeric(CT_Pub_responses.df[i,]["len_matches"]) | |
FP = as.numeric(CT_Pub_responses.df[i,]["len_candidate"]) | |
FN = as.numeric(CT_Pub_responses.df[i,]["len_reference"]) | |
# Calculate and store precision, recall, and f1 | |
precision = TP/(TP + FP) | |
recall = TP/(TP + FN) | |
f1 = 2 * (precision * recall)/(precision + recall) | |
CT_Pub_responses.df[i,]["precision"] = precision | |
CT_Pub_responses.df[i,]["recall"] = recall | |
CT_Pub_responses.df[i,]["f1"] = f1 | |
} | |
CT_Pub_responses.df | |
``` |