Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/plot_CSP_hist.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
272 lines (236 sloc)
11.9 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# plot_CSP.py | |
from util import * | |
from collections import Counter | |
method = "MONTE" | |
CSmethod = "UCBShift" | |
while CSmethod not in ['SPARTA', 'UCBShift', 'ShiftX', 'consensus']: | |
CSmethod = input("What CS prediciton method to use? [SPARTA, UCBShift, ShiftX, consensus]") | |
data_source_file = './CSPRANK.csv' | |
parsed_data = pd.read_csv(data_source_file) | |
apos = [str(data) for data in parsed_data['apo_bmrb']] | |
apo_pdbs = [str(data) for data in parsed_data['apo_pdb']] | |
holos = [data.lower() for data in parsed_data['holo_pdb']] | |
well_defined_residues = [data for data in parsed_data['Well_Defined_Residues']] | |
match_sequences = [data for data in parsed_data['match_seq']] | |
def plot(ax, shifts, cutoffs, sequence, pref, consensus, pdb_id = ""): | |
cleaned_data = [max(0, x) for x in shifts] | |
# Define colors corresponding to significance levels | |
colors = ['gray', 'lightcoral', 'firebrick', 'red'] | |
labels = ['z=0', 'z=1', 'z=3'] | |
bar_colors = [] | |
for shift in cleaned_data: | |
if shift > cutoffs[2]: | |
bar_colors.append(colors[3]) | |
elif shift > cutoffs[1]: | |
bar_colors.append(colors[2]) | |
elif shift > cutoffs[0]: | |
bar_colors.append(colors[1]) | |
else: | |
bar_colors.append(colors[0]) | |
# Plot bar chart | |
if pref == 'AF2': | |
ax.set_xlabel('Sequence', fontsize=20) | |
ax.bar(range(len(cleaned_data)), cleaned_data, color=bar_colors, edgecolor='black')#, tick_label=sequence) | |
# Fill areas between horizontal lines | |
for i in range(len(cutoffs)): | |
ax.axhline(y=cutoffs[i], color='r', linestyle='-', label=labels[i]) | |
if i == 0: | |
ax.fill_between(range(-1, len(shifts) + 1), 0, cutoffs[i], color=colors[i], alpha=0.3) | |
else: | |
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[i-1], cutoffs[i], color=colors[i], alpha=0.3) | |
ax.axhline(y=cutoffs[-1], color='r', linestyle='-', label = labels[-1]) | |
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[-1], max(shifts) if max(shifts) > cutoffs[-1] else cutoffs[-1] + 1, color=colors[-1], alpha=0.3) | |
# Set x-axis to use the provided characters as labels | |
ax.set_xticks(range(len(sequence))) # Set x-ticks to correspond to index of each character | |
ax.set_xticklabels(sequence) # Set the actual label text | |
ax.set_xlim(-0.5, len(sequence) -0.5) | |
# Labels, title, and legend | |
#ax.set_xlabel('Character') | |
ax.set_ylim(0, max([shift for shift in shifts if shift < cutoffs[-1]]) + cutoffs[1] ) | |
ax.set_ylabel('CSP') | |
if pref == f'NMR Data for {pdb_id.upper()}': | |
ax.set_title(f"{pref}", fontsize=20) | |
else: | |
ax.set_title(f"{pref} CSPRank Score = {consensus:.2g}", fontsize=20) | |
def plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, ax, plots): | |
# Extract the three lists from the significances | |
# Initialize an empty list for colors | |
colors = [] | |
y_offset = 0#1.25 | |
sequences = [''.join(seq) for seq in sequences] | |
def merge_sequences(sequences): | |
# Find the shortest sequence to limit the search space | |
shortest_seq = min(sequences, key=len) | |
max_len = len(shortest_seq) | |
# Iterate over all possible substring lengths from longest to shortest | |
for length in range(max_len, 0, -1): | |
# Iterate over all possible starting points for the substring | |
for start in range(max_len - length + 1): | |
substring = shortest_seq[start:start + length] | |
if all(substring in seq for seq in sequences): | |
return substring | |
return "" | |
sequence = merge_sequences(sequences) | |
print(sequence) | |
shifts = [shifts[i][sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, shift in enumerate(shifts)] | |
significances = [significance[sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, significance in enumerate(significances)] | |
first_list, second_list, third_list = significances | |
plot(axes[0], shifts[0], cutoff_lists[0], sequence, file_prefs[0], consensuses[0], holo.upper()) | |
plot(axes[1], shifts[1], cutoff_lists[1], sequence, file_prefs[1], consensuses[1], holo.upper()) | |
plot(axes[2], shifts[2], cutoff_lists[2], sequence, file_prefs[2], consensuses[2], holo.upper()) | |
# Determine the color for each index | |
for i in range(len(sequence)): | |
if second_list[i] == first_list[i] and third_list[i] == first_list[i]: | |
colors.append('green') | |
elif second_list[i] == first_list[i] and third_list[i] != first_list[i]: | |
colors.append('yellow') | |
elif third_list[i] == first_list[i] and second_list[i] != first_list[i]: | |
colors.append('purple') | |
else: | |
colors.append('red') | |
if plots[0]: | |
# Create scatter plot with small dots at each x-axis tick | |
ax[0].scatter(range(len(sequence)), [y_offset] * len(sequence), color=colors, s=100*(100/len(sequence)), zorder=2) # s is the size of the dots | |
colors = [] | |
# Initialize an empty list for colors | |
colors = [] | |
# Determine the color for each index | |
for i in range(len(sequence)): | |
if second_list[i] == first_list[i] and third_list[i] == first_list[i]: | |
colors.append('green') | |
elif second_list[i] == first_list[i] and third_list[i] != first_list[i]: | |
colors.append('green') | |
elif third_list[i] == first_list[i] and second_list[i] != first_list[i]: | |
colors.append('red') | |
else: | |
colors.append('red') | |
if plots[1]: | |
# Create scatter plot with small dots at each x-axis tick | |
ax[1].scatter(range(len(sequence)), [y_offset] * len(sequence), color=colors, s=100*(100/len(sequence)), zorder=2) # s is the size of the dots | |
colors = [] | |
for i in range(len(sequence)): | |
if second_list[i] == first_list[i] and third_list[i] == first_list[i]: | |
colors.append('green') | |
elif second_list[i] == first_list[i] and third_list[i] != first_list[i]: | |
colors.append('red') | |
elif third_list[i] == first_list[i] and second_list[i] != first_list[i]: | |
colors.append('green') | |
else: | |
colors.append('red') | |
if plots[2]: | |
ax[2].scatter(range(len(sequence)), [y_offset] * len(sequence), color=colors, s=100*(100/len(sequence)), zorder=2) # s is the size of the dots | |
structures = ['NMR', 'AF2'] | |
#structures = ['NMR', 'AF2', 'AF3'] | |
z_scores = [0, 1, 3] | |
#fig, axes = plt.subplots(len(structures)+2, 1, figsize=(6, (len(structures)+2)*6)) | |
fig, axes = plt.subplots(len(structures)+1, 1, figsize=(18, (len(structures)+1)*7)) | |
# for holo in holos: | |
# pdb = holo.lower() | |
pdbs = input('Provide bound pdb ( e.g. "7jq8" ) ') | |
pdb = pdbs.strip() | |
holo = pdb.lower() | |
if True: | |
# list of lists of booleans. 1 if residue index is significant, 0 otherwise | |
significances = [] | |
sequences = [] | |
shifts = [] | |
cutoff_lists = [] | |
file_prefs = [] | |
consensuses = [] | |
apo = str(apos[holos.index(holo.lower())]).lower() | |
well_defined_res = str(well_defined_residues[holos.index(holo.lower())]) | |
match_seq = str(match_sequences[holos.index(holo.lower())]) | |
consensus = 0 | |
for structure in structures: | |
CSPs = [] | |
CSP_cutoff = -1 | |
cutoffs = [] | |
bound_file = "" | |
bound_seq = "" | |
file_pref = structure | |
if structure == 'NMR': | |
# load NMR w/ real shifts | |
structure = 'NMR_real' | |
file_pref = f'NMR Data for {holo.upper()}' | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod="REAL", structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(cutoffs) | |
file_prefs.append(file_pref) | |
consensuses.append(1) | |
#plot(axes[0], CSPs, cutoffs, bound_seq, file_pref, 1, holo.upper()) | |
# load NMR w/ pred shifts | |
cutoffs = [] | |
structure = 'NMR_pred' | |
file_pref = 'NMR' | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(cutoffs) | |
file_prefs.append(file_pref) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
#plot(axes[1], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
elif structure == 'AF2': | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(cutoffs) | |
file_prefs.append(file_pref) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF2") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
#plot(axes[2], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
elif structure == 'AF3': | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
#TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "AF3") | |
#F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
plot(axes[3], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
else: | |
print("received malformed structure selection: " + structure) | |
continue | |
#plot_significance_pattern(significances, sequences, axes, [True, True, True]) | |
plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, axes, [True, True, True]) | |
#plot_significance_pattern(significances, sequences, axes, [False, True, True]) | |
# plt.subplots_adjust(wspace=0.0, hspace=0.2) # Increase wspace and hspace as needed | |
plt.tight_layout() | |
plt.subplots_adjust(hspace=0.1) # Add space between plots in the same column | |
plt.savefig(f'./Figures/{pdb.upper()}_CSP_comparison_plots.png') | |
# plt.show() | |
plt.cla() | |
plt.clf() | |
plt.close() |