Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/plot_CSP_scatter.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
176 lines (150 sloc)
6.94 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# plot_CSP.py | |
from util import * | |
from collections import Counter | |
from scipy.stats import spearmanr | |
method = "MONTE" | |
CSmethod = "UCBShift" | |
while CSmethod not in ['SPARTA', 'UCBShift', 'ShiftX', 'consensus']: | |
CSmethod = input("What CS prediciton method to use? [SPARTA, UCBShift, ShiftX, consensus]") | |
data_source_file = './CSPRANK.csv' | |
parsed_data = pd.read_csv(data_source_file) | |
apos = [str(data) for data in parsed_data['apo_bmrb']] | |
apo_pdbs = [str(data) for data in parsed_data['apo_pdb']] | |
holos = [data.lower() for data in parsed_data['holo_pdb']] | |
well_defined_residues = [data for data in parsed_data['Well_Defined_Residues']] | |
match_sequences = [data for data in parsed_data['match_seq']] | |
def scatter_plot(ax, real_shifts, real_cutoffs, real_sequence, pred_shifts, pred_cutoffs, pred_sequence, label): | |
def merge_sequences(sequences): | |
# Find the shortest sequence to limit the search space | |
shortest_seq = min(sequences, key=len) | |
max_len = len(shortest_seq) | |
# Iterate over all possible substring lengths from longest to shortest | |
for length in range(max_len, 0, -1): | |
# Iterate over all possible starting points for the substring | |
for start in range(max_len - length + 1): | |
substring = shortest_seq[start:start + length] | |
if all(substring in seq for seq in sequences): | |
return start, start+length | |
return substring | |
return "" | |
real_seq, pred_seq = align(real_sequence, pred_sequence) | |
start, end = merge_sequences([real_seq, pred_seq]) | |
real_seq = real_seq[start:end] | |
pred_seq = pred_seq[start:end] | |
real_shifts = real_shifts[start:end] | |
pred_shifts = pred_shifts[start:end] | |
print(real_seq) | |
print(pred_seq) | |
real_shifts = [max(0, x) for x in real_shifts] | |
pred_shifts = [max(0, x) for x in pred_shifts] | |
# Calculate Spearman correlation coefficient | |
scc, _ = spearmanr(real_shifts, pred_shifts) | |
# Add text to the plot | |
ax.text(0.05, 0.95, f'SCC: {scc:.2f}', transform=ax.transAxes, fontsize=12, verticalalignment='top') | |
# Plot the real shifts | |
ax.scatter(real_shifts, pred_shifts, c='blue', label='Real vs Predicted Shifts') | |
for cutoff in real_cutoffs: | |
ax.axhline(y=cutoff, color='red', linestyle='--', linewidth=0.5) | |
for cutoff in pred_cutoffs: | |
ax.axvline(x=cutoff, color='green', linestyle='--', linewidth=0.5) | |
ax.set_xlabel('Real Shifts') | |
ax.set_ylabel('Predicted Shifts') | |
ax.legend() | |
structures = ['NMR', 'AF2'] | |
#structures = ['NMR', 'AF2', 'AF3'] | |
z_scores = [0, 1, 3] | |
#fig, axes = plt.subplots(len(structures)+2, 1, figsize=(6, (len(structures)+2)*6)) | |
fig, axes = plt.subplots(len(structures), 1, figsize=(6, (len(structures))*6)) | |
pdbs = input('Provide bound pdb ( e.g. "7jq8" ) ') | |
pdb = pdbs.strip() | |
holo = pdb.lower() | |
# list of lists of booleans. 1 if residue index is significant, 0 otherwise | |
significances = [] | |
sequences = [] | |
shifts = [] | |
cutoff_lists = [] | |
file_prefs = [] | |
consensuses = [] | |
apo = str(apos[holos.index(holo.lower())]).lower() | |
well_defined_res = str(well_defined_residues[holos.index(holo.lower())]) | |
match_seq = str(match_sequences[holos.index(holo.lower())]) | |
consensus = 0 | |
for structure in structures: | |
CSPs = [] | |
CSP_cutoff = -1 | |
bound_file = "" | |
bound_seq = "" | |
file_pref = structure | |
if structure == 'NMR': | |
# load NMR w/ real shifts | |
structure = 'NMR_real' | |
file_pref = f'NMR Data for {holo.upper()}' | |
real_CSPs, real_CSP_cutoff, real_bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod="REAL", structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in real_CSPs if C < real_CSP_cutoff and C > 0 ] | |
real_cutoffs = [] | |
for z_score in z_scores: | |
real_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in real_CSPs: | |
if CSP >= real_cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(real_bound_seq) | |
shifts.append(real_CSPs) | |
cutoff_lists.append(real_cutoffs) | |
file_prefs.append(file_pref) | |
consensuses.append(1) | |
#plot(axes[0], CSPs, cutoffs, bound_seq, file_pref, 1, holo.upper()) | |
# load NMR w/ pred shifts | |
pred_cutoffs = [] | |
structure = 'NMR_pred' | |
file_pref = 'NMR' | |
pred_CSPs, pred_CSP_cutoff, pred_bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in pred_CSPs if C < pred_CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
pred_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in pred_CSPs: | |
if CSP >= pred_cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(pred_bound_seq) | |
shifts.append(pred_CSPs) | |
cutoff_lists.append(pred_cutoffs) | |
file_prefs.append(file_pref) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
scatter_plot(axes[0], real_CSPs, real_cutoffs, real_bound_seq, pred_CSPs, pred_cutoffs, pred_bound_seq, f"{pdb} NMR vs NMR pred") | |
#plot(axes[1], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
elif structure == 'AF2': | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
AF2_cutoffs = [] | |
for z_score in z_scores: | |
AF2_cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= AF2_cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(AF2_cutoffs) | |
file_prefs.append(file_pref) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF2") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
#plot(axes[2], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
scatter_plot(axes[1], real_CSPs, real_cutoffs, real_bound_seq, CSPs, AF2_cutoffs, bound_seq, f"{pdb} NMR vs AF2 pred") | |
else: | |
print("received malformed structure selection: " + structure) | |
continue | |
plt.subplots_adjust(wspace=0.4, hspace=0.6) # Increase wspace and hspace as needed | |
plt.show() |