Skip to content
Permalink
b9fe25bf2d
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
176 lines (150 sloc) 6.94 KB
# plot_CSP.py
from util import *
from collections import Counter
from scipy.stats import spearmanr
method = "MONTE"
CSmethod = "UCBShift"
while CSmethod not in ['SPARTA', 'UCBShift', 'ShiftX', 'consensus']:
CSmethod = input("What CS prediciton method to use? [SPARTA, UCBShift, ShiftX, consensus]")
data_source_file = './CSPRANK.csv'
parsed_data = pd.read_csv(data_source_file)
apos = [str(data) for data in parsed_data['apo_bmrb']]
apo_pdbs = [str(data) for data in parsed_data['apo_pdb']]
holos = [data.lower() for data in parsed_data['holo_pdb']]
well_defined_residues = [data for data in parsed_data['Well_Defined_Residues']]
match_sequences = [data for data in parsed_data['match_seq']]
def scatter_plot(ax, real_shifts, real_cutoffs, real_sequence, pred_shifts, pred_cutoffs, pred_sequence, label):
def merge_sequences(sequences):
# Find the shortest sequence to limit the search space
shortest_seq = min(sequences, key=len)
max_len = len(shortest_seq)
# Iterate over all possible substring lengths from longest to shortest
for length in range(max_len, 0, -1):
# Iterate over all possible starting points for the substring
for start in range(max_len - length + 1):
substring = shortest_seq[start:start + length]
if all(substring in seq for seq in sequences):
return start, start+length
return substring
return ""
real_seq, pred_seq = align(real_sequence, pred_sequence)
start, end = merge_sequences([real_seq, pred_seq])
real_seq = real_seq[start:end]
pred_seq = pred_seq[start:end]
real_shifts = real_shifts[start:end]
pred_shifts = pred_shifts[start:end]
print(real_seq)
print(pred_seq)
real_shifts = [max(0, x) for x in real_shifts]
pred_shifts = [max(0, x) for x in pred_shifts]
# Calculate Spearman correlation coefficient
scc, _ = spearmanr(real_shifts, pred_shifts)
# Add text to the plot
ax.text(0.05, 0.95, f'SCC: {scc:.2f}', transform=ax.transAxes, fontsize=12, verticalalignment='top')
# Plot the real shifts
ax.scatter(real_shifts, pred_shifts, c='blue', label='Real vs Predicted Shifts')
for cutoff in real_cutoffs:
ax.axhline(y=cutoff, color='red', linestyle='--', linewidth=0.5)
for cutoff in pred_cutoffs:
ax.axvline(x=cutoff, color='green', linestyle='--', linewidth=0.5)
ax.set_xlabel('Real Shifts')
ax.set_ylabel('Predicted Shifts')
ax.legend()
structures = ['NMR', 'AF2']
#structures = ['NMR', 'AF2', 'AF3']
z_scores = [0, 1, 3]
#fig, axes = plt.subplots(len(structures)+2, 1, figsize=(6, (len(structures)+2)*6))
fig, axes = plt.subplots(len(structures), 1, figsize=(6, (len(structures))*6))
pdbs = input('Provide bound pdb ( e.g. "7jq8" ) ')
pdb = pdbs.strip()
holo = pdb.lower()
# list of lists of booleans. 1 if residue index is significant, 0 otherwise
significances = []
sequences = []
shifts = []
cutoff_lists = []
file_prefs = []
consensuses = []
apo = str(apos[holos.index(holo.lower())]).lower()
well_defined_res = str(well_defined_residues[holos.index(holo.lower())])
match_seq = str(match_sequences[holos.index(holo.lower())])
consensus = 0
for structure in structures:
CSPs = []
CSP_cutoff = -1
bound_file = ""
bound_seq = ""
file_pref = structure
if structure == 'NMR':
# load NMR w/ real shifts
structure = 'NMR_real'
file_pref = f'NMR Data for {holo.upper()}'
real_CSPs, real_CSP_cutoff, real_bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod="REAL", structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in real_CSPs if C < real_CSP_cutoff and C > 0 ]
real_cutoffs = []
for z_score in z_scores:
real_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in real_CSPs:
if CSP >= real_cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(real_bound_seq)
shifts.append(real_CSPs)
cutoff_lists.append(real_cutoffs)
file_prefs.append(file_pref)
consensuses.append(1)
#plot(axes[0], CSPs, cutoffs, bound_seq, file_pref, 1, holo.upper())
# load NMR w/ pred shifts
pred_cutoffs = []
structure = 'NMR_pred'
file_pref = 'NMR'
pred_CSPs, pred_CSP_cutoff, pred_bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in pred_CSPs if C < pred_CSP_cutoff and C > 0 ]
for z_score in z_scores:
pred_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in pred_CSPs:
if CSP >= pred_cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(pred_bound_seq)
shifts.append(pred_CSPs)
cutoff_lists.append(pred_cutoffs)
file_prefs.append(file_pref)
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR")
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN)
consensuses.append(consensus)
scatter_plot(axes[0], real_CSPs, real_cutoffs, real_bound_seq, pred_CSPs, pred_cutoffs, pred_bound_seq, f"{pdb} NMR vs NMR pred")
#plot(axes[1], CSPs, cutoffs, bound_seq, file_pref, consensus)
elif structure == 'AF2':
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ]
AF2_cutoffs = []
for z_score in z_scores:
AF2_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in CSPs:
if CSP >= AF2_cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(bound_seq)
shifts.append(CSPs)
cutoff_lists.append(AF2_cutoffs)
file_prefs.append(file_pref)
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF2")
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN)
consensuses.append(consensus)
#plot(axes[2], CSPs, cutoffs, bound_seq, file_pref, consensus)
scatter_plot(axes[1], real_CSPs, real_cutoffs, real_bound_seq, CSPs, AF2_cutoffs, bound_seq, f"{pdb} NMR vs AF2 pred")
else:
print("received malformed structure selection: " + structure)
continue
plt.subplots_adjust(wspace=0.4, hspace=0.6) # Increase wspace and hspace as needed
plt.show()