Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/plot_CSP_ES_one_hist.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
244 lines (206 sloc)
10.3 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# plot_CSP.py | |
from util import * | |
from collections import Counter | |
method = "MONTE" | |
CSmethod = "UCBShift" | |
while CSmethod not in ['SPARTA', 'UCBShift', 'ShiftX', 'consensus']: | |
CSmethod = input("What CS prediciton method to use? [SPARTA, UCBShift, ShiftX, consensus]") | |
data_source_file = './CSPRANK.csv' | |
parsed_data = pd.read_csv(data_source_file) | |
apos = [str(data) for data in parsed_data['apo_bmrb']] | |
apo_pdbs = [str(data) for data in parsed_data['apo_pdb']] | |
holos = [data.lower() for data in parsed_data['holo_pdb']] | |
well_defined_residues = [data for data in parsed_data['Well_Defined_Residues']] | |
match_sequences = [data for data in parsed_data['match_seq']] | |
def plot(ax, shifts, cutoffs, sequence, pref, consensus, pdb_id = ""): | |
cleaned_data = [max(0, x) for x in shifts] | |
# Define colors corresponding to significance levels | |
colors = ['gray', 'lightcoral', 'red', 'firebrick'] | |
labels = ['z=0', 'z=1', 'z=3'] | |
bar_colors = [] | |
for shift in cleaned_data: | |
if shift > cutoffs[2]: | |
bar_colors.append(colors[3]) | |
elif shift > cutoffs[1]: | |
bar_colors.append(colors[2]) | |
elif shift > cutoffs[0]: | |
bar_colors.append(colors[1]) | |
else: | |
bar_colors.append(colors[0]) | |
# Plot bar chart | |
if pref == 'AF2': | |
ax.set_xlabel('Sequence', fontsize=20) | |
ax.bar(range(len(cleaned_data)), cleaned_data, color=bar_colors, edgecolor='black')#, tick_label=sequence) | |
# Fill areas between horizontal lines | |
for i in range(len(cutoffs)): | |
ax.axhline(y=cutoffs[i], color='r', linestyle='-', label=labels[i]) | |
if i == 0: | |
ax.fill_between(range(-1, len(shifts) + 1), 0, cutoffs[i], color=colors[i], alpha=0.3) | |
else: | |
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[i-1], cutoffs[i], color=colors[i], alpha=0.3) | |
ax.axhline(y=cutoffs[-1], color='r', linestyle='-', label = labels[-1]) | |
ax.fill_between(range(-1, len(shifts) + 1), cutoffs[-1], max(shifts) if max(shifts) > cutoffs[-1] else cutoffs[-1] + 1, color=colors[-1], alpha=0.3) | |
# Set x-axis to use the provided characters as labels | |
ax.set_xticks(range(len(sequence))) # Set x-ticks to correspond to index of each character | |
ax.set_xticklabels(sequence) # Set the actual label text | |
ax.set_xlim(-0.5, len(sequence) -0.5) | |
# Labels, title, and legend | |
#ax.set_xlabel('Character') | |
ax.set_ylim(0, max([shift for shift in shifts if shift < cutoffs[-1]]) + cutoffs[1] ) | |
ax.set_ylabel('CSP') | |
if pref == f'NMR Data for {pdb_id.upper()}': | |
ax.set_title(f"{pref}", fontsize=20) | |
else: | |
ax.set_title(f"{pref} CSPRank Score = {consensus:.2g}", fontsize=20) | |
def plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, ax, plots): | |
# Extract the three lists from the significances | |
# Initialize an empty list for colors | |
colors = [] | |
y_offset = 0#1.25 | |
sequences = [''.join(seq) for seq in sequences] | |
def merge_sequences(sequences): | |
# Find the shortest sequence to limit the search space | |
shortest_seq = min(sequences, key=len) | |
max_len = len(shortest_seq) | |
# Iterate over all possible substring lengths from longest to shortest | |
for length in range(max_len, 0, -1): | |
# Iterate over all possible starting points for the substring | |
for start in range(max_len - length + 1): | |
substring = shortest_seq[start:start + length] | |
if all(substring in seq for seq in sequences): | |
return substring | |
return "" | |
sequence = merge_sequences(sequences) | |
print(sequence) | |
shifts = [shifts[i][sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, shift in enumerate(shifts)] | |
significances = [significance[sequences[i].find(sequence):sequences[i].find(sequence)+len(sequence)] for i, significance in enumerate(significances)] | |
first_list, second_list, third_list = significances | |
plot(axes, shifts[0], cutoff_lists[0], sequence, file_prefs[0], consensuses[0], holo.upper()) | |
# plot(axes[1], shifts[1], cutoff_lists[1], sequence, file_prefs[1], consensuses[1], holo.upper()) | |
# plot(axes[2], shifts[2], cutoff_lists[2], sequence, file_prefs[2], consensuses[2], holo.upper()) | |
# Determine the color for each index | |
for i in range(len(sequence)): | |
if second_list[i] == first_list[i] and third_list[i] == first_list[i]: | |
colors.append('green') | |
elif second_list[i] == first_list[i] and third_list[i] != first_list[i]: | |
colors.append('yellow') | |
elif third_list[i] == first_list[i] and second_list[i] != first_list[i]: | |
colors.append('blue') | |
else: | |
colors.append('red') | |
if plots[0]: | |
# Create scatter plot with small dots at each x-axis tick | |
axes.scatter(range(len(sequence)), [y_offset] * len(sequence), color=colors, s=100*(100/len(sequence)), zorder=2) # s is the size of the dots | |
# Save the plot to the specified directory | |
# plt.savefig(f'./Figures/{pdb.upper()}_CSP_ES_comparison_plot.png') | |
structures = ['NMR', 'ES'] | |
#structures = ['NMR', 'AF2', 'AF3'] | |
z_scores = [0, 1, 3] | |
fig, axes = plt.subplots(1, figsize=(20,5)) | |
pdbs = input('Provide bound pdb ( e.g. "7jq8" ) ') | |
pdb = pdbs.strip() | |
holo = pdb.lower() | |
# list of lists of booleans. 1 if residue index is significant, 0 otherwise | |
significances = [] | |
sequences = [] | |
shifts = [] | |
cutoff_lists = [] | |
file_prefs = [] | |
consensuses = [] | |
apo = str(apos[holos.index(holo.lower())]).lower() | |
well_defined_res = str(well_defined_residues[holos.index(holo.lower())]) | |
match_seq = str(match_sequences[holos.index(holo.lower())]) | |
consensus = 0 | |
for structure in structures: | |
CSPs = [] | |
CSP_cutoff = -1 | |
cutoffs = [] | |
bound_file = "" | |
bound_seq = "" | |
file_pref = structure | |
if structure == 'NMR': | |
# load NMR w/ real shifts | |
structure = 'NMR_real' | |
file_pref = f'NMR Data for {holo.upper()}' | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod="REAL", structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(cutoffs) | |
file_prefs.append(file_pref) | |
consensuses.append(1) | |
#plot(axes[0], CSPs, cutoffs, bound_seq, file_pref, 1, holo.upper()) | |
# load NMR w/ pred shifts | |
cutoffs = [] | |
structure = 'NMR_pred' | |
file_pref = 'NMR' | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure, match_seq=match_seq) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
shifts.append(CSPs) | |
cutoff_lists.append(cutoffs) | |
file_prefs.append(file_pref) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
#plot(axes[1], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
elif structure == "ES": | |
file_prefs.append(structure) | |
possible_bound_dirs = [PDB_FILES + dire + '/' for dire in listdir(PDB_FILES) if dire.find(holo.lower()) != -1 and isdir(PDB_FILES + dire)] | |
print("Select a directory from the following list:") | |
for i, dire in enumerate(possible_bound_dirs): | |
print(f"{i}: {dire}") | |
selected_index = int(input("Enter the number corresponding to your choice: ")) | |
bound_dir = possible_bound_dirs[selected_index] | |
pdb_files = [ bound_dir + f for f in listdir(bound_dir) if f.endswith('.pdb')] | |
basenames = [ f[f.rfind('/')+1:f.rfind('.')] for f in pdb_files ] | |
bound_file = find_medoid_structure(pdb_files) # get medoid model of ES ensemble | |
bound_file_basename = bound_file[bound_file.rfind('/')+1:bound_file.rfind('.')] # return basename so we know which CSpredictions to load | |
CSPs, CSP_cutoff, bound_seq = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, \ | |
structure_source='ES', match_seq=match_seq, basename=basenames) | |
CSP_below_thresh = [ C for C in CSPs if C < CSP_cutoff and C > 0 ] | |
cutoffs = [] | |
for z_score in z_scores: | |
cutoffs.append(calculate_z_score_threshold(CSP_below_thresh, z_score)) | |
cutoff_lists.append(cutoffs) | |
shifts.append(CSPs) | |
sigs = [] | |
for CSP in CSPs: | |
if CSP >= cutoffs[0]: | |
sigs.append(True) | |
else: | |
sigs.append(False) | |
significances.append(sigs) | |
sequences.append(bound_seq) | |
TP, FP, FN, TN = get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "ES") | |
F, MCC, consensus = get_F_MCC_cons(TP, FP, FN, TN) | |
consensuses.append(consensus) | |
# plot(axes[2], CSPs, cutoffs, bound_seq, file_pref, consensus) | |
else: | |
print("received malformed structure selection: " + structure) | |
continue | |
#plot_significance_pattern(significances, sequences, axes, [True, True, True]) | |
plot_significance_pattern(significances, sequences, shifts, cutoff_lists, file_prefs, consensuses, axes, [True, True, True]) | |
#plot_significance_pattern(significances, sequences, axes, [False, True, True]) | |
plt.tight_layout() | |
plt.subplots_adjust(hspace=0.1) # Add space between plots in the same column | |
plt.savefig(f'./Figures/{pdb.upper()}_ES_CSP_comparison_plot.png') | |
plt.show() |