Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# plot_CSP.py
from util import *
from collections import Counter
method = "MONTE"
CSmethod = "UCBShift"
while CSmethod not in ['SPARTA', 'UCBShift', 'ShiftX', 'consensus']:
CSmethod = input("What CS prediciton method to use? [SPARTA, UCBShift, ShiftX, consensus]")
data_source_file = './CSPRANK.csv'
parsed_data = pd.read_csv(data_source_file)
apos = [str(data) for data in parsed_data['apo_bmrb']]
apo_pdbs = [str(data) for data in parsed_data['apo_pdb']]
holos = [data.lower() for data in parsed_data['holo_pdb']]
well_defined_residues = [data for data in parsed_data['Well_Defined_Residues']]
match_sequences = [data for data in parsed_data['match_seq']]
def plot(ax, shifts, cutoffs, sequence, pref, consensus, pdb_id = ""):
cleaned_data = [max(0, x) for x in shifts]
# Define colors corresponding to significance levels
colors = ['gray', 'lightcoral', 'red', 'firebrick']
labels = ['z=0', 'z=1', 'z=3']
bar_colors = []
for shift in cleaned_data:
if shift > cutoffs[2]:
bar_colors.append(colors[3])
elif shift > cutoffs[1]:
bar_colors.append(colors[2])
elif shift > cutoffs[0]:
bar_colors.append(colors[1])
else:
bar_colors.append(colors[0])
# Plot bar chart
if pref == 'AF2':
ax.set_xlabel('Sequence', fontsize=20)
ax.bar(range(len(cleaned_data)), cleaned_data, color=bar_colors, edgecolor='black')#, tick_label=sequence)
# Fill areas between horizontal lines
for i in range(len(cutoffs)):
ax.axhline(y=cutoffs[i], color='r', linestyle='-', label=labels[i])
if i == 0:
ax.fill_between(range(-1, len(shifts) + 1), 0, cutoffs[i], color=colors[i], alpha=0.3)
else:
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[i-1], cutoffs[i], color=colors[i], alpha=0.3)
# Add label for each horizontal line
ax.text(1, cutoffs[i] - 0.1, labels[i], color='black', va='bottom', ha='right', fontweight='bold')
ax.text(len(shifts)-1, cutoffs[i] - 0.1, labels[i], color='black', va='bottom', ha='right', fontweight='bold')
ax.axhline(y=cutoffs[-1], color='r', linestyle='-', label = labels[-1])
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[-1], max(shifts) if max(shifts) > cutoffs[-1] else cutoffs[-1] + 1, color=colors[-1], alpha=0.3)
# ax.text(len(shifts) - 0.5, cutoffs[-1] - 0.1, labels[-1], color='black', va='bottom', ha='right', fontweight='bold')
# Set x-axis to use the provided characters as labels
ax.set_xticks(range(len(sequence))) # Set x-ticks to correspond to index of each character
ax.set_xticklabels(sequence) # Set the actual label text
ax.set_xlim(-0.5, len(sequence) -0.5)
# Labels, title, and legend
#ax.set_xlabel('Character')
ax.set_ylim(0, max([shift for shift in shifts if shift < cutoffs[-1]]) + cutoffs[1] )
ax.set_ylabel('CSP', fontsize=20)
# ax.set_title(f"{pref}", fontsize=20, fontweight='bold')
def plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, ax, plots):
# Extract the three lists from the significances
# Initialize an empty list for colors
colors = []
y_offset = 0#1.25
sequences = [''.join(seq) for seq in sequences]
def merge_sequences(sequences):
# Find the shortest sequence to limit the search space
shortest_seq = min(sequences, key=len)
max_len = len(shortest_seq)
# Iterate over all possible substring lengths from longest to shortest
for length in range(max_len, 0, -1):
# Iterate over all possible starting points for the substring
for start in range(max_len - length + 1):
substring = shortest_seq[start:start + length]
if all(substring in seq for seq in sequences):
return substring
return ""
sequence = merge_sequences(sequences)
print(sequence)
shifts = [shifts[i][sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, shift in enumerate(shifts)]
significances = [significance[sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, significance in enumerate(significances)]
first_list, second_list, third_list = significances
plot(axes, shifts[0], cutoff_lists[0], sequence, file_prefs[0], consensuses[0], holo.upper())
# plot(axes[1], shifts[1], cutoff_lists[1], sequence, file_prefs[1], consensuses[1], holo.upper())
# plot(axes[2], shifts[2], cutoff_lists[2], sequence, file_prefs[2], consensuses[2], holo.upper())
# Determine the color for each index
for i in range(len(sequence)):
if second_list[i] == first_list[i] and third_list[i] == first_list[i]:
colors.append('green')
elif second_list[i] == first_list[i] and third_list[i] != first_list[i]:
colors.append('yellow')
elif third_list[i] == first_list[i] and second_list[i] != first_list[i]:
colors.append('purple')
else:
colors.append('red')
if plots[0]:
# Create scatter plot with small dots at each x-axis tick
axes.scatter(range(len(sequence)), [y_offset] * len(sequence), color=colors, s=100*(100/len(sequence)), zorder=2) # s is the size of the dots
# Save the plot to the specified directory
# plt.savefig(f'./Figures/{pdb.upper()}_CSP_comparison_plot.png')
structures = ['NMR', 'AF2']
#structures = ['NMR', 'AF2', 'AF3']
z_scores = [0, 1, 3]
fig, axes = plt.subplots(1, figsize=(20,5))
# for holo in holos:
# pdb = holo.lower()
pdbs = input('Provide bound pdb ( e.g. "7jq8" ) ')
pdb = pdbs.strip()
holo = pdb.lower()
if True:
# list of lists of booleans. 1 if residue index is significant, 0 otherwise
significances = []
sequences = []
shifts = []
cutoff_lists = []
file_prefs = []
consensuses = []
apo = str(apos[holos.index(holo.lower())]).lower()
well_defined_res = str(well_defined_residues[holos.index(holo.lower())])
match_seq = str(match_sequences[holos.index(holo.lower())])
consensus = 0
for structure in structures:
CSPs = []
CSP_cutoff = -1
cutoffs = []
bound_file = ""
bound_seq = ""
file_pref = structure
if structure == 'NMR':
# load NMR w/ real shifts
structure = 'NMR_real'
file_pref = f'NMR Data for {holo.upper()}'
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod="REAL", structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ]
for z_score in z_scores:
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in CSPs:
if CSP >= cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(bound_seq)
shifts.append(CSPs)
cutoff_lists.append(cutoffs)
file_prefs.append(file_pref)
consensuses.append(1)
#plot(axes[0], CSPs, cutoffs, bound_seq, file_pref, 1, holo.upper())
# load NMR w/ pred shifts
cutoffs = []
structure = 'NMR_pred'
file_pref = 'NMR'
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ]
for z_score in z_scores:
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in CSPs:
if CSP >= cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(bound_seq)
shifts.append(CSPs)
cutoff_lists.append(cutoffs)
file_prefs.append(file_pref)
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR")
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN)
consensuses.append(consensus)
#plot(axes[1], CSPs, cutoffs, bound_seq, file_pref, consensus)
elif structure == 'AF2':
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ]
for z_score in z_scores:
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
sigs = []
for CSP in CSPs:
if CSP >= cutoffs[0]:
sigs.append(True)
else:
sigs.append(False)
significances.append(sigs)
sequences.append(bound_seq)
shifts.append(CSPs)
cutoff_lists.append(cutoffs)
file_prefs.append(file_pref)
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF2")
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN)
consensuses.append(consensus)
#plot(axes[2], CSPs, cutoffs, bound_seq, file_pref, consensus)
elif structure == 'AF3':
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq)
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ]
for z_score in z_scores:
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score))
#TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF3")
#F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN)
plot(axes[3], CSPs, cutoffs, bound_seq, file_pref, consensus)
else:
print("received malformed structure selection: " + structure)
continue
#plot_significance_pattern(significances, sequences, axes, [True, True, True])
plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, axes, [True, True, True])
#plot_significance_pattern(significances, sequences, axes, [False, True, True])
plt.subplots_adjust(wspace=0.4, hspace=0.6) # Increase wspace and hspace as needed
plt.savefig(f'./Figures/{pdb.upper()}_CSP_AF2_NMR_plot.png')
plt.show()
# plt.cla()
# plt.clf()
# plt.close()