Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
from util import *
from Bio import PDB
from Bio.SVDSuperimposer import SVDSuperimposer
def calculate_gdt_ts(pdb_file1, pdb_file2, thresholds=[1, 2, 4, 8]):
parser = PDB.PDBParser(QUIET=True)
# Load the structures
structure1 = parser.get_structure('structure1', pdb_file1)
structure2 = parser.get_structure('structure2', pdb_file2)
# Extract the alpha carbons (CA atoms)
ca_atoms1 = [atom for atom in structure1.get_atoms() if atom.get_id() == 'CA']
ca_atoms2 = [atom for atom in structure2.get_atoms() if atom.get_id() == 'CA']
if len(ca_atoms1) != len(ca_atoms2):
raise ValueError("The two structures do not have the same number of CA atoms.")
# Extract the coordinates
coords1 = np.array([atom.get_coord() for atom in ca_atoms1])
coords2 = np.array([atom.get_coord() for atom in ca_atoms2])
# Initialize SVDSuperimposer
sup = SVDSuperimposer()
sup.set(coords1, coords2)
sup.run()
rms = sup.get_rms()
rot, tran = sup.get_rotran()
# Apply rotation and translation to the second set of coordinates
aligned_coords2 = np.dot(coords2, rot) + tran
# Calculate GDT_TS
gdt_scores = []
for threshold in thresholds:
distances = np.sqrt(np.sum((coords1 - aligned_coords2) ** 2, axis=1))
gdt_score = np.sum(distances <= threshold) / len(coords1)
gdt_scores.append(gdt_score)
gdt_ts = np.mean(gdt_scores)
return rms, gdt_ts
def create_scatterplot(csv_file, files_to_include, files_to_highlight, pdb_id):
# Extract pdb_id from the file name
# Read the CSV file
df = pd.read_csv(csv_file)
# Check if the necessary columns exist
# Extract the basename of the holo_model_path and filter based on files_to_include
df['basename'] = df['holo_model_path'].apply(os.path.basename)
df_filtered = df[df['basename'].isin(files_to_include)]
AFS_v21_points = df_filtered[~df_filtered['holo_model_path'].str.contains('exp|comp|v2|notemplate')]
AFS_v22_points = df_filtered[df_filtered['holo_model_path'].str.contains('v2')]
NMR_points = df_filtered[df_filtered['holo_model_path'].str.contains('exp')]
AF2_points = df_filtered[df_filtered['holo_model_path'].str.contains('comp')]
AFALT_points = df_filtered[df_filtered['holo_model_path'].str.contains('notemplate')]
print("number of AFS_v21 points = " + str(len(AFS_v21_points)))
print("number of AFS_v22 points = " + str(len(AFS_v22_points)))
print("number of NMR points = " + str(len(NMR_points)))
print("number of AF2 points = " + str(len(AF2_points)))
print("number of AFALT points = " + str(len(AFALT_points)))
def highlight_points(points, color, label, cat):
highlight = points['basename'].isin(files_to_highlight)
plt.scatter(points[cat][~highlight], 1 - points['consensus'][~highlight], c=color, label=label)
plt.scatter(points[cat][highlight], 1 - points['consensus'][highlight], c=color, edgecolor='black', linewidth=1.5)
if 'RMSD' in df.columns and 'consensus' in df.columns and 'holo_model_path' in df.columns:
# Plot points with highlighting
highlight_points(AFS_v22_points, 'cyan', 'AFS AF2 v2', 'RMSD')
highlight_points(AFS_v21_points, 'yellow', 'AFS AF2 v1', 'RMSD')
highlight_points(NMR_points, 'green', 'NMR', 'RMSD')
highlight_points(AF2_points, 'purple', 'Baseline AF2', 'RMSD')
highlight_points(AFALT_points, 'orange', 'AF Alt', 'RMSD')
# Add legend
plt.legend()
# Plotting
#plt.scatter(df['rmsd'], 1 - df['consensus_comp'], c=colors)
plt.xlabel('RMSD')
plt.ylabel('1 - Consensus Score')
plt.title('Scatter Plot between RMSD and Consensus Score of ' + pdb_id)
save_file = '/home/tiburon/Desktop/ROT4/AFS_FINAL/' + pdb_id + '_funnel_RMSD'
plt.savefig(save_file)
plt.close()
else:
print("Required columns ('RMSD', 'consensus', and 'holo_model_path') not found in the CSV.")
if 'GDT_TS' in df.columns and 'consensus' in df.columns and 'holo_model_path' in df.columns:
# Determine color based on the presence of 'exp' in 'holo_model_path'
#colors = ['green' if 'exp' in path else ('purple' if 'comp' in path else 'blue') for path in df['holo_model_path']]
highlight_points(AFS_v22_points, 'cyan', 'AFS AF2 v2', 'GDT_TS')
highlight_points(AFS_v21_points, 'yellow', 'AFS AF2 v1', 'GDT_TS')
highlight_points(NMR_points, 'green', 'NMR', 'GDT_TS')
highlight_points(AF2_points, 'purple', 'Baseline AF2', 'GDT_TS')
highlight_points(AFALT_points, 'orange', 'AF Alt', 'GDT_TS')
# Add legend
plt.legend()
# Plotting
#plt.scatter(df['rmsd'], 1 - df['consensus_comp'], c=colors)
plt.xlabel('GDT_TS')
plt.ylabel('1 - Consensus Score')
plt.title('Scatter Plot between RMSD and Consensus Score of ' + pdb_id)
save_file = '/home/tiburon/Desktop/ROT4/AFS_FINAL/' + pdb_id + '_funnel_GDT'
plt.savefig(save_file)
plt.close()
else:
print("Required columns ('GDT_TS', 'consensus', and 'holo_model_path') not found in the CSV.")
def get_points_in_final_ensemble(bound):
data_source_file = './CSP_'+bound+'_CSpred.csv'
parsed_data = parse_csv(data_source_file)
holo_model_files = [data['holo_model_path'][data['holo_model_path'].rfind('/')+1:] for data in parsed_data]
holo_model_files_raw = [data['holo_model_path'] for data in parsed_data]
consensus_scores = [float(data['consensus']) for data in parsed_data]
UMAP_file = './data/'+bound+'_aligned_CSPREDB_UMAP_chain_B_data.csv'
UMAP_data = parse_csv(UMAP_file)
UMAP_files = [ data['pdb_file'] for data in UMAP_data ]
UMAP_clusters = [ int(data['Cluster']) for data in UMAP_data ]
TSNE_file = './data/'+bound+'_aligned_CSPREDB_TSNE_chain_B_data.csv'
TSNE_data = parse_csv(TSNE_file)
TSNE_files = [ data['pdb_file'] for data in TSNE_data ]
TSNE_clusters = [ int(data['Cluster']) for data in TSNE_data ]
print("getting TSNE cluster scores")
TSNE_cluster_scores = {}
TSNE_cluster_files = {}
for i, pdb_file in enumerate(TSNE_files):
cluster_number = TSNE_clusters[i]
if cluster_number not in list(TSNE_cluster_scores):
TSNE_cluster_scores[cluster_number] = []
TSNE_cluster_files[cluster_number] = []
try:
index = holo_model_files.index(pdb_file)
except:
continue
TSNE_cluster_files[cluster_number].append(holo_model_files_raw[index])
TSNE_cluster_scores[cluster_number].append(consensus_scores[index])
#plot_boxplots(TSNE_cluster_scores)
print("getting UMAP cluster scores")
UMAP_cluster_scores = {}
UMAP_cluster_files = {}
for i, pdb_file in enumerate(UMAP_files):
cluster_number = UMAP_clusters[i]
if cluster_number not in list(UMAP_cluster_scores):
UMAP_cluster_scores[cluster_number] = []
UMAP_cluster_files[cluster_number] = []
try:
index = holo_model_files.index(pdb_file)
except:
continue
UMAP_cluster_files[cluster_number].append(holo_model_files_raw[index])
UMAP_cluster_scores[cluster_number].append(consensus_scores[index])
max_consensus_files = []
print("getting TSNE cluster medoid structures")
for cluster in list(TSNE_cluster_scores):
max_score = 0
max_score_itr = -1
for itr, score in enumerate(TSNE_cluster_scores[cluster]):
if score > max_score:
max_score = score
max_score_itr = itr
max_score_file = TSNE_cluster_files[cluster][max_score_itr]
print("Max consensus for TSNE cluster " + str(cluster) + ' = ' + str(max_score) + '. PDB file = ' + max_score_file)
max_consensus_files.append(max_score_file[max_score_file.rfind('/')+1:])
print("getting TSNE cluster medoid structures")
for cluster in list(UMAP_cluster_scores):
max_score = 0
max_score_itr = -1
for itr, score in enumerate(UMAP_cluster_scores[cluster]):
if score > max_score:
max_score = score
max_score_itr = itr
max_score_file = UMAP_cluster_files[cluster][max_score_itr]
print("Max consensus for UMAP cluster " + str(cluster) + ' = ' + str(max_score) + '. PDB file = ' + max_score_file)
max_consensus_files.append(max_score_file[max_score_file.rfind('/')+1:])
return max_consensus_files
def get_max_cons_pdb(pdb):
data_source_file = './CSP_'+pdb+'_CSpred.csv'
parsed_data = parse_csv(data_source_file)
bound_model_paths = [data['holo_model_path'] for data in parsed_data]
consensus = [data['consensus'] for data in parsed_data]
return bound_model_paths[consensus.index(max(consensus))]
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python funnel.py <bound>")
sys.exit(1)
pdb_id = sys.argv[1].lower()
exp_pdb = './experimental_structures/exp_'+pdb_id+'.pdb'
ref_pdb = exp_pdb
ref_pdb = get_max_cons_pdb(pdb_id)
print("reference pdb file for alignment = " + ref_pdb)
data_source_file = './CSP_'+pdb_id+'_CSpred.csv'
parsed_data = parse_csv(data_source_file)
bound_model_paths = [data['holo_model_path'] for data in parsed_data]
apos = [str(data['apo_bmrb']) for data in parsed_data]
i = 0
for pdb_file in tqdm(bound_model_paths):
AF2_RMSD, AF2_GDT_TS = calculate_gdt_ts(ref_pdb, pdb_file)
new_values = [AF2_RMSD, AF2_GDT_TS]
new_columns = ['RMSD', 'GDT_TS']
apo = apos[i]
holo = pdb_id
update_row(data_source_file, apo.upper(), pdb_file, new_values, new_columns)
i += 1
trimmed_source_file = './data/'+pdb_id+'_aligned_CSPREDB_PCA_chain_B_data_trimmed.csv'
parsed_data = parse_csv(trimmed_source_file)
files_to_include = [data['pdb_file'] for data in parsed_data]
files_to_highlight = get_points_in_final_ensemble(pdb_id)
create_scatterplot(data_source_file, files_to_include, files_to_highlight, pdb_id)