Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/calc_TM_top_rank.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
321 lines (267 sloc)
11.8 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from util import * | |
from Bio import PDB | |
import seaborn as sns | |
from Bio.PDB import PDBParser | |
import math | |
data_source_file = './CSPRANK.csv' | |
parsed_data = pd.read_csv(data_source_file) | |
holo_NMR_structure_dir = './PDB_FILES/NMR_holo/' | |
AF2_structure_dir = './PDB_FILES/AF2_holo/' | |
import pandas as pd | |
import subprocess | |
import re | |
# Function to call TM-align and return the TM-score | |
def get_tm_score(pdb_path1, pdb_path2): | |
try: | |
# Call TM-align with the two PDB paths | |
print(f"./MMalign {pdb_path1} {pdb_path2}") | |
result = subprocess.run(['./MMalign', pdb_path1, pdb_path2], capture_output=True, text=True) | |
output = result.stdout | |
# Extract TM-score from the output | |
# This regex looks for TM-score in the output assuming it's in the format "TM-score= X.XXXX" | |
match = re.search(r"TM-score\s*=\s*(\d.\d+)", output) | |
if match: | |
return float(match.group(1)) | |
else: | |
print("TM-score not found in the output.") | |
return None | |
except Exception as e: | |
print(f"Error calculating TM-score for {pdb_path1} and {pdb_path2}: {e}") | |
return None | |
def calculate_gdt_ts(pdb_file_ref, pdb_file_model, cutoffs=[1.0, 2.0, 4.0, 8.0]): | |
""" | |
Calculate the GDT_TS (Global Distance Test - Total Score) between two | |
PDB structures (potentially multi-chain). Assumes the structures are | |
pre-aligned and that chain + residue numbers correspond to each other. | |
:param pdb_file_ref: Path to the reference PDB file. | |
:param pdb_file_model: Path to the model PDB file. | |
:param cutoffs: List of distance cutoffs in angstroms for GDT calculation (default: [1,2,4,8]). | |
:return: Computed GDT_TS (float). | |
""" | |
parser = PDBParser(QUIET=True) | |
structure_ref = parser.get_structure("reference", pdb_file_ref) | |
structure_model = parser.get_structure("model", pdb_file_model) | |
# Extract C-alpha coordinates by (chain_id, residue_number) | |
ref_coords = {} | |
for chain in structure_ref.get_chains(): | |
for residue in chain.get_residues(): | |
if residue.has_id('CA'): | |
chain_id = chain.get_id() | |
res_id = residue.get_id()[1] # ( (' ', residue_number, ' ') ) | |
ref_coords[(chain_id, res_id)] = residue['CA'].get_vector() | |
model_coords = {} | |
for chain in structure_model.get_chains(): | |
for residue in chain.get_residues(): | |
if residue.has_id('CA'): | |
chain_id = chain.get_id() | |
res_id = residue.get_id()[1] | |
model_coords[(chain_id, res_id)] = residue['CA'].get_vector() | |
# Identify matching residues | |
common_keys = set(ref_coords.keys()).intersection(model_coords.keys()) | |
if not common_keys: | |
raise ValueError("No matching (chain, residue_number) entries found between the two structures.") | |
# Calculate distances for each matching residue | |
distances = [] | |
for key in common_keys: | |
dist = (ref_coords[key] - model_coords[key]).norm() | |
distances.append(dist) | |
# For each cutoff, compute the fraction of residues with distance <= cutoff | |
gdt_scores = [] | |
total_matched = len(distances) | |
for c in cutoffs: | |
within_cutoff = sum(d <= c for d in distances) | |
fraction = within_cutoff / total_matched | |
gdt_scores.append(fraction) | |
# GDT_TS is typically the average of the fractions * 100 | |
# (i.e., percentage), but many definitions just give the fraction. | |
# For clarity, we compute it as a percentage. | |
gdt_ts = (sum(gdt_scores) / len(gdt_scores)) * 100.0 | |
return gdt_ts | |
def parse_ranges(ranges_str): | |
"""Parse the input string to extract chains and residue ranges.""" | |
ranges = {} | |
for part in ranges_str.split(', '): | |
chain, start_end = part.split(":") | |
start_res, end_res = map(int, start_end.split("..")) | |
if chain not in ranges: | |
ranges[chain] = [] | |
ranges[chain].append((start_res, end_res)) | |
return ranges | |
def adjust_residue_numbers(ranges): | |
"""Adjust residue numbers to start from 1 for each chain.""" | |
adjustment_maps = {} | |
for chain, chain_ranges in ranges.items(): | |
adjustment_map = {} | |
new_residue_num = 1 | |
for start_res, end_res in chain_ranges: | |
for original_res_num in range(start_res, end_res + 1): | |
adjustment_map[original_res_num] = new_residue_num | |
new_residue_num += 1 | |
adjustment_maps[chain] = adjustment_map | |
return adjustment_maps | |
def trim_pdb_by_residues(pdb_file_path, ranges_str, new_file): | |
"""Trim and reindex residues in a PDB file based on provided ranges string.""" | |
ranges = parse_ranges(ranges_str) | |
adjustment_maps = adjust_residue_numbers(ranges) | |
print(ranges) | |
with open(pdb_file_path, 'r') as pdb_file, open(new_file, 'w') as output_file: | |
for line in pdb_file: | |
if line.startswith("ATOM") or line.startswith("HETATM"): | |
chain_id = line[21] | |
residue_num = int(line[22:26].strip()) | |
if chain_id in ranges: | |
for start, end in ranges[chain_id]: | |
if start <= residue_num <= end: | |
# Adjust residue number | |
adjusted_residue_num = adjustment_maps[chain_id][residue_num] | |
# Rewrite line with adjusted residue number | |
new_line = line[:22] + "{:>4}".format(adjusted_residue_num) + line[26:] | |
output_file.write(new_line) | |
break | |
else: | |
# Write lines for chains not in ranges as they are | |
output_file.write(line) | |
return new_file | |
def align_raw(NMR_file, AF2_file): | |
NMR_sequence = get_pdb_sequence(NMR_file) | |
AF2_sequence = get_pdb_sequence(AF2_file) | |
NMR_aligned, AF2_aligned = align(NMR_sequence, AF2_sequence) | |
print(NMR_aligned) | |
print(AF2_aligned) | |
# raise | |
for i, holo_pdb in enumerate(parsed_data['holo_pdb']): | |
apo = str(parsed_data['apo_bmrb'][i]) | |
holo = holo_pdb.lower() | |
# print(holo) | |
# if holo != '2lsp': | |
# continue | |
# if holo in ['2kne', '2lsk', '2m55', '2mps', '5j8h']: | |
# continue | |
def get_largest_chain(structure): | |
largest_chain = None | |
max_residues = 0 | |
for model in structure: | |
for chain in model: | |
num_residues = len([residue for residue in chain if PDB.is_aa(residue)]) | |
if num_residues > max_residues: | |
max_residues = num_residues | |
largest_chain = chain | |
return largest_chain | |
def align_chains(ref_chain, mobile_chain): | |
# Create a Superimposer object | |
super_imposer = PDB.Superimposer() | |
# Extract alpha carbon (CA) atoms from the reference chain, only from 'ATOM' lines | |
ref_atoms = [residue['CA'] for residue in ref_chain if 'CA' in residue and residue.id[0] == ' '] | |
# Extract alpha carbon (CA) atoms from the mobile chain, only from 'ATOM' lines | |
mobile_atoms = [residue['CA'] for residue in mobile_chain if 'CA' in residue and residue.id[0] == ' '] | |
# Print the number of CA atoms in the reference and mobile chains | |
# print("REF: ", len(ref_atoms)) | |
# print("MOBILE: ", len(mobile_atoms)) | |
# Set the atoms to be used for alignment | |
super_imposer.set_atoms(ref_atoms, mobile_atoms) | |
# Return the Superimposer object | |
return super_imposer | |
# get top rank AF2 file | |
AF2_files = [AF2_structure_dir + f for f in os.listdir(AF2_structure_dir) if f.find(holo+"_1") != -1] | |
if len(AF2_files) == 0: | |
print(f"Could not find AF2 file for {holo}") | |
continue | |
NMR_files = [holo_NMR_structure_dir + f for f in os.listdir(holo_NMR_structure_dir) if f.find(holo) != -1] | |
if len(NMR_files) == 0: | |
print(f"Could not find NMR file for {holo}") | |
continue | |
medoid_NMR_file = find_medoid_structure(NMR_files) | |
# get medoid NMR file | |
NMR_files = [medoid_NMR_file] | |
# TRIM THE NMR FILE | |
# print(AF2_files) | |
# print(NMR_files) | |
def superpose(NMR_file, AF2_file): | |
parser = PDB.PDBParser(QUIET=True) | |
ref_structure = parser.get_structure('NMR', NMR_file) | |
mobile_structure = parser.get_structure('AF2', AF2_file) | |
# Get the largest chains | |
ref_chain = get_largest_chain(ref_structure) | |
mobile_chain = get_largest_chain(mobile_structure) | |
# Align the chains | |
super_imposer = align_chains(ref_chain, mobile_chain) | |
# Apply the transformation to all atoms in the mobile structure | |
for model in mobile_structure: | |
for chain in model: | |
super_imposer.apply(chain.get_atoms()) | |
# Save the aligned structure | |
io = PDB.PDBIO() | |
io.set_structure(mobile_structure) | |
aligned_AF2_file = AF2_file.replace('.pdb', '_aligned.pdb') | |
io.save(aligned_AF2_file) | |
os.system('mv ' + aligned_AF2_file + ' ' + AF2_file) | |
# Update the AF2_file variable to point to the aligned file | |
return AF2_file | |
def trim(NMR_file): | |
trim = pd.read_csv(data_source_file) | |
well_defined_residues = "" | |
try: | |
well_defined_residues = trim.loc[trim['holo_pdb'] == holo]['Well_Defined_Residues'].values[0] | |
except: | |
print(holo + " not found in the data source file.") | |
# Construct the full path to the pdb file | |
pdb_path = NMR_file | |
# Construct the full path to the output directory | |
output_file = f'{holo_NMR_structure_dir}{pdb_path.split('/')[-1].split('.')[0]}_trim.pdb' | |
# Move the pdb file to the output directory | |
trim_pdb_by_residues(NMR_file, well_defined_residues, output_file) | |
return output_file | |
tm_scores = [] | |
gdt_ts_scores = [] | |
for i, NMR_file in enumerate(NMR_files): | |
tm_scores.append([]) | |
gdt_ts_scores.append([]) | |
for AF2_file in AF2_files: | |
# align_raw(NMR_file, AF2_file) | |
# break | |
try: | |
AF2_file = superpose(NMR_file, AF2_file) | |
except: | |
try: | |
print("TRIMMING NMR FILE") | |
NMR_file = trim(NMR_file) | |
print("NMR_FILE TRIMMED") | |
print("New NMR_file = " + NMR_file) | |
print("ALIGNING") | |
AF2_file = superpose(NMR_file, AF2_file) | |
except: | |
# try aligning the sequences then trim | |
print("Could not align or trim the structure.") | |
raise | |
# print(AF2_file) | |
# print(NMR_file) | |
# raise | |
gdt_ts = calculate_gdt_ts(AF2_file, NMR_file) | |
gdt_ts_scores[i].append(gdt_ts) | |
tm_score = get_tm_score(AF2_file, NMR_file) | |
tm_scores[i].append(tm_score) | |
import matplotlib.pyplot as plt | |
def plot_heatmap(data, title, xlabel, ylabel, output_file): | |
"""Plot a 2D heatmap from a 2D array.""" | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(data, annot=True, fmt=".2f", cmap="viridis") | |
plt.title(title) | |
plt.xlabel(xlabel) | |
plt.ylabel(ylabel) | |
plt.savefig(output_file) | |
# plt.show() | |
plt.cla() | |
plt.clf() | |
plt.close() | |
# continue | |
# Example usage: | |
# plot_heatmap(tm_scores, 'TM Scores Heatmap', 'AF2 Files', 'NMR Files', f'./images/{holo}_TM_scores_heatmap.png') | |
TM = np.mean(tm_scores) | |
GDT_TS = np.mean(gdt_ts_scores) | |
# new_values = [GDT_TS] | |
# new_columns = ['AF2_GDT_TS_top_rank'] | |
new_values = [TM] | |
new_columns = ['AF2_TM_top_rank'] | |
print(new_values) | |
update_row(data_source_file, apo.lower(), holo, new_values, new_columns) | |
# raise |