Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# es_ensemble_plddt.py
from util import *
import sys
import pymol
from pymol import cmd
from paths import *
def process_pdb(pdb_file, object_name):
# Initialize PyMOL
pymol.finish_launching()
# Load the PDB file
cmd.load(pdb_file, object_name)
# Color by chain
cmd.color('green', object_name + ' and chain A')
cmd.color('cyan', object_name + ' and chain B')
# Show chain B as sticks
cmd.show('sticks', f'{object_name} and chain B')
# Hide ribbon for chain B
cmd.hide('cartoon', f'{object_name} and chain B')
cmd.orient()
cmd.viewport(800, 800)
def process_pdb_plddt(pdb_file, object_name, reverse = False):
# Initialize PyMOL
pymol.finish_launching()
# Load the PDB file
cmd.load(pdb_file, object_name)
# Color by chain
if not(reverse):
cmd.color('green', object_name + ' and chain A')
cmd.color('cyan', object_name + ' and chain B')
# Show chain B as sticks
cmd.show('sticks', f'{object_name} and chain B')
# Hide ribbon for chain B
cmd.hide('cartoon', f'{object_name} and chain B')
else:
cmd.color('green', object_name + ' and chain B')
cmd.color('cyan', object_name + ' and chain A')
# Show chain B as sticks
cmd.show('sticks', f'{object_name} and chain A')
# Hide ribbon for chain B
cmd.hide('cartoon', f'{object_name} and chain A')
# Use spectrum to color the object based on B-factor values
cmd.spectrum("b", "blue_red", f"{object_name}")
# Use spectrum to color the cartoon based on B-factor values
#cmd.spectrum("b", "blue_red", f"{object_name} and cartoon")
#cmd.spectrum("b", "blue_red", f"{object_name} and sticks")
# Color the backbone cartoon
#cmd.show("cartoon", f"{object_name}")
#cmd.color("white", f"{object_name} and elem C")
cmd.orient()
cmd.viewport(800, 800)
def get_original_file(pdb_file, pdb_id):
if pdb_file.find('notemplate') == -1:
new_file = 'unrelaxed_'
new_file += pdb_file[pdb_file.rfind('model'):pdb_file.rfind('_af2')]
new_file += '.pdb'
return f"{PDB_FILES}{pdb_id.lower()}/{new_file}"
else:
new_file = ''
new_file += pdb_file[pdb_file.rfind('min_')+4:pdb_file.rfind('_af2')]
new_file += '.pdb'
return f"{PDB_FILES}{pdb_id.lower()}_alt/{new_file}"
from collections import defaultdict
from Bio import pairwise2
def convert_aa_name(aa_name):
aa_dict = {
'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}
return aa_dict.get(aa_name, 'X')
def parse_pdb(pdb_file):
chains = defaultdict(lambda: {'sequence': [], 'plddt': [], 'chain_id': ''})
with open(pdb_file, 'r') as f:
for line in f:
if line.startswith('ATOM') and line[13:15].strip() == 'CA':
chain_id = line[21].strip()
res_name = line[17:20].strip()
b_factor = float(line[60:66].strip())
chains[chain_id]['chain_id'] = chain_id
chains[chain_id]['sequence'].append(convert_aa_name(res_name))
chains[chain_id]['plddt'].append(b_factor)
# Convert the lists to the required format
chain_data = []
for chain_id, data in chains.items():
chain_data.append({
'sequence': ''.join(data['sequence']),
'plddt': data['plddt'],
'chain_id': chain_id
})
if chain_data[0]['chain_id'] == 'B' and chain_data[1]['chain_id'] == 'C':
chain_data[0]['chain_id'] = 'A'
chain_data[1]['chain_id'] = 'B'
return chain_data
def update_b_factors_all_chains(chain_data, pdb_filepath, new_pdb_filepath):
with open(pdb_filepath, 'r') as pdb_file:
lines = pdb_file.readlines()
chain_residue_count = defaultdict(set)
atom_lines = []
seqs = {}
last_res_ind = defaultdict(lambda: -1)
for line in lines:
if line.startswith('ATOM') or line.startswith('HETATM'):
atom_lines.append(line)
chain_id = line[21]
if chain_id not in list(seqs):
seqs[chain_id] = ""
residue_index = int(line[22:26].strip())
if residue_index != last_res_ind[chain_id]:
last_res_ind[chain_id] = residue_index
seqs[chain_id] += convert_aa_name(line[17:20].strip())
chain_residue_count[chain_id].add(residue_index)
# Build a dictionary of new B-factors for each chain
b_factor_dicts = {}
for chain in chain_data:
seq = chain['sequence']
plddt = chain['plddt']
chain_id = chain.get('chain_id')
if chain_id not in seqs:
continue
if len(seqs[chain_id]) != len(plddt):
# Align sequences
bound_aligned1, bound_aligned2 = pairwise2.align.globalxx(seq, seqs[chain_id])[0][:2]
new_bfactors = []
ind = 0
for i, c in enumerate(bound_aligned1):
if bound_aligned2[i] in ['_', '-']:
ind += 1
continue
if bound_aligned1[i] in ['_', '-']:
new_bfactors.append(-1)
continue
new_bfactors.append(plddt[ind])
ind += 1
if len(new_bfactors) != len(seqs[chain_id]):
continue
else:
new_bfactors = [l for l in plddt]
b_factor_dicts[chain_id] = {index: b_factor for index, b_factor in zip(sorted(chain_residue_count[chain_id]), new_bfactors)}
updated_lines = []
for line in lines:
if line.startswith('ATOM') or line.startswith('HETATM'):
chain_id = line[21]
residue_index = int(line[22:26].strip())
if chain_id in b_factor_dicts:
new_b_factor = b_factor_dicts[chain_id].get(residue_index, -1.0)
updated_line = line[:60] + f'{new_b_factor:6.2f}' + line[66:]
updated_lines.append(updated_line)
else:
updated_lines.append(line)
else:
updated_lines.append(line)
with open(new_pdb_filepath, 'w') as pdb_file:
pdb_file.writelines(updated_lines)
#import matplotlib.pyplot as plt
def plot_boxplots(data_dict):
"""
Plot boxplots for each key:query pair in the dictionary.
Parameters:
data_dict (dict): Dictionary with integer keys and lists of floats as values.
"""
# Extract keys and values
keys = list(data_dict.keys())
values = list(data_dict.values())
# Create the boxplot
plt.figure(figsize=(10, 6))
plt.boxplot(values, labels=keys)
# Set plot labels and title
plt.xlabel('Keys')
plt.ylabel('Values')
plt.title('Boxplots for Each Key:Query Pair')
# Show plot
plt.show()
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python UMAP_TSNE_STATS.py <bound>")
sys.exit(1)
bound = sys.argv[1].lower()
pdb_id = bound
data_source_file = f'{CSP_Rank_Scores}CSP_'+bound+'_CSpred.csv'
parsed_data = parse_csv(data_source_file)
holo_model_files = [data['holo_model_path'][data['holo_model_path'].rfind('/')+1:] for data in parsed_data]
holo_model_files_raw = [data['holo_model_path'] for data in parsed_data]
UMAP_file = f"{CLUSTERING_RESULTS}{bound.upper()}_aligned_CSPREDB_UMAP_chain_B_data.csv"
UMAP_data = parse_csv(UMAP_file)
UMAP_files = [ data['pdb_file'] for data in UMAP_data ]
UMAP_clusters = [ int(data['Cluster']) for data in UMAP_data ]
TSNE_file = f"{CLUSTERING_RESULTS}{bound.upper()}_aligned_CSPREDB_TSNE_chain_B_data.csv"
TSNE_data = parse_csv(TSNE_file)
TSNE_files = [ data['pdb_file'] for data in TSNE_data ]
TSNE_clusters = [ int(data['Cluster']) for data in TSNE_data ]
consensus_scores = []#[float(data['consensus']) for data in parsed_data]
Confidence_scores = []
for data in parsed_data:
conf = 0
cons = 0
try:
conf = float(data['Confidence'])
cons = float(data['consensus'])
except:
conf = 0
cons = 0
consensus_scores.append(cons)
Confidence_scores.append(conf)
print("getting TSNE cluster scores")
TSNE_cluster_scores = {}
TSNE_cluster_files = {}
for i, pdb_file in enumerate(TSNE_files):
cluster_number = TSNE_clusters[i]
if cluster_number not in list(TSNE_cluster_scores):
TSNE_cluster_scores[cluster_number] = []
TSNE_cluster_files[cluster_number] = []
try:
index = holo_model_files.index(pdb_file)
except:
continue
TSNE_cluster_files[cluster_number].append(holo_model_files_raw[index])
#TSNE_cluster_scores[cluster_number].append(consensus_scores[index])
TSNE_cluster_scores[cluster_number].append(math.sqrt(consensus_scores[index] * Confidence_scores[index]))
#plot_boxplots(TSNE_cluster_scores)
print("getting UMAP cluster scores")
UMAP_cluster_scores = {}
UMAP_cluster_files = {}
for i, pdb_file in enumerate(UMAP_files):
cluster_number = UMAP_clusters[i]
if cluster_number not in list(UMAP_cluster_scores):
UMAP_cluster_scores[cluster_number] = []
UMAP_cluster_files[cluster_number] = []
try:
index = holo_model_files.index(pdb_file)
except:
continue
UMAP_cluster_files[cluster_number].append(holo_model_files_raw[index])
#UMAP_cluster_scores[cluster_number].append(consensus_scores[index])
UMAP_cluster_scores[cluster_number].append(math.sqrt(consensus_scores[index] * Confidence_scores[index]))
#plot_boxplots(UMAP_cluster_scores)
print("getting TSNE cluster score averages")
TSNE_cluster_score_averages = {}
for i in list(TSNE_cluster_scores):
sum_scores = 0
for j in TSNE_cluster_scores[i]:
sum_scores += j
sum_scores /= len(TSNE_cluster_scores[i])
TSNE_cluster_score_averages[i] = sum_scores
print("getting UMAP cluster score averages")
UMAP_cluster_score_averages = {}
for i in list(UMAP_cluster_scores):
sum_scores = 0
for j in UMAP_cluster_scores[i]:
sum_scores += j
sum_scores /= len(UMAP_cluster_scores[i])
UMAP_cluster_score_averages[i] = sum_scores
def print_sorted_dicts(*dicts):
for d in dicts:
sorted_dict = {k: round(v, 3) for k, v in sorted(d.items())}
for k, v in sorted_dict.items():
print(f"{k}: {v}")
print() # Print a newline for better separation between dictionaries
print_sorted_dicts(TSNE_cluster_score_averages, UMAP_cluster_score_averages)
outdir = f"{PDB_FILES}{pdb_id.lower()}_max_RPF_NLDR_consensus_files/"
#outdir = f"{PDB_FILES}{pdb_id.lower()}_max_NLDR_consensus_files/"
if isdir(outdir) == False:
os.system('mkdir '+ outdir)
else:
os.system('rm -r ' + outdir)
os.system('mkdir '+ outdir)
max_consensus_files = []
print("getting TSNE cluster medoid structures")
for cluster in list(TSNE_cluster_scores):
max_score = 0
max_score_itr = -1
for itr, score in enumerate(TSNE_cluster_scores[cluster]):
if score > max_score:
max_score = score
max_score_itr = itr
max_score_file = TSNE_cluster_files[cluster][max_score_itr]
print("Max consensus for TSNE cluster " + str(cluster) + ' = ' + str(max_score) + '. PDB file = ' + max_score_file)
max_consensus_files.append(max_score_file)
consensus_file = outdir + max_score_file[max_score_file.rfind('/')+1:]
os.system('cp ' + max_score_file + ' ' + consensus_file)
original_file = get_original_file(consensus_file, pdb_id)
if not(exists(original_file)):
continue
original_data = parse_pdb(original_file)
#print(original_file)
#print(original_data)
#raise
new_max_scores_file = consensus_file[:consensus_file.rfind('.')] + '_plddt.pdb'
update_b_factors_all_chains(original_data, consensus_file, new_max_scores_file)
process_pdb_plddt(new_max_scores_file, 'tSNE_max' + str(cluster))
#max_consensus_files.append(max_score_file)
print("getting TSNE cluster medoid structures")
for cluster in list(UMAP_cluster_scores):
max_score = 0
max_score_itr = -1
for itr, score in enumerate(UMAP_cluster_scores[cluster]):
if score > max_score:
max_score = score
max_score_itr = itr
max_score_file = UMAP_cluster_files[cluster][max_score_itr]
print("Max consensus for UMAP cluster " + str(cluster) + ' = ' + str(max_score) + '. PDB file = ' + max_score_file)
max_consensus_files.append(max_score_file)
consensus_file = outdir + max_score_file[max_score_file.rfind('/')+1:]
os.system('cp ' + max_score_file + ' ' + consensus_file)
original_file = get_original_file(consensus_file, pdb_id)
if not(exists(original_file)):
continue
original_data = parse_pdb(original_file)
new_max_scores_file = consensus_file[:consensus_file.rfind('.')] + '_plddt.pdb'
update_b_factors_all_chains(original_data, consensus_file, new_max_scores_file)
process_pdb_plddt(new_max_scores_file, 'UMAP_max' + str(cluster))
#process_pdb(max_score_file, 'UMAP_max' + str(cluster))
#experimental_medoid_file = './experimental_structures/exp_' + pdb_id + '.pdb'
process_pdb(experimental_medoid_file, 'exp_' + pdb_id)
#os.system('python3 compress.py ' + outdir)
cmd.hide('everything', 'hydro')