Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/plot_tm_dockq_es_rms.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
269 lines (230 sloc)
10.2 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from tqdm import tqdm | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import os | |
from os import listdir | |
from os.path import isfile, join | |
def determine_special_case(holo_model_path, pdb_id): | |
"""Determines the color and label for each data point based on model type""" | |
if 'exp_' + pdb_id.lower() in holo_model_path: | |
return 'green' | |
elif 'comp_' + pdb_id.lower() in holo_model_path: | |
return 'cyan' | |
elif 'v3_' in holo_model_path and 'dropout' in holo_model_path: | |
return 'blue' | |
elif 'v2_' in holo_model_path and 'dropout' in holo_model_path: | |
return 'pink' | |
elif 'dropout' in holo_model_path: | |
return 'red' | |
elif 'v2_' in holo_model_path: | |
return 'purple' | |
elif 'v3_' in holo_model_path: | |
return 'purple' | |
elif 'notemplate' in holo_model_path: | |
return 'orange' | |
elif 'multimer' in holo_model_path: | |
return 'yellow' | |
else: | |
return 'gray' | |
# Define colors and their labels | |
colors = { | |
'green': 'NMR', | |
'cyan': 'Baseline AF2', | |
'blue': 'AFS v3', | |
'pink': 'AFS v2', | |
'red': 'AFS v1', | |
'purple': 'AFS2 v2', | |
'orange': 'AF ALT', | |
'yellow': 'AFS2 v1/3', | |
'gray': 'NA' | |
} | |
def generate_bayes_components_plot(pdb_id): | |
""" | |
Generate Bayesian analysis plots for a given PDB ID | |
Args: | |
pdb_id: PDB identifier | |
""" | |
print(f"Processing {pdb_id}...") | |
# Get CSP rank score file path | |
csp_rank_score_file = f'./CSP_Rank_Scores/CSP_{pdb_id.lower()}_CSpred.csv' | |
# Read the CSP rank scores file | |
try: | |
df = pd.read_csv(csp_rank_score_file) | |
except Exception as e: | |
print(f"Error reading CSP rank scores file for {pdb_id}: {e}") | |
return | |
# Add special cases column | |
df['special_cases'] = df['holo_model_path'].apply(lambda x: determine_special_case(x, pdb_id)) | |
# Create figure with 12 subplots in 3 rows, 4 columns | |
fig, ((ax1, ax2, ax7, ax8), (ax3, ax4, ax9, ax10), (ax5, ax6, ax11, ax12)) = plt.subplots(3, 4, figsize=(32, 24)) | |
# Add title with PDB ID | |
fig.suptitle(f'Bayesian Analysis for {pdb_id.upper()}\n', fontsize=24, y=0.95, weight='bold') | |
# Add Bayes equation on left margin | |
fig.text(0.5, 0.92, 'P(model|data) = P(model) * P(data|model)', | |
fontsize=18, ha='center', weight='bold') | |
# Top row: Posterior plots | |
# Plot 1: Bayes Score vs TM Score (left) | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax1.scatter(subset['tm_score'], subset['consensus'] * subset['Confidence'], | |
alpha=0.5, label=label, color=case) | |
ax1.set_ylabel('P(model|data)\nBayes Score', fontsize=14, weight='bold') | |
ax1.set_xlabel('TM Score', fontsize=14, weight='bold') | |
ax1.set_title('Posterior (TM)', fontsize=16, weight='bold') | |
ax1.grid(True, linestyle='--', alpha=0.7) | |
ax1.set_xlim(0, 1) | |
ax1.set_ylim(0, 1) | |
ax1.tick_params(labelsize=12) | |
# Plot 2: Bayes Score vs DockQ Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax2.scatter(subset['dockq_score'], subset['consensus'] * subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax2.set_ylabel('P(model|data)\nBayes Score', fontsize=14, weight='bold') | |
ax2.set_xlabel('DockQ Score', fontsize=14, weight='bold') | |
ax2.set_title('Posterior (DockQ)', fontsize=16, weight='bold') | |
ax2.grid(True, linestyle='--', alpha=0.7) | |
ax2.set_xlim(0, 1) | |
ax2.set_ylim(0, 1) | |
ax2.tick_params(labelsize=12) | |
# Plot 7: Bayes Score vs iRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax7.scatter(subset['irms_score'], subset['consensus'] * subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax7.set_ylabel('P(model|data)\nBayes Score', fontsize=14, weight='bold') | |
ax7.set_xlabel('iRMS Score', fontsize=14, weight='bold') | |
ax7.set_title('Posterior (iRMS)', fontsize=16, weight='bold') | |
ax7.grid(True, linestyle='--', alpha=0.7) | |
ax7.set_xlim(0, max(df['irms_score'])) | |
ax7.set_ylim(0, 1) | |
ax7.tick_params(labelsize=12) | |
# Plot 8: Bayes Score vs LRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax8.scatter(subset['lrms_score'], subset['consensus'] * subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax8.set_ylabel('P(model|data)\nBayes Score', fontsize=14, weight='bold') | |
ax8.set_xlabel('LRMS Score', fontsize=14, weight='bold') | |
ax8.set_title('Posterior (LRMS)', fontsize=16, weight='bold') | |
ax8.grid(True, linestyle='--', alpha=0.7) | |
ax8.set_xlim(0, max(df['lrms_score'])) | |
ax8.set_ylim(0, 1) | |
ax8.tick_params(labelsize=12) | |
# Middle row: Prior plots | |
# Plot 3: Confidence vs TM Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax3.scatter(subset['tm_score'], subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax3.set_ylabel('P(model)\nConfidence', fontsize=14, weight='bold') | |
ax3.set_xlabel('TM Score', fontsize=14, weight='bold') | |
ax3.set_title('Prior (TM)', fontsize=16, weight='bold') | |
ax3.grid(True, linestyle='--', alpha=0.7) | |
ax3.set_xlim(0, 1) | |
ax3.set_ylim(0, 1) | |
ax3.tick_params(labelsize=12) | |
# Plot 4: Confidence vs DockQ Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax4.scatter(subset['dockq_score'], subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax4.set_ylabel('P(model)\nConfidence', fontsize=14, weight='bold') | |
ax4.set_xlabel('DockQ Score', fontsize=14, weight='bold') | |
ax4.set_title('Prior (DockQ)', fontsize=16, weight='bold') | |
ax4.grid(True, linestyle='--', alpha=0.7) | |
ax4.set_xlim(0, 1) | |
ax4.set_ylim(0, 1) | |
ax4.tick_params(labelsize=12) | |
# Plot 9: Confidence vs iRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax9.scatter(subset['irms_score'], subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax9.set_ylabel('P(model)\nConfidence', fontsize=14, weight='bold') | |
ax9.set_xlabel('iRMS Score', fontsize=14, weight='bold') | |
ax9.set_title('Prior (iRMS)', fontsize=16, weight='bold') | |
ax9.grid(True, linestyle='--', alpha=0.7) | |
ax9.set_xlim(0, max(df['irms_score'])) | |
ax9.set_ylim(0, 1) | |
ax9.tick_params(labelsize=12) | |
# Plot 10: Confidence vs LRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax10.scatter(subset['lrms_score'], subset['Confidence'], | |
alpha=0.5, label=None, color=case) | |
ax10.set_ylabel('P(model)\nConfidence', fontsize=14, weight='bold') | |
ax10.set_xlabel('LRMS Score', fontsize=14, weight='bold') | |
ax10.set_title('Prior (LRMS)', fontsize=16, weight='bold') | |
ax10.grid(True, linestyle='--', alpha=0.7) | |
ax10.set_xlim(0, max(df['lrms_score'])) | |
ax10.set_ylim(0, 1) | |
ax10.tick_params(labelsize=12) | |
# Bottom row: Likelihood plots | |
# Plot 5: Consensus vs TM Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax5.scatter(subset['tm_score'], subset['consensus'], | |
alpha=0.5, label=None, color=case) | |
ax5.set_ylabel('P(data|model)\nConsensus Score', fontsize=14, weight='bold') | |
ax5.set_xlabel('TM Score', fontsize=14, weight='bold') | |
ax5.set_title('Likelihood (TM)', fontsize=16, weight='bold') | |
ax5.grid(True, linestyle='--', alpha=0.7) | |
ax5.set_xlim(0, 1) | |
ax5.set_ylim(0, 1) | |
ax5.tick_params(labelsize=12) | |
# Plot 6: Consensus vs DockQ Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax6.scatter(subset['dockq_score'], subset['consensus'], | |
alpha=0.5, label=None, color=case) | |
ax6.set_ylabel('P(data|model)\nConsensus Score', fontsize=14, weight='bold') | |
ax6.set_xlabel('DockQ Score', fontsize=14, weight='bold') | |
ax6.set_title('Likelihood (DockQ)', fontsize=16, weight='bold') | |
ax6.grid(True, linestyle='--', alpha=0.7) | |
ax6.set_xlim(0, 1) | |
ax6.set_ylim(0, 1) | |
ax6.tick_params(labelsize=12) | |
# Plot 11: Consensus vs iRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax11.scatter(subset['irms_score'], subset['consensus'], | |
alpha=0.5, label=None, color=case) | |
ax11.set_ylabel('P(data|model)\nConsensus Score', fontsize=14, weight='bold') | |
ax11.set_xlabel('iRMS Score', fontsize=14, weight='bold') | |
ax11.set_title('Likelihood (iRMS)', fontsize=16, weight='bold') | |
ax11.grid(True, linestyle='--', alpha=0.7) | |
ax11.set_xlim(0, max(df['irms_score'])) | |
ax11.set_ylim(0, 1) | |
ax11.tick_params(labelsize=12) | |
# Plot 12: Consensus vs LRMS Score | |
for case, label in colors.items(): | |
subset = df[df['special_cases'] == case] | |
ax12.scatter(subset['lrms_score'], subset['consensus'], | |
alpha=0.5, label=None, color=case) | |
ax12.set_ylabel('P(data|model)\nConsensus Score', fontsize=14, weight='bold') | |
ax12.set_xlabel('LRMS Score', fontsize=14, weight='bold') | |
ax12.set_title('Likelihood (LRMS)', fontsize=16, weight='bold') | |
ax12.grid(True, linestyle='--', alpha=0.7) | |
ax12.set_xlim(0, max(df['lrms_score'])) | |
ax12.set_ylim(0, 1) | |
ax12.tick_params(labelsize=12) | |
# Adjust layout and add single legend | |
plt.subplots_adjust(right=0.85, hspace=0.3) # Make room for legend and adjust spacing | |
legend = fig.legend(title='Model Types', loc='center right', | |
bbox_to_anchor=(0.98, 0.5), fontsize=12) | |
legend.get_title().set_fontsize(14) | |
legend.get_title().set_weight('bold') | |
# Save the plot | |
plt.savefig(f'./Figures/bayes_components_{pdb_id.lower()}.png', | |
bbox_inches='tight', dpi=300) | |
plt.close() | |
print(f"Completed processing {pdb_id}") | |
if __name__ == '__main__': | |
# List of PDB IDs to process | |
pdb_ids = ['2jw1', '2lgk', '2lsk', '2law', '2kwv', '2mnu', '2mps', | |
'5tp6', '5urn', '7ovc', '7jyn', '7jq8', '6h8c'] | |
# Generate plots for each PDB ID | |
for pdb_id in pdb_ids: | |
generate_bayes_components_plot(pdb_id) |