Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/get_CS_AF_ensemble.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
162 lines (135 sloc)
5.57 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import csv | |
import ast | |
import os | |
from pathlib import Path | |
def parse_list(value): | |
"""Parse string representations of lists into Python lists.""" | |
try: | |
return ast.literal_eval(value) | |
except (ValueError, SyntaxError): | |
return value | |
def parse_csv(file_name): | |
"""Parse a CSV file into a list of dictionaries with proper type conversion.""" | |
data = [] | |
try: | |
with open(file_name, newline='') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
data.append({k: parse_list(v) for k, v in row.items()}) | |
return data | |
except FileNotFoundError: | |
print(f"Error: Could not find file {file_name}") | |
return None | |
except Exception as e: | |
print(f"Error parsing {file_name}: {str(e)}") | |
return None | |
def get_pdb_id(pdb_file): | |
"""Extract PDB ID from filename.""" | |
# Remove path and extension | |
base_name = os.path.basename(pdb_file) | |
pdb_id = os.path.splitext(base_name)[0].lower() | |
# If it starts with 'exp_', remove it | |
if pdb_id.startswith('exp_'): | |
pdb_id = pdb_id[4:] | |
return pdb_id | |
def calculate_entry_scc(entry): | |
"""Calculate average SCC for a single entry.""" | |
try: | |
n_scc = float(entry.get('N_SCC', 0)) | |
h_scc = float(entry.get('H_SCC', 0)) | |
return (n_scc + h_scc) / 2 | |
except (ValueError, TypeError): | |
return 0 | |
def calculate_confidence(csp_entry): | |
"""Calculate confidence score (iptm*0.8 + ptm*0.2) for a CSP entry.""" | |
try: | |
iptm = float(csp_entry.get('iptm', 0)) | |
ptm = float(csp_entry.get('ptm', 0)) | |
return iptm * 0.8 + ptm * 0.2 | |
except (ValueError, TypeError): | |
return 0 | |
def find_best_models_per_cluster(clustering_data, csp_data, stats_data): | |
"""Find the PDB file with highest posterior score in each cluster.""" | |
# Create mapping from PDB file to scores | |
pdb_scores = {} | |
# Create mapping of PDB files to their SCC scores | |
scc_scores = {} | |
for entry in stats_data: | |
pdb_file = entry.get('model_name') | |
if pdb_file: | |
scc_scores[pdb_file+'.pdb'] = calculate_entry_scc(entry) | |
# Calculate confidence and posterior scores for each model | |
for csp_entry in csp_data: | |
pdb_file = os.path.basename(csp_entry.get('holo_model_path', '')) | |
if pdb_file: | |
confidence = calculate_confidence(csp_entry) | |
scc = scc_scores.get(pdb_file, 0) | |
posterior = confidence * scc | |
pdb_scores[pdb_file] = posterior | |
print(f"PDB: {pdb_file}, Confidence: {confidence:.4f}, SCC: {scc:.4f}, Posterior: {posterior:.4f}") | |
# Group by cluster and find best score | |
cluster_best_models = {} | |
for entry in clustering_data: | |
cluster = entry.get('Cluster') | |
pdb_file = entry.get('pdb_file') | |
if cluster is not None and pdb_file: | |
cluster = int(cluster) | |
score = pdb_scores.get(pdb_file, 0) | |
if cluster not in cluster_best_models or score > cluster_best_models[cluster][1]: | |
cluster_best_models[cluster] = (pdb_file, score) | |
return cluster_best_models | |
def process_files(pdb_id): | |
"""Process all three CSV files for a given PDB ID.""" | |
# Define file paths | |
csp_file = f"./CSP_Rank_Scores/CSP_{pdb_id}_CSpred.csv" | |
stats_file = f"./plots/shift_analysis/{pdb_id}/statistics.csv" | |
clustering_file = f"./CLUSTERING_RESULTS/{pdb_id}_aligned_CSPREDB_TSNE_chain_B_data.csv" | |
# Parse each file | |
print(f"\nProcessing files for PDB ID: {pdb_id}") | |
print("\nParsing CSP prediction file...") | |
csp_data = parse_csv(csp_file) | |
if csp_data: | |
print(f"Successfully parsed {len(csp_data)} entries from CSP prediction file") | |
print("\nParsing statistics file...") | |
stats_data = parse_csv(stats_file) | |
if stats_data: | |
print(f"Successfully parsed {len(stats_data)} entries from statistics file") | |
print("\nParsing clustering data...") | |
clustering_data = parse_csv(clustering_file) | |
if clustering_data: | |
print(f"Successfully parsed {len(clustering_data)} entries from clustering file") | |
return csp_data, stats_data, clustering_data | |
def main(): | |
if len(sys.argv) != 2: | |
print("Usage: python get_CS_AF_ensemble.py <pdb_file>") | |
sys.exit(1) | |
pdb_file = sys.argv[1] | |
pdb_id = get_pdb_id(pdb_file) | |
print(f"Processing data for PDB ID: {pdb_id}") | |
csp_data, stats_data, clustering_data = process_files(pdb_id) | |
if not all([csp_data, stats_data, clustering_data]): | |
print("\nError: One or more files could not be processed") | |
sys.exit(1) | |
# Find best models per cluster | |
best_models = find_best_models_per_cluster(clustering_data, csp_data, stats_data) | |
# Print results | |
# Create directory for CS AF ensemble files | |
cs_af_dir = f'./PDB_FILES/{pdb_id}_CS_AF_files/' | |
if not os.path.exists(cs_af_dir): | |
os.makedirs(cs_af_dir) | |
# Copy best model from each cluster to new directory | |
print("\nBest models per cluster:") | |
for cluster, (pdb_file, score) in sorted(best_models.items()): | |
print(f"Cluster {cluster}: {pdb_file} (posterior score: {score:.4f})") | |
# Copy file from aligned dir to CS AF dir | |
src = f'./PDB_FILES/{pdb_id.upper()}_aligned/{pdb_file}' | |
dst = os.path.join(cs_af_dir, pdb_file) | |
if os.path.exists(src): | |
import shutil | |
shutil.copy2(src, dst) | |
print(f"Copied {pdb_file} to {cs_af_dir}") | |
else: | |
print(f"Warning: Source file {src} not found") | |
if __name__ == "__main__": | |
main() |