Skip to content
Permalink
5ab4ed8e4a
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
2352 lines (1991 sloc) 91.7 KB
# util.py
from Bio.Seq import Seq
from paths import *
import pickle
import urllib.request
import json
from tqdm import tqdm
import requests
from os.path import exists
from os import listdir
import os
import csv
import math
import matplotlib.pyplot as plt
import numpy as np
from Bio.Align import PairwiseAligner
from Bio.PDB import PDBParser, PDBIO, Select
import pymol
from pymol import cmd, stored
from scipy.optimize import minimize
import re
from collections import defaultdict
import os
from os import listdir
from os.path import exists, isdir
import mdtraj as md
import pandas as pd
import ast
from scipy.stats import linregress
import statistics
from scipy import stats
import subprocess
from Bio.PDB.Polypeptide import Polypeptide
from Bio.PDB.Polypeptide import is_aa
import zipfile
from Bio.PDB import *
holo_entries = ['1d5g', '1gjz', '1jgn', '1jm4', '1klq', '1l3e', '1lxf', '1m4p', '1m4q', '1oo9', '1r8u', '1sy9', '1vj6', '1wa7', '1x9a', '1ywi', '2b6g', '2dyf', '2fwl', '2gzu', '2h7e', '2i94', '2ipa', '2jmf', '2jmx', '2jqf', '2jxc', '2k17', '2k33', '2k6d', '2k79', '2k7a', '2k7i', '2k8f', '2k9j', '2ka4', '2ka6', '2kc8', '2kff', '2kfg', '2kfh', '2kgi', '2khs', '2kid', '2kje', '2kjz', '2kne', '2knh', '2kpl', '2ksp', '2kvm', '2kzu', '2l00', '2l12', '2l14', '2l1b', '2l1c', '2l1r', '2l29', '2l3r', '2l48', '2l4t', '2l6e', '2lag', '2las', '2lbm', '2le8', '2lfh', '2lgf', '2lgg', '2lgk', '2li5', '2ll6', '2ll7', '2llo', '2llq', '2lmc', '2lns', '2lnw', '2loz', '2lp0', '2lp8', '2lqc', '2lsi', '2lsk', '2lue', '2lv6', '2lvo', '2lxm', '2lxp', '2ly4', '2lz6', '2m04', '2m0g', '2m0j', '2m0k', '2m0u', '2m0v', '2m3o', '2m55', '2m56', '2m86', '2m8l', '2mbb', '2mbh', '2mc0', '2mc6', '2mcn', '2mej', '2mes', '2mg5', '2mk9', '2mki', '2mli', '2mlz', '2mma', '2mn6', '2mnz', '2mp2', '2mps', '2ms4', '2msr', '2mtp', '2mur', '2mv7', '2mwo', '2mx9', '2mzd', '2n01', '2n1g', '2n3a', '2n3k', '2n77', '2n8j', '2n8r', '2n8t', '2n9e', '2n9x', '2nd0', '2nd1', '2roz', '2rqg', '2rr4', '2ru4', '2rui', '2rvn', '2ys5', '4a54', '4asv', '4yyp', '5iay', '5j6z', '5j8h', '5lvf', '5m8i', '5m9d', '5mf9', '5owi', '5owj', '5ue5', '5ujn', '5vzm', '5x3z', '6bgg', '6bgh', '6c0a', '6co4', '6ctb', '6e5n', '6e83', '6g04', '6ijq', '6qxz', '6rh6', '6so9', '6tvm', '6u19', '6uz4', '6x4x', '7a0o', '7jq8', '7klr', '7l8v', '7nqc', '7s5j', '7sft', '7t2f', '7zey', '8dgh', '8dgk']
def continue_prompt():
while True:
user_input = input("Do you want to continue? (y/n): ").lower()
if user_input in ['n', 'no']:
print("Terminating script.")
return False
exit()
elif user_input in ['y', 'yes']:
print("Continuing execution.")
return True
break
else:
print("Invalid input. Please enter 'y' for yes or 'n' for no.")
def get_pdb_file(pdb_id):
url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
try:
response = requests.get(url)
if response.status_code == 200:
with open(f"./bmrb_dat/{pdb_id}.pdb", 'w') as out_file:
out_file.write(response.text)
else:
print(f"Couldn't download PDB file {pdb_id}")
return None
return f"./bmrb_dat/{pdb_id}.pdb"
except:
return None
def execute_command_and_get_output(command):
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output = result.stdout.decode().strip() # decode stdout to string and remove trailing whitespace
return output
def merge_chains(input_pdb, output_pdb):
new_chain_id = "A"
last_residue_id = 0
current_chain_id = None
with open(input_pdb, 'r') as f_in, open(output_pdb, 'w') as f_out:
for line in f_in:
if line.startswith("ATOM") or line.startswith("HETATM"):
chain_id = line[21]
residue_id = int(line[22:26].strip())
if current_chain_id != chain_id:
residue_id_offset = last_residue_id
current_chain_id = chain_id
new_residue_id = residue_id + residue_id_offset
new_line = line[:21] + new_chain_id + f"{new_residue_id:>4}" + line[26:]
f_out.write(new_line)
last_residue_id = new_residue_id
else:
f_out.write(line)
#print("wrote to " + output_pdb)
def convert_aa_name(aa_name):
aa_dict = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
'UNK': 'X', 'GLX': 'G', 'SEC': 'S', 'MSE': 'M', 'ALC': 'A',
'ORN': 'X'}
try:
return aa_dict[aa_name]
except:
return 'X'
from Bio.PDB.Polypeptide import is_aa
def get_pdb_sequence_from_file(pdb_path):
sequence = []
current_chain = ""
prev_chain = ""
prev_res_id = None
with open(pdb_path, "r") as pdb_file:
for line in pdb_file:
if line.startswith("ATOM"):
chain = line[21]
res_name = line[17:20].strip()
res_id = int(line[22:26])
if prev_chain != "" and prev_chain != chain:
sequence.append(':')
if prev_res_id != res_id:
sequence.append(convert_aa_name(res_name))
prev_chain = chain
prev_res_id = res_id
# do not return the sequence if there are not exactly 2 sequences
# with at least one sequence having less than 30 residues
# and at most 5 unknown AA's between the two sequences
#if sequence.count(':') != 1:
# return None
#if sequence.index(':') >= 30 and len(sequence) - sequence.index(':') >= 30:
# return None
#if sequence.count('X') > 5:
# return None
return "".join(sequence)
# finds best model according to Wallner's confidence metric = 0.8 * iptm + 0.2 * ptm;
# extracts model to zip_file_path and returns file name and confidence of the model.
# extract_path can be used to set where to extract the .pdbs to, otherwise it defaults
# to the same directory as the file containing the zip files.
def extract_best_Wallner_model_from_af2_output(zip_file_path, extract_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
max_file_name = ""
max_confidence = 0
for file_info in zip_file.infolist():
if file_info.filename.endswith('.json'):
with zip_file.open(file_info) as file:
data = json.load(file)
iptm = data.get('iptm', 0)
ptm = data.get('ptm', 0)
confidence = 0.8 * iptm + 0.2 * ptm
if confidence > max_confidence:
max_confidence = confidence
max_file_name = file_info.filename[file_info.filename.rfind('/')+1:file_info.filename.rfind('.')].replace('scores', 'unrelaxed')
print("Zip file path = " + zip_file_path)
print("Found this model to be highest confidence = " + max_file_name)
print("Confidence = " + str(max_confidence))
if max_file_name != "":
for file_info in zip_file.infolist():
file_name = file_info.filename
if file_name.endswith('.pdb'):
if file_name[file_name.rfind('/')+1:file_name.rfind('.')] == max_file_name:
zip_file.extract(file_info.filename, path = extract_path)
print(f"Extracted file: {file_info.filename}")
return max_file_name, max_confidence
print("No file with the specified content was found.")
raise
def extract_medoid_from_af2_output(zip_file_path):
# Open the zip file in read mode
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
try:
# Initialize an empty list to store the pdb file paths
pdb_files = []
# Iterate over each file in the zip file
for file_info in zip_file.infolist():
# Get the file name
file_name = file_info.filename
#print(file_name)
# If the file is a pdb file, add its path to the pdb_files list
if file_name.endswith('.pdb'):
extract_path = './output/'
pdb_files.append(extract_path + file_name)
zip_file.extract(file_info.filename, path = extract_path)
# print("extracted " + zip_file_path + '/' + file_info.filename + ' --> ' + extract_path + '.')
pdb_files = [f for f in pdb_files if exists(f)]
# print("finding medoid file of " + str(pdb_files))
medoid_file = find_medoid_structure(pdb_files)
for f in pdb_files:
if f != medoid_file:
os.system('rm ' + f)
# print("removed " + f)
# print("left medoid file = " + medoid_file)
return medoid_file # Exit the function after successful extraction
except:
# If no medoid structure was found in the zip file, print an error message and raise an exception
print("Did not extract medoid file from " + zip_file_path + ".")
raise
def add_chain_identifiers(pdb_file):
parser = PDBParser(QUIET=True)
structure = parser.get_structure('pdb', pdb_file)
# define the chain identifiers
chain_identifiers = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
# The model in a PDB file is assumed to be a protein. Usually a PDB file has only one model,
# but it can have more than one (NMR structures, for example)
for model in structure:
chain_id = 0
for chain in model:
chain.id = chain_identifiers[chain_id]
chain_id += 1
io = PDBIO()
io.set_structure(structure)
io.save(pdb_file) # overwrite the original PDB file
def extract_top_rank_from_af2_output(zip_file_path):
# Open the zip file in read mode
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
try:
# Initialize an empty list to store the pdb file paths
pdb_files = []
# Iterate over each file in the zip file
for file_info in zip_file.infolist():
# Get the file name
file_name = file_info.filename
#print(file_name)
# If the file is a pdb file, add its path to the pdb_files list
if file_name.endswith('.pdb') and file_name.find('rank_001') != -1:
extract_path = './output/'
pdb_files.append(extract_path + file_name)
zip_file.extract(file_info.filename, path = extract_path)
print("extracted " + zip_file_path + '/' + file_info.filename + ' --> ' + extract_path + '.')
except:
# If no medoid structure was found in the zip file, print an error message and raise an exception
print("Did not extract rank 1 model file from " + zip_file_path + ".")
raise
def get_sequences(pdb_id, file = False):
pdb_filename = pdb_id
if not(file):
pdb_filename = get_pdb_file(pdb_id)
if not pdb_filename:
return None
structure = PDBParser(QUIET=True).get_structure(pdb_id, pdb_filename)
sequences = {}
for model in structure:
for chain in model:
seq = ""
for residue in chain:
if is_aa(residue.get_resname(), standard=True):
seq += convert_aa_name(residue.get_resname())
sequences[chain.id] = seq
# Clean up the PDB file
#os.remove(pdb_filename)
return sequences
def align(bound_seq, apo_seq):
# Define two amino acid sequences to align
seq1 = ""
for c in bound_seq:
seq1 += c
seq2 = ""
for c in apo_seq:
seq2 += c
# Create a pairwise aligner object
aligner = PairwiseAligner()
# Set the parameters for the pairwise alignment
aligner.mode = 'local'
aligner.match_score = 1
aligner.mismatch_score = -1
aligner.open_gap_score = -2
aligner.extend_gap_score = -2
# Align the two sequences
alignments = aligner.align(seq1, seq2)
try:
alignment_data = [k.replace(' ', '_').replace('-','_') for k in str(alignments[0]).split('\n')]
except:
return bound_seq, apo_seq
#print(alignment_data)
bound_aligned_seq = alignment_data[0]
apo_aligned_seq = alignment_data[2]
while len(bound_aligned_seq) < len(apo_aligned_seq):
bound_aligned_seq += '_'
while len(apo_aligned_seq) < len(bound_aligned_seq):
apo_aligned_seq += '_'
return bound_aligned_seq, apo_aligned_seq
def get_pdb_sequence(pdb_path):
sequence = []
current_chain = ""
prev_chain = ""
prev_res_id = None
with open(pdb_path, "r") as pdb_file:
for line in pdb_file:
if line.startswith("ATOM"):
chain = line[21]
res_name = line[17:20].strip()
res_id = int(line[22:26])
if prev_chain != "" and prev_chain != chain:
sequence.append(':')
if prev_res_id != res_id:
sequence.append(convert_aa_name(res_name))
prev_chain = chain
prev_res_id = res_id
return "".join(sequence)
# plotting function
def plot_CSP_data(CSPs, redundant_flag, sequence, bound, apo, cutoff):
output_file = '/home/tiburon/Desktop/ROT4/complex3/CSP/' + bound.upper() + '_' + apo.upper() + '_CSP_plot.png'
fig, ax = plt.subplots()
CSPs = [ CSP if CSP > 0 else 0 for CSP in CSPs ]
bar_plot = ax.bar(np.arange(len(CSPs)), CSPs, align='center', width=0.8)
for i in range(len(CSPs)):
if redundant_flag[i]:
bar_height = bar_plot[i].get_height()
ax.text(i, bar_height * 1.05, "*", ha='center', va='bottom')
if cutoff > 0 :
# Add a red horizontal line to indicate the cutoff value for significance
ax.axhline(y=cutoff, color='r', linestyle='--', label='Significance Cutoff')
# Set the x-axis label and title
xticks = []
for i in range(0, len(sequence)):
if i %10 == 0:
xticks.append(str(i))
else:
xticks.append("")
ax.set_xticks(np.arange(0,len(sequence), 10))
ax.set_xticklabels(np.arange(0,len(sequence),10))
ax.set_xlabel('Residue Number', fontsize = 20)
ax.set_ylabel('ΔδN,H [ppm]', fontsize = 20)
ax.set_title('CSPs between ' + bound + ' and ' + apo + '.', fontsize = 24)
# Modify the size of the plot image
fig.set_size_inches(10, 8)
# Save the plot to a predefined .png file location
fig.savefig(output_file)
print("Generated " + output_file)
fig.clf()
plt.close()
def calculate_rmsd_matrix(ensemble):
n_structures = len(ensemble)
rmsd_matrix = np.zeros((n_structures, n_structures))
for i in range(n_structures):
for j in range(i+1, n_structures):
rmsd = md.rmsd(ensemble[i], ensemble[j])
rmsd_matrix[i, j] = rmsd
rmsd_matrix[j, i] = rmsd
return rmsd_matrix
def find_medoid_structure(pdb_files):
if not pdb_files:
raise ValueError("No PDB files found in the given directory")
pdb_paths = [pdb_file for pdb_file in pdb_files if pdb_file.find('_A') == -1 and pdb_file.find('_B') == -1 and pdb_file.endswith('.pdb')]
ensemble = []
for pdb_path in pdb_paths:
# print(pdb_path)
ensemble.append(md.load(pdb_path))
#ensemble = [md.load(pdb_path) for pdb_path in pdb_paths]
rmsd_matrix = calculate_rmsd_matrix(ensemble)
sum_rmsd = np.sum(rmsd_matrix, axis=1)
medoid_index = np.argmin(sum_rmsd)
return pdb_paths[medoid_index]
def clean_output_dir():
basedir = '/home/tiburon/Desktop/ROT4/complex4/'
pdb_files = [ basedir+'output/'+ f for f in listdir(basedir+ 'output/') if f.endswith('.pdb') ]
keep_files = [ f for f in pdb_files if f.rfind('.') - f.rfind('/') == 5 ]#or f.find('minimized') != -1]
delete_files = [ f for f in pdb_files if f not in keep_files ]
for f in delete_files:
os.system('rm ' + f)
def clean_boundpdbs():
basedir = '/home/tiburon/Desktop/ROT4/complex4/'
bound_files = [ basedir + 'boundpbs/' + f for f in listdir(basedir + 'boundpdbs/') if f.endswith('.pdb') ]
bound_directories = [ basedir + 'boundpdbs/' + f for f in listdir(basedir + 'boundpdbs/') if isdir(basedir + 'boundpdbs/'+f) ]
for direct in bound_directories:
for f in listdir(direct):
if f.endswith('.pdb'):
bound_files.append(direct + '/' + f)
for bound in bound_files:
if len(bound.split('_')) > 2:
os.system('rm ' + bound)
print('rm ' + bound)
def get_bound_path(pdb_id):
pdb_id = pdb_id.lower()
basedir = '/home/tiburon/Desktop/ROT4/complex4/'
if isdir(basedir+'boundpdbs/'+pdb_id):
try:
return find_medoid_structure([basedir + 'boundpdbs/'+pdb_id+'/' + f for f in listdir(basedir + 'boundpdbs/'+pdb_id+'/')])
except:
x = 0
if exists(basedir + 'boundpdbs/'+pdb_id+'.pdb'):
return basedir + 'boundpdbs/'+pdb_id+'.pdb'
else:
if exists(basedir + 'bmrb_dat/'+pdb_id.upper()+'.pdb'):
os.system('cp ' + basedir +'bmrb_dat/'+pdb_id.upper()+'.pdb '+basedir+'boundpdbs/'+pdb_id+'.pdb')
print("PROCESS BOUNDPDBS DIR")
return
#raise
def get_apo_path(pdb_id):
pdb_id = pdb_id.lower()
basedir = '/home/tiburon/Desktop/ROT4/complex4/'
if isdir(basedir + 'apopdbs/'+pdb_id):
return find_medoid_structure([basedir + 'apopdbs/'+pdb_id+'/' + f for f in listdir(basedir + 'apopdbs/'+pdb_id+'/')])
else:
if exists(basedir + 'apopdbs/'+pdb_id+'.pdb'):
return basedir + 'apopdbs/'+pdb_id+'.pdb'
else:
if exists(basedir + 'bmrb_dat/'+pdb_id.upper()+'.pdb'):
os.system('cp ' + basedir + 'bmrb_dat/'+pdb_id.upper()+'.pdb ' + basedir + 'apopdbs/'+pdb_id+'.pdb')
print("PROCESS BOUNDPDBS DIR")
raise
def get_bfactors(pdb_file):
parser = PDBParser(QUIET=True)
structure = parser.get_structure('PDB', pdb_file)
largest_chain = max(structure.get_chains(), key=lambda chain: len(list(chain.get_residues())))
bfactors = []
# Calculate average B-factor per residue
for residue in largest_chain.get_residues():
bfactor_sum = 0
atom_count = 0
for atom in residue.get_atoms():
bfactor_sum += atom.get_bfactor()
atom_count += 1
average_bfactor = bfactor_sum / atom_count if atom_count > 0 else None
bfactors.append(average_bfactor)
return bfactors
def parse_list(value):
try:
return ast.literal_eval(value)
except ValueError:
return value
except SyntaxError:
return value
def parse_csv(file_name):
data = []
with open(file_name, newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
data.append({k: parse_list(v) for k, v in row.items()})
return data
def get_model_bound_path(pdb_id):
pdb_id = pdb_id.lower()
basedir = '/home/tiburon/Desktop/ROT4/complex4/'
#files = [ basedir + 'output/' + f for f in listdir(basedir + 'output/') if f.startswith(pdb_id) and f.endswith('.pdb') ]
files = [ basedir + 'o/' + f for f in listdir(basedir + 'o/') if f.startswith(pdb_id) and f.endswith('.pdb') ]
if len(files) == 0 :
print("Error finding AF2 output file for " + pdb_id)
return None
else:
return files[0]
def remove_atoms_without_chain(pdb_file_path):
with open(pdb_file_path, "r") as f:
lines = f.readlines()
new_lines = [line for line in lines if not (line.startswith("ATOM") and line[21].strip() == '')]
with open(pdb_file_path, "w") as f:
f.writelines(new_lines)
def pdb_interchain_cb_distances(pdb_file):
parser = PDBParser(QUIET=True)
structure = parser.get_structure('PDB', pdb_file)
chains = list(structure.get_chains())
if len(chains) != 2:
return None #, "The PDB file should contain exactly two chains" + pdb_file
chains.sort(key=lambda chain: len(list(chain.get_residues())), reverse=True)
chain1_residues = list(chains[0].get_residues())
chain2_residues = list(chains[1].get_residues())
distances = [None] * len(chain1_residues)
# Calculate minimum distances
for i, res1 in enumerate(chain1_residues):
min_distance = np.inf
if res1.has_id("CB"):
atom1 = res1["CB"]
for res2 in chain2_residues:
if res2.has_id("CB"):
atom2 = res2["CB"]
distance = atom1 - atom2
if distance < min_distance:
min_distance = distance
distances[i] = min_distance
return distances
def pdb_interchain_distances(pdb_file, chain=None):
parser = PDBParser(QUIET=True)
structure = parser.get_structure('PDB', pdb_file)
chains = {c.get_id(): c for c in structure.get_chains()}
# If specific chain is provided, return distances for this chain
if chain is not None:
if chain in chains:
chain_residues = list(chains[chain].get_residues())
other_chain = chains[[c for c in chains if c != chain][0]]
other_chain_residues = list(other_chain.get_residues())
else:
return None#, "Invalid chain identifier provided, expected 'A' or 'B'"
else:
# Sort chains by length
sorted_chains = sorted(chains.items(), key=lambda c: len(list(c[1].get_residues())), reverse=True)
chain_residues = list(sorted_chains[0][1].get_residues())
other_chain_residues = list(sorted_chains[1][1].get_residues())
distances = [None] * len(chain_residues)
# Calculate minimum distances
for i, res1 in enumerate(chain_residues):
min_distance = np.inf
for res2 in other_chain_residues:
for atom1 in res1.get_atoms():
for atom2 in res2.get_atoms():
distance = atom1 - atom2
if distance < min_distance:
min_distance = distance
distances[i] = min_distance
return distances
def get_indirect_binding_interface( flags, pdb_file, cutoff = 2.5):
parser = PDBParser(QUIET=True)
structure = parser.get_structure('PDB', pdb_file)
largest_chain = max(structure.get_chains(), key=lambda chain: len(list(chain.get_residues())))
assert len(flags) == len(list(largest_chain.get_residues())), "The flags list should have a length equal to the number of residues in the largest chain"
distances = [False] * len(flags)
residues = list(largest_chain.get_residues())
# Calculate minimum distances for flagged residues
for i, res1 in enumerate(residues):
for j, res2 in enumerate(residues):
min_distance = np.inf
if flags[j]:
for atom1 in res1.get_atoms():
for atom2 in res2.get_atoms():
distance = atom1 - atom2
if distance < min_distance:
min_distance = distance
if min_distance < cutoff:
distances[i] = True
return distances
def write_to_csv(file_path, headers, data):
with open(file_path, mode='w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(headers)
writer.writerows(data)
def get_CSList(pdb_code):
local_path = '/home/tiburon/Desktop/ROT4/complex3/CSLists/' + pdb_code + ".csv"
if exists(local_path):
print("Already have HSQC data for " + pdb_code + ", continuing...")
# Define the URL to access the BMRB API for the given PDB code
# example url = https://api.bmrb.io/v2/search/get_bmrb_ids_from_pdb_id/2JNS
bmrb_url = "https://api.bmrb.io/v2/search/get_bmrb_ids_from_pdb_id/" + pdb_code
# Send a request to the BMRB API and retrieve the response
with urllib.request.urlopen(bmrb_url) as url:
responses = json.loads(url.read().decode())
try:
k = responses[0]
except:
print(pdb_code + " not in BMRB.")
return
# Check if the response contains a BMRB ID for the given PDB code
if "bmrb_id" in responses[0]:
bmrb_id = None
for response in responses:
if response["match_types"][0] == 'Exact':
bmrb_id = response["bmrb_id"]
if bmrb_id is None:
print("Couldn't find exact match.")
return
print("BMRB ID for PDB code", pdb_code, "is", bmrb_id)
# get Chemical shift list from bmrb
bmrb_csl_url = "https://api.bmrb.io/v2/entry/"+str(bmrb_id)
# Send a request to the BMRB API and retrieve the response
with urllib.request.urlopen(bmrb_csl_url) as url:
responses = json.loads(url.read().decode())
saveframe_data = responses[bmrb_id]['saveframes']
cs_frame_index = -1
for i, data in enumerate(saveframe_data):
#if data['name'] in ["assigned_chemical_shifts_1", "assigned_chem_shift_list_1", "chemical_shift_1", "ChemShifts"]:
if data['category'] == "assigned_chemical_shifts":
cs_frame_index = i
break
if cs_frame_index == -1:
print("Couldn't find chemical shift lists... continuing...")
return
CSL_data = saveframe_data[cs_frame_index]['loops']
loop_i = -1
for i in range(0, len(CSL_data)):
if CSL_data[i]['category'] == '_Atom_chem_shift':
loop_i = i
if loop_i == -1:
return
CSL_data = CSL_data[loop_i]
tags = CSL_data['tags']
data = CSL_data['data']
dat = [ d for d in data]
write_to_csv(local_path, tags, dat)
print("wrote to file " + local_path)
else:
print("No BMRB ID found for PDB code", pdb_code)
return
def update_row(csv_filename, apo, bound, new_values, new_columns):
try:
# Load the DataFrame if the CSV file exists
df = pd.read_csv(csv_filename, low_memory=False)
except (pd.errors.EmptyDataError, FileNotFoundError):
# Create an empty DataFrame if the CSV file is empty or doesn't exist
df = pd.DataFrame()
data_dict = {col: val for col, val in zip(new_columns, new_values)}
#print(data_dict)
#print(data_dict.items())
# Check if 'apo' and 'bound' columns exist
if 'apo_bmrb' not in df.columns or 'holo_model_path' not in df.columns:
df = df._append(data_dict, ignore_index=True)
else:
# Update or create the row
row_index = df[(df['holo_model_path'] == bound)].index
print("UPDATING ROW INDEX : " + str(row_index))
if not row_index.empty:
for col, val in data_dict.items():
try:
df.loc[row_index[0], col] = val
except:
data_dict['apo_bmrb'] = apo
data_dict['holo_model_path'] = bound
df = df._append(data_dict, ignore_index=True)
else:
data_dict['apo_bmrb'] = apo
data_dict['holo_model_path'] = bound
df = df._append(data_dict, ignore_index=True)
# Save the DataFrame back to the CSV file
df.to_csv(csv_filename, index=False)
def plot(xaxs, yaxs, xlab, ylab, zscore, show = False, outfile = "", heats = None):
if heats is not None:
plt.scatter(xaxs, yaxs, c=heats, cmap='hot')
plt.colorbar()
else:
plt.scatter(xaxs, yaxs)
slope, intercept, r_value, p_value, std_err = linregress(xaxs, yaxs)
r_squared = r_value ** 2
print(xlab + ' v ' + ylab)
print("R value:", r_value)
print("R^2 value:", r_squared)
x_fit = np.linspace(0, 1, 50)
y_fit = slope * x_fit + intercept
plt.xlim(0,max(xaxs))
plt.ylim(0,1)
plt.plot(x_fit, y_fit, 'r--')
#plt.text(0.2, 0.7, f"R$^2$ = {r_squared:.2f}", fontsize=12, color='red')
if r_value > 0:
plt.text(0.1, 0.9, f"R = {r_value:.2f}", fontsize=12, color='red')
elif r_value < 0:
plt.text(0.9, 0.9, f"R = {r_value:.2f}", fontsize=12, color='red')
#plt.xlim(min(0,min(xaxs)),max(1, max(xaxs)))
plt.ylim(min(0,min(yaxs)),max(1, max(yaxs)))
# Add axis labels and a title
plt.xlabel(xlab)
plt.ylabel(ylab)
plt.title(xlab + ' vs. ' + ylab)
# Display the scatterplot
if show:
plt.show()
if outfile:
if outfile == "":
plt.savefig('./corr/'+xlab+'vs'+ylab+'_'+str(zscore)+'.png')
else:
plt.savefig(outfile)
# Determine outliers using Z-scores
xaxs_z_scores = np.abs(stats.zscore(xaxs))
yaxs_z_scores = np.abs(stats.zscore(yaxs))
xaxs_outliers = np.where(xaxs_z_scores > 3)[0]
yaxs_outliers = np.where(yaxs_z_scores > 3)[0]
plt.cla()
return set(xaxs_outliers).union(yaxs_outliers)
def rmsd(offset, spectrum1, spectrum2):
aligned_spectrum1 = spectrum1 + offset
return np.sqrt(np.mean((aligned_spectrum1 - spectrum2) ** 2))
def read_CSList_csv(csv_path):
sequences = [[]]
H_shifts = [[]]
N_shifts = [[]]
CA_shifts = [[]]
CB_shifts = [[]]
CO_shifts = [[]]
min_n_err = 100
max_n_err = -1
min_h_err = 100
max_h_err = -1
print("reading " + csv_path)
# Open the CSV file and read the data from the first column into the list
ind = 0
seqi = 0
with open(csv_path, "r") as csv_file:
csv_reader = csv.DictReader(csv_file)
for row in csv_reader:
tind = int(row['Seq_ID'])
if tind < ind:
sequences.append([])
H_shifts.append([])
N_shifts.append([])
CA_shifts.append([])
CB_shifts.append([])
CO_shifts.append([])
seqi += 1
ind = 0
if tind > ind + 1:
for j in range(0, tind - ind - 1):
sequences[seqi].append("_")
H_shifts[seqi].append(-1)
N_shifts[seqi].append(-1)
sequences[seqi].append(convert_aa_name(row['Comp_ID']))
if tind == ind + 1:
while len(H_shifts[seqi]) < len(sequences[seqi]):
H_shifts[seqi].append(-1)
while len(N_shifts[seqi]) < len(sequences[seqi]):
N_shifts[seqi].append(-1)
sequences[seqi].append(convert_aa_name(row['Comp_ID']))
ind = tind
if row['Atom_ID'] == "H":
H_shifts[seqi].append(float(row['Val']))
if row['Atom_ID'] == "N":
N_shifts[seqi].append(float(row['Val']))
if row['Atom_ID'] == "CA":
CA_shifts[seqi].append(float(row['Val']))# + 0.1) #+ float(row['Val_err']))
if row['Atom_ID'] == "CB":
CB_shifts[seqi].append(float(row['Val']))# + 0.1) #+ float(row['Val_err']))
if row['Atom_ID'] == "C":
CO_shifts[seqi].append(float(row['Val']))# + 0.1) #+ float(row['Val_err']))
maxind = -1
maxlen = -1
for i,seq in enumerate(sequences):
if len(seq) > maxlen:
maxind = i
maxlen = max(maxlen, len(seq))
sequence = sequences[maxind]
H_shift = H_shifts[maxind]
N_shift = N_shifts[maxind]
CA_shift = CA_shifts[maxind]
CB_shift = CB_shifts[maxind]
CO_shift = CO_shifts[maxind]
while len(H_shift) < len(sequence):
H_shift.append(-1)
while len(N_shift) < len(sequence):
N_shift.append(-1)
while len(CA_shift) < len(sequence):
CA_shift.append(-1)
while len(CB_shift) < len(sequence):
CB_shift.append(-1)
while len(CO_shift) < len(sequence):
CO_shift.append(-1)
#print("N_SHIFT")
#print(N_shift)
#print("H_SHIFT")
#print(H_shift)
return sequence, H_shift, N_shift, CA_shift, CB_shift, CO_shift
def calculate_z_score_threshold(data, z_value):
# Filter out data points less than 0
filtered_data = [d for d in data if d >= 0]
# Ensure there are enough data points remaining
if len(filtered_data) < 2:
raise ValueError("Not enough data points greater than or equal to 0.")
# Calculate the mean of the filtered data
mean = statistics.mean(filtered_data)
# Calculate the standard deviation of the filtered data
std_dev = statistics.stdev(filtered_data)
# Calculate the z-score
z = mean + std_dev * z_value
return z
def calculate_z_score(data, value):
# Filter out data points less than 0
filtered_data = [d for d in data if d >= 0]
# Ensure there are enough data points remaining
if len(filtered_data) < 2:
raise ValueError("Not enough data points greater than or equal to 0.")
# Calculate the mean and standard deviation of the filtered data
mean = statistics.mean(filtered_data)
std_dev = statistics.stdev(filtered_data)
# Ensure the standard deviation is not zero (to avoid division by zero)
if std_dev == 0:
raise ValueError("Standard deviation is zero, z-score calculation is not possible.")
# Calculate the z-score for the given value
z_score = min((value - mean) / std_dev, 3)
return z_score
def calculate_surface_residues(pdb_filename, sphere_radius):
# Parse the PDB file
parser = PDBParser()
structure = parser.get_structure('protein', pdb_filename)
# Get all chains
chains = [chain for model in structure for chain in model]
# Find the largest chain
largest_chain = max(chains, key=lambda chain: len(chain))
# Prepare an atom list
atom_list = [atom for residue in largest_chain for atom in residue]
# Calculate surface residues
surface_residues = []
for i,residue in enumerate(largest_chain):
for atom in residue:
# Create a sphere around the atom
sphere = [a for a in atom_list if a - atom < sphere_radius]
if sphere:
surface_residues.append(i)
break
# return is the indeces of surface residues in the protein
return surface_residues
def calculate_buried_residues(pdb_filename, sphere_radius):
# Parse the PDB file
parser = PDBParser()
structure = parser.get_structure('protein', pdb_filename)
# Get all chains
chains = [chain for model in structure for chain in model]
# Find the largest chain
largest_chain = max(chains, key=lambda chain: len(chain))
# Compute relative solvent accessibility for each residue
model = largest_chain.get_parent()
rsa = HSExposureCA(model, radius=sphere_radius)
#print(dict(rsa))
# Identify buried residues
buried_residues = []
for i, residue in enumerate(largest_chain):
residue_id = (get_longest_chain(pdb_filename), residue.id)
if residue_id in rsa:
if rsa[residue_id][2] < 0.33: # or whatever threshold you choose
buried_residues.append(i)
print(residue_id)
return buried_residues
def find_optimal_offset(spectrum1, spectrum2):
spectrum1 = np.array(spectrum1)
spectrum2 = np.array(spectrum2)
initial_offset = 0
result = minimize(rmsd, initial_offset, args=(spectrum1, spectrum2), method='L-BFGS-B')
optimal_offset = result.x[0]
return optimal_offset
def spectra_alignment(bound_shift, t_bound_shift, t_apo_shift, factor = None):
if factor is None:
print("Define factor to calculate shift perturbations.")
raise
def calculate_shift_perturbations(holo, apo, factor):
return [((holo_val - apo_val) / factor) ** 2 for holo_val, apo_val in zip(holo, apo)]
if all(x == -1 for x in bound_shift) or all(x == -1 for x in t_apo_shift):
print("No alignment to calculate, returning")
return bound_shift
def remove_outliers(shifts):
shifts = np.array(shifts)
indices = np.arange(len(shifts)) # Keep track of original indices
while True:
mean_shift = np.mean(shifts)
std_shift = np.std(shifts)
cutoff = mean_shift + 3 * std_shift
below_cutoff = shifts <= cutoff
filtered_shifts = shifts[below_cutoff]
filtered_indices = indices[below_cutoff]
if len(filtered_shifts) == len(shifts):
break
shifts = filtered_shifts
indices = filtered_indices
return filtered_indices.tolist()
nt_bound_shift = []
nt_apo_shift = []
for i,s in enumerate(t_bound_shift):
if s > 0 and t_apo_shift[i] > 0:
nt_bound_shift.append(s)
nt_apo_shift.append(t_apo_shift[i])
t_bound_shift = [ s for s in nt_bound_shift ]
t_apo_shift = [ s for s in nt_apo_shift ]
shifts = calculate_shift_perturbations(t_bound_shift, t_apo_shift, factor)
indices = remove_outliers(shifts)
filtered_t_bound_shift = [t_bound_shift[i] for i in indices]
filtered_t_apo_shift = [t_apo_shift[i] for i in indices]
#print(t_bound_shift)
#print(t_apo_shift)
offset = find_optimal_offset(filtered_t_bound_shift, filtered_t_apo_shift)
for i in range(0, len(bound_shift)):
if bound_shift[i] != -1:
bound_shift[i] += offset
return bound_shift
def read_UCBShift(csv_file):
df = pd.read_csv(csv_file)
# If the convert_aa_name function needs to handle NaNs as well,
# ensure it's capable of doing so or preprocess df['RESNAME'] as needed.
sequence = [convert_aa_name(aa) for aa in df['RESNAME']]
# Fill NaNs with -1 in the DataFrame for specified columns before list comprehension
columns_to_replace_nans = ['H_UCBShift', 'HA_UCBShift', 'N_UCBShift', 'C_UCBShift', 'CA_UCBShift', 'CB_UCBShift']
df[columns_to_replace_nans] = df[columns_to_replace_nans].fillna(-1)
H_shift = [H for H in df['H_UCBShift']]
HA_shift = [HA for HA in df['HA_UCBShift']]
N_shift = [N for N in df['N_UCBShift']]
C_shift = [C for C in df['C_UCBShift']]
CA_shift = [CA for CA in df['CA_UCBShift']]
CB_shift = [CB for CB in df['CB_UCBShift']]
return sequence, H_shift, N_shift, CA_shift, CB_shift, C_shift
def calc_CSP_new(apo_csv, holo_csv, method = "MONTE", UCBShift_pred_holo = False, UCBShift_pred_apo = False):
def read_and_check(path, UCBS = False):
if not exists(path):
print("couldn't locate", path)
return None
if UCBS:
return read_UCBShift(path)
else:
return read_CSList_csv(path)
def align_and_reformat(aligned_sequence, sequence, H_shift, N_shift, CA_shift, CB_shift, CO_shift):
new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift = [], [], [], [], [], []
seq_index = 0
#print(aligned_sequence)
#print(sequence)
for i in range(len(aligned_sequence)):
if seq_index < len(sequence) and aligned_sequence[i] == sequence[seq_index]:
new_sequence.append(sequence[seq_index])
new_N_shift.append(N_shift[seq_index])
new_H_shift.append(H_shift[seq_index])
new_CA_shift.append(CA_shift[seq_index])
new_CB_shift.append(CB_shift[seq_index])
new_CO_shift.append(CO_shift[seq_index])
seq_index += 1
else:# i >= len(sequence) or (aligned_sequence[i] in ["_", "-"] and sequence[i] not in ['_', '-']):
new_sequence.append("_")
new_N_shift.append(-1)
new_H_shift.append(-1)
new_CA_shift.append(-1)
new_CB_shift.append(-1)
new_CO_shift.append(-1)
return new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift
def calculate_CSPS(N_shift_bound, H_shift_bound, CA_shift_bound, CB_shift_bound, CO_shift_bound, \
N_shift_apo, H_shift_apo, CA_shift_apo, CB_shift_apo, CO_shift_apo, bound_seq, method):
CSPs, redundant_flag = [], []
ind = 0
for bound_N, apo_N, bound_H, apo_H, bound_CA, apo_CA, bound_CB, apo_CB, bound_CO, apo_CO in zip(N_shift_bound, N_shift_apo, \
H_shift_bound, H_shift_apo, CA_shift_bound, CA_shift_apo,\
CB_shift_bound, CB_shift_apo, CO_shift_bound, CO_shift_apo):
if any([val == -1 for val in [bound_N, apo_N, bound_H, apo_H]]):
CSPs.append(-1)
redundant_flag.append(True)
ind += 1
continue
CSP = -1
if method == "MONTE":
CSP = math.sqrt(0.5*(((bound_N-apo_N)/2.56)**2 + ((bound_H-apo_H)/0.54)**2))
elif method == "WILLIAMSON2013":
if bound_seq[ind] == 'G':
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.20)**2 + (bound_H-apo_H)**2))
else:
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.14)**2 + (bound_H-apo_H)**2))
elif method == "EVENAS2001":
CSP = math.sqrt( (1/4) * ( ( ( bound_N-apo_N ) / 6.5 )**2 + \
( ( bound_CA-apo_CA ) / 3.62 )**2 + \
( ( bound_CB-apo_CB ) / 3.62 )**2 + \
( ( bound_CO-apo_CO ) / 2.93 ) **2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
elif method == "GRZESIEK1996":
CSP = math.sqrt( (1/4) * ( ( ( bound_H-apo_H ) ) ** 2 + \
( ( bound_N-apo_N ) / 5 ) ** 2 + \
( ( bound_CA-apo_CA ) / 2 ) ** 2 + \
( ( bound_CB-apo_CB ) / 2 ) ** 2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
CSPs.append(CSP)
redundant_flag.append(False)
ind += 1
if len([CS for CS in CSPs if CS > 0]) == 0:
print("No useful CSPs calculated, continuing...")
return CSPs, -1
new_redundant = []#CSPs.index(min([CS for CS in CSPs if CS > 0]))]
itr = 0
prev_value_3_std_below = 0
value_3_std_below = 0
mean_CSP = 0
while True:
#print(value_3_std_below)
#print(itr)
if itr > 0 and len(new_redundant) == 0:
break
#while len(new_redundant) > 0:
for i in new_redundant:
redundant_flag[i] = True
prev_value_3_std_below = value_3_std_below
current_CSPs = [ CS for i,CS in enumerate(CSPs) if redundant_flag[i] == False ]
#print(current_CSPs)
if len(current_CSPs) == 0:
break
mean_CSP = sum(current_CSPs) / len(current_CSPs)
variance = sum([((x - mean_CSP) ** 2) for x in current_CSPs]) / len(current_CSPs)
std_deviation = math.sqrt(variance)
value_3_std_below = mean_CSP + 3 * std_deviation
new_redundant = [ i for i,CS in enumerate(CSPs) if CS >= value_3_std_below and redundant_flag[i] == False]
itr += 1
#print(value_3_std_below)
return CSPs, value_3_std_below
# Start of calc_CSP function
bound_path = holo_csv
bound = bound_path[bound_path.rfind('/')+1:bound_path.rfind('/')+5]
print("GETTING DATA FROM " + bound_path)
bound_data = read_and_check(bound_path, UCBS = UCBShift_pred_holo)
if bound_data is None:
return None
apo_path = apo_csv
apo_data = read_and_check(apo_path, UCBS = UCBShift_pred_apo)
if apo_data is None:
return None
bound_sequence = bound_data[0]
apo_sequence = apo_data[0]
bound_aligned, apo_aligned = align(bound_sequence, apo_sequence)
print(str(len(bound_aligned)) + ", " + str(len(apo_aligned)))
bound_sequence, bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift = align_and_reformat(bound_aligned, *bound_data[:6])
apo_sequence, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift = align_and_reformat(apo_aligned, *apo_data[:6])
print("N")
print(bound_N_shift)
print("H")
print(bound_H_shift)
#raise
print(str(len(bound_N_shift)) + ", " + str(len(apo_N_shift)))
prev_redundant_flag = [False for i in bound_N_shift]
csp_itr = 0
error_flag = False
N_factor = -1
H_factor = -1
CA_factor = -1
CB_factor = -1
CO_factor = -1
if method == "MONTE":
N_factor = 2.56
H_factor = 0.54
elif method == "WILLIAMSON2013":
N_factor = 0.20
H_factor = 1
elif method == "EVENAS2001":
N_factor = 6.5
CA_factor = 3.62
CB_factor = 3.62
CO_factor = 2.93
elif method == "GRZESIEK1996":
H_factor = 1
N_factor = 5
CA_factor = 2
CB_factor = 2
bound_N_shift = spectra_alignment(bound_N_shift, bound_N_shift, apo_N_shift, factor = N_factor)
bound_H_shift = spectra_alignment(bound_H_shift, bound_H_shift, apo_H_shift, factor = H_factor)
bound_CA_shift = spectra_alignment(bound_CA_shift, bound_CA_shift, apo_CA_shift, factor = CA_factor)
bound_CB_shift = spectra_alignment(bound_CB_shift, bound_CB_shift, apo_CB_shift, factor = CB_factor)
bound_CO_shift = spectra_alignment(bound_CO_shift, bound_CO_shift, apo_CO_shift, factor = CO_factor)
CSPs, CSP_sig_cutoff = calculate_CSPS(bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, \
apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence, method)
#print("Sig cutoff = " +str(CSP_sig_cutoff))
return CSPs, CSP_sig_cutoff, bound_sequence
def calc_CSP(bound, apo, method = "MONTE", buried = True, bound_file = None):
def read_and_check(path):
if not exists(path):
print("couldn't locate", path)
return None
return read_CSList_csv(path)
def align_and_reformat(aligned_sequence, sequence, H_shift, N_shift, CA_shift, CB_shift, CO_shift):
new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift = [], [], [], [], [], []
seq_index = 0
#print("ALIGNED")
#print(aligned_sequence)
for i in range(len(aligned_sequence)):
if seq_index < len(sequence) and aligned_sequence[i] == sequence[seq_index]:
new_sequence.append(sequence[seq_index])
new_N_shift.append(N_shift[seq_index])
new_H_shift.append(H_shift[seq_index])
new_CA_shift.append(CA_shift[seq_index])
new_CB_shift.append(CB_shift[seq_index])
new_CO_shift.append(CO_shift[seq_index])
seq_index += 1
elif i >= len(sequence) or ( aligned_sequence[i] in ["_", "-"] and sequence[i] not in ['_', '-']):
new_sequence.append("_")
new_N_shift.append(-1)
new_H_shift.append(-1)
new_CA_shift.append(-1)
new_CB_shift.append(-1)
new_CO_shift.append(-1)
return new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift
def get_mutual_edge_gaps(bound, apo):
start = -1
end = -1
for i in range(0, len(bound)):
if bound[i] != -1 or apo[i] != -1:
start = i
break
end = next(i for i in reversed(range(max(len(bound), len(apo)))) if (i < len(bound) and bound[i] != -1) or (i < len(apo) and apo[i] != -1))
if end != len(bound)-1:
end += 1
return start, end
def update_shifts(seq, N_shift, H_shift, CA_shift, CB_shift, CO_shift):
nN_shift, nH_shift, nCA_shift, nCB_shift, nCO_shift = [], [], [], [], []
ind = 0
for c in seq:
if c == '_':
nN_shift.append(-1)
nH_shift.append(-1)
nCA_shift.append(-1)
nCB_shift.append(-1)
nCO_shift.append(-1)
else:
nN_shift.append(N_shift[ind])
nH_shift.append(H_shift[ind])
nCA_shift.append(CA_shift[ind])
nCB_shift.append(CB_shift[ind])
nCO_shift.append(CO_shift[ind])
ind += 1
return nN_shift, nH_shift, nCA_shift, nCB_shift, nCO_shift
def calculate_CSPS(N_shift_bound, H_shift_bound, CA_shift_bound, CB_shift_bound, CO_shift_bound, \
N_shift_apo, H_shift_apo, CA_shift_apo, CB_shift_apo, CO_shift_apo, bound_seq, buried_residue_indeces, method):
CSPs = []
redundant_flag = []
ind = 0
for bound_N, apo_N, bound_H, apo_H, bound_CA, apo_CA, bound_CB, apo_CB, bound_CO, apo_CO in zip(N_shift_bound, N_shift_apo, \
H_shift_bound, H_shift_apo, CA_shift_bound, CA_shift_apo,\
CB_shift_bound, CB_shift_apo, CO_shift_bound, CO_shift_apo):
if any([val == -1 for val in [bound_N, apo_N, bound_H, apo_H]]):
CSPs.append(-1)
redundant_flag.append(True)
ind += 1
continue
CSP = -1
if method == "MONTE":
CSP = math.sqrt(0.5*(((bound_N-apo_N)/2.56)**2 + ((bound_H-apo_H)/0.54)**2))
elif method == "WILLIAMSON2013":
if bound_seq[ind] == 'G':
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.20)**2 + (bound_H-apo_H)**2))
else:
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.14)**2 + (bound_H-apo_H)**2))
elif method == "EVENAS2001":
CSP = math.sqrt( (1/4) * ( ( ( bound_N-apo_N ) / 6.5 )**2 + \
( ( bound_CA-apo_CA ) / 3.62 )**2 + \
( ( bound_CB-apo_CB ) / 3.62 )**2 + \
( ( bound_CO-apo_CO ) / 2.93 ) **2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
elif method == "GRZESIEK1996":
CSP = math.sqrt( (1/4) * ( ( ( bound_H-apo_H ) ) ** 2 + \
( ( bound_N-apo_N ) / 5 ) ** 2 + \
( ( bound_CA-apo_CA ) / 2 ) ** 2 + \
( ( bound_CB-apo_CB ) / 2 ) ** 2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
CSPs.append(CSP)
redundant_flag.append(False)
ind += 1
if len([CS for CS in CSPs if CS > 0]) == 0:
print("No useful CSPs calculated, continuing...")
return CSPs, -1
# now iteratively refine CSP by updating redundant flag
new_redundant = [CSPs.index(min([CS for CS in CSPs if CS > 0]))]
if buried:
try:
for buried_ind in buried_residue_indeces:
new_redundant.append(buried_ind)
except:
x= 0
itr = 0
prev_value_3_std_below = 0
value_3_std_below = 0
mean_CSP = 0
while len(new_redundant) > 0:
#print(redundant_flag)
#print(len(redundant_flag))
for i in new_redundant:
redundant_flag[i] = True
prev_value_3_std_below = value_3_std_below
current_CSPs = [ CS for i,CS in enumerate(CSPs) if redundant_flag[i] == False ]
#print(current_CSPs)
if len(current_CSPs) == 0:
break
mean_CSP = sum(current_CSPs) / len(current_CSPs)
#print("mean = " + str(mean_CSP))
# Calculate the standard deviation of the list
variance = sum([((x - mean_CSP) ** 2) for x in current_CSPs]) / len(current_CSPs)
std_deviation = math.sqrt(variance)
#print("standard deviation = " + str(std_deviation))
# Calculate the value that is 3 standard deviations below the mean
value_3_std_below = mean_CSP + 3 * std_deviation
new_redundant = [ i for i,CS in enumerate(CSPs) if CS >= value_3_std_below and redundant_flag[i] == False]
itr += 1
return CSPs, value_3_std_below
# Start of calc_CSP function
bound_path = f'/home/tiburon/Desktop/ROT4/complex4/CSLists/{bound.upper()}.csv'
bound_data = read_and_check(bound_path)
if bound_data is None:
return None
apo_path = f'/home/tiburon/Desktop/ROT4/complex4/CSLists/{apo.upper()}.csv'
apo_data = read_and_check(apo_path)
if apo_data is None:
return None
#print(apo_data)
#print(bound_data)
bound_sequence = bound_data[0]
apo_sequence = apo_data[0]
bound_aligned, apo_aligned = align(bound_sequence, apo_sequence)
#print(str(len(bound_aligned)) + ", " + str(len(apo_aligned)))
bound_sequence, bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift = align_and_reformat(bound_aligned, *bound_data[:6])
apo_sequence, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift = align_and_reformat(apo_aligned, *apo_data[:6])
#print(str(len(bound_N_shift)) + ", " + str(len(apo_N_shift)))
if False:
start, end = get_mutual_edge_gaps(bound_N_shift, apo_N_shift)
bound_N_shift = bound_N_shift[start:end]
apo_N_shift = apo_N_shift[start:end]
bound_H_shift = bound_H_shift[start:end]
apo_H_shift = apo_H_shift[start:end]
bound_CA_shift = bound_CA_shift[start:end]
apo_CA_shift = apo_CA_shift[start:end]
bound_CB_shift = bound_CB_shift[start:end]
apo_CB_shift = apo_CB_shift[start:end]
bound_CO_shift = bound_CO_shift[start:end]
apo_CO_shift = apo_CO_shift[start:end]
prev_redundant_flag = [False for i in bound_N_shift]
csp_itr = 0
error_flag = False
N_factor = -1
H_factor = -1
CA_factor = -1
CB_factor = -1
CO_factor = -1
if method == "MONTE":
N_factor = 2.56
H_factor = 0.54
elif method == "WILLIAMSON2013":
N_factor = 0.20
H_factor = 1
elif method == "EVENAS2001":
N_factor = 6.5
CA_factor = 3.62
CB_factor = 3.62
CO_factor = 2.93
elif method == "GRZESIEK1996":
H_factor = 1
N_factor = 5
CA_factor = 2
CB_factor = 2
bound_N_shift = spectra_alignment(bound_N_shift, bound_N_shift, apo_N_shift, factor = N_factor)
bound_H_shift = spectra_alignment(bound_H_shift, bound_H_shift, apo_H_shift, factor = H_factor)
bound_CA_shift = spectra_alignment(bound_CA_shift, bound_CA_shift, apo_CA_shift, factor = CA_factor)
bound_CB_shift = spectra_alignment(bound_CB_shift, bound_CB_shift, apo_CB_shift, factor = CB_factor)
bound_CO_shift = spectra_alignment(bound_CO_shift, bound_CO_shift, apo_CO_shift, factor = CO_factor)
print(bound_aligned)
print(apo_aligned)
if False:
print("bound_N_shift = ")
print(bound_N_shift)
print("apo_N_shift = ")
print(apo_N_shift)
print("bound_H_shift = ")
print(bound_H_shift)
print("apo_H_shift = ")
print(apo_H_shift)
buried_residue_indeces = None
if buried:
print("getting buried residues from " + bound_file)
buried_residue_indeces = calculate_buried_residues(bound_file, 3)
#if len(bound_aligned) > len(bound_H_shift):
# bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift = update_shifts(bound_aligned, bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift)
#if len(apo_aligned) > len(apo_H_shift):
# apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift = update_shifts(apo_aligned, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift)
#buried_residue_indeces = get_value_from_csv('CSP8.csv', bound, apo, 'buried_residue_indeces')
#print(buried_residue_indeces)
if len(str(buried_residue_indeces)) == 0 or buried_residue_indeces is None:
buried_residue_indeces = []
try:
CSPs, CSP_sig_cutoff = calculate_CSPS(bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, \
apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence, buried_residue_indeces, method)
except Exception as e:
print("GOT EXCEPTION: " + str(e))
while True:
redundant_flag = [CSP < CSP_sig_cutoff for CSP in CSPs]
print("calculating CSPs for " + bound + " attempt # " + str(csp_itr+1))
t_bound_N_shift = [bn if not flag else 0 for bn, flag in zip(bound_N_shift, redundant_flag)]
t_apo_N_shift = [an if not flag else 0 for an, flag in zip(apo_N_shift, redundant_flag)]
t_bound_H_shift = [bh if not flag else 0 for bh, flag in zip(bound_H_shift, redundant_flag)]
t_apo_H_shift = [ah if not flag else 0 for ah, flag in zip(apo_H_shift, redundant_flag)]
t_bound_CA_shift = [bca if not flag else 0 for bca, flag in zip(bound_CA_shift, redundant_flag)]
t_apo_CA_shift = [aca if not flag else 0 for aca, flag in zip(apo_CA_shift, redundant_flag)]
t_bound_CB_shift = [bcb if not flag else 0 for bcb, flag in zip(bound_CB_shift, redundant_flag)]
t_apo_CB_shift = [acb if not flag else 0 for acb, flag in zip(apo_CB_shift, redundant_flag)]
t_bound_CO_shift = [bco if not flag else 0 for bco, flag in zip(bound_CO_shift, redundant_flag)]
t_apo_CO_shift = [aco if not flag else 0 for aco, flag in zip(apo_CO_shift, redundant_flag)]
bound_N_shift = spectra_alignment(bound_N_shift, t_bound_N_shift, t_apo_N_shift, factor = N_factor)
bound_H_shift = spectra_alignment(bound_H_shift, t_bound_H_shift, t_apo_H_shift, factor = H_factor)
bound_CA_shift = spectra_alignment(bound_CA_shift, t_bound_CA_shift, t_apo_CA_shift, factor = CA_factor)
bound_CB_shift = spectra_alignment(bound_CB_shift, t_bound_CB_shift, t_apo_CB_shift, factor = CB_factor)
bound_CO_shift = spectra_alignment(bound_CO_shift, t_bound_CO_shift, t_apo_CO_shift, factor = CO_factor)
try:
CSPs, CSP_sig_cutoff = calculate_CSPS(bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, \
apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence, buried_residue_indeces, method)
except Exception as e:
error_flag = True
print("continuing due to exception: " + str(e))
break
if redundant_flag == prev_redundant_flag or csp_itr > 10:
break
else:
prev_redundant_flag = [ i for i in redundant_flag ]
csp_itr += 1
if error_flag:
print("RECEIVED ERROR.")
return None, None
print("Sig cutoff = " +str(CSP_sig_cutoff))
return CSPs, CSP_sig_cutoff, bound_aligned
def get_align_boundaries(bound_seq, match_seq):
# Create a pairwise aligner object
#print(bound_seq)
#print(match_seq)
seq1 = Seq(bound_seq)
seq2 = Seq(match_seq)
aligner = PairwiseAligner()
# Set the parameters for the pairwise alignment
aligner.mode = 'local'
aligner.match_score = 1
aligner.mismatch_score = -1
aligner.open_gap_score = -2
aligner.extend_gap_score = -2
# Align the two sequences
alignments = aligner.align(seq1, seq2)
if alignments:
# Get the first alignment (highest score)
first_alignment = alignments[0]
#print(first_alignment)
# Extract the start and end indices of the alignment on the target (apo_seq)
start_idx = first_alignment.aligned[1][0][0] # Start index of alignment in apo_seq
end_idx = first_alignment.aligned[1][0][1] # End index of alignment in apo_seq
return start_idx, end_idx
else:
# Return None if no alignment found
return None
def get_shift_data(apo_csv, bound_csv, CSmethod, match_seq):
def read_and_check(path, CSmethod):
if not exists(path):
print("couldn't locate", path)
return None
if CSmethod == "UCBShift":
return read_UCBShift(path)
elif CSmethod == "ShiftX":
return read_ShiftX(path)
elif CSmethod == "SPARTA":
return read_SPARTA(path)
else:
return read_CSList_csv(path)
def align_and_reformat(aligned_sequence, sequence, H_shift, N_shift, CA_shift, CB_shift, CO_shift):
new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift = [], [], [], [], [], []
seq_index = 0
#print(aligned_sequence)
#print(sequence)
for i in range(len(aligned_sequence)):
if seq_index < len(sequence) and aligned_sequence[i] == sequence[seq_index]:
new_sequence.append(sequence[seq_index])
new_N_shift.append(N_shift[seq_index])
new_H_shift.append(H_shift[seq_index])
new_CA_shift.append(CA_shift[seq_index])
new_CB_shift.append(CB_shift[seq_index])
new_CO_shift.append(CO_shift[seq_index])
seq_index += 1
else:# i >= len(sequence) or (aligned_sequence[i] in ["_", "-"] and sequence[i] not in ['_', '-']):
new_sequence.append("_")
new_N_shift.append(-1)
new_H_shift.append(-1)
new_CA_shift.append(-1)
new_CB_shift.append(-1)
new_CO_shift.append(-1)
return new_sequence, new_N_shift, new_H_shift, new_CA_shift, new_CB_shift, new_CO_shift
bound_data = read_and_check(bound_csv, CSmethod)
if bound_data is None:
return None
apo_data = read_and_check(apo_csv, CSmethod)
if apo_data is None:
return None
if match_seq is None:
new_bound_data = []
for i,dat in enumerate(bound_data):
new_bound_data.append(bound_data[i][well_defined_res[0]:well_defined_res[1]+1])
new_apo_data = []
for i, dat in enumerate(apo_data):
new_apo_data.append(apo_data[i][well_defined_res[0]:well_defined_res[1]+1])
bound_data = tuple(new_bound_data)
apo_data = tuple(new_apo_data)
else:
bound_seq = ''.join(bound_data[0])
print("MATCH SEQ = " + match_seq)
print("BOUND SEQ = " + bound_seq)
start_idx, end_idx = get_align_boundaries(match_seq, bound_seq)
new_bound_data = []
for i,dat in enumerate(bound_data):
new_bound_data.append(bound_data[i][start_idx:end_idx])
apo_seq = ''.join(apo_data[0])
print("APO SEQ = " + apo_seq)
start_idx, end_idx = get_align_boundaries(match_seq, apo_seq)
new_apo_data = []
for i, dat in enumerate(apo_data):
new_apo_data.append(apo_data[i][start_idx:end_idx])
bound_data = tuple(new_bound_data)
apo_data = tuple(new_apo_data)
bound_sequence = bound_data[0]
apo_sequence = apo_data[0]
print('BOUND SEQUENCE = ' + ''.join(bound_sequence))
print('APO SEQUENCE = ' + ''.join(apo_sequence))
#if len(''.join(bound_sequence)) != len(''.join(apo_sequence)):
# continue_prompt()
# #raise
bound_aligned, apo_aligned = align(bound_sequence, apo_sequence)
#print(str(len(bound_aligned)) + ", " + str(len(apo_aligned)))
#print(bound_aligned + "\n" + apo_aligned)
bound_sequence, bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift = align_and_reformat(bound_aligned, *bound_data[:6])
apo_sequence, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift = align_and_reformat(apo_aligned, *apo_data[:6])
#print(str(len(bound_N_shift)) + ", " + str(len(apo_N_shift)))
return bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence
def calculate_CSP_from_shifts(N_shift_bound, H_shift_bound, CA_shift_bound, CB_shift_bound, CO_shift_bound, \
N_shift_apo, H_shift_apo, CA_shift_apo, CB_shift_apo, CO_shift_apo, bound_seq, method):
N_factor = -1
H_factor = -1
CA_factor = -1
CB_factor = -1
CO_factor = -1
if method == "MONTE":
N_factor = 2.56
H_factor = 0.54
elif method == "WILLIAMSON2013":
N_factor = 0.20
H_factor = 1
elif method == "EVENAS2001":
N_factor = 6.5
CA_factor = 3.62
CB_factor = 3.62
CO_factor = 2.93
elif method == "GRZESIEK1996":
H_factor = 1
N_factor = 5
CA_factor = 2
CB_factor = 2
N_shift_bound = spectra_alignment(N_shift_bound, N_shift_bound, N_shift_apo, factor = N_factor)
H_shift_bound = spectra_alignment(H_shift_bound, H_shift_bound, H_shift_apo, factor = H_factor)
CA_shift_bound = spectra_alignment(CA_shift_bound, CA_shift_bound, CA_shift_apo, factor = CA_factor)
CB_shift_bound = spectra_alignment(CB_shift_bound, CB_shift_bound, CB_shift_apo, factor = CB_factor)
CO_shift_bound = spectra_alignment(CO_shift_bound, CO_shift_bound, CO_shift_apo, factor = CO_factor)
CSPs, redundant_flag = [], []
ind = 0
for bound_N, apo_N, bound_H, apo_H, bound_CA, apo_CA, bound_CB, apo_CB, bound_CO, apo_CO in zip(N_shift_bound, N_shift_apo, \
H_shift_bound, H_shift_apo, CA_shift_bound, CA_shift_apo,\
CB_shift_bound, CB_shift_apo, CO_shift_bound, CO_shift_apo):
if any([val == -1 for val in [bound_N, apo_N, bound_H, apo_H]]):
CSPs.append(-1)
redundant_flag.append(True)
ind += 1
continue
CSP = -1
if method == "MONTE":
CSP = math.sqrt(0.5*(((bound_N-apo_N)/2.56)**2 + ((bound_H-apo_H)/0.54)**2))
elif method == "WILLIAMSON2013":
if bound_seq[ind] == 'G':
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.20)**2 + (bound_H-apo_H)**2))
else:
CSP = math.sqrt(0.5*(((bound_N-apo_N)*0.14)**2 + (bound_H-apo_H)**2))
elif method == "EVENAS2001":
CSP = math.sqrt( (1/4) * ( ( ( bound_N-apo_N ) / 6.5 )**2 + \
( ( bound_CA-apo_CA ) / 3.62 )**2 + \
( ( bound_CB-apo_CB ) / 3.62 )**2 + \
( ( bound_CO-apo_CO ) / 2.93 ) **2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
elif method == "GRZESIEK1996":
CSP = math.sqrt( (1/4) * ( ( ( bound_H-apo_H ) ) ** 2 + \
( ( bound_N-apo_N ) / 5 ) ** 2 + \
( ( bound_CA-apo_CA ) / 2 ) ** 2 + \
( ( bound_CB-apo_CB ) / 2 ) ** 2 ) )
if -1 in [bound_CA, bound_CB, bound_CO, apo_CA, apo_CB, apo_CO ]:
CSP = -1
CSPs.append(CSP)
redundant_flag.append(False)
ind += 1
if len([CS for CS in CSPs if CS > 0]) == 0:
print("No useful CSPs calculated, continuing...")
return CSPs, -1
new_redundant = []#CSPs.index(min([CS for CS in CSPs if CS > 0]))]
itr = 0
prev_value_3_std_below = 0
value_3_std_below = 0
mean_CSP = 0
while True:
#print(value_3_std_below)
#print(itr)
if itr > 0 and len(new_redundant) == 0:
break
for i in new_redundant:
redundant_flag[i] = True
prev_value_3_std_below = value_3_std_below
current_CSPs = [ CS for i,CS in enumerate(CSPs) if redundant_flag[i] == False ]
#print(current_CSPs)
if len(current_CSPs) == 0:
break
mean_CSP = sum(current_CSPs) / len(current_CSPs)
variance = sum([((x - mean_CSP) ** 2) for x in current_CSPs]) / len(current_CSPs)
std_deviation = math.sqrt(variance)
value_3_std_below = mean_CSP + 3 * std_deviation
new_redundant = [ i for i,CS in enumerate(CSPs) if CS >= value_3_std_below and redundant_flag[i] == False]
itr += 1
#print(value_3_std_below)
return CSPs, value_3_std_below
def get_shift_files(apo, holo, structure_source, CSmethod, basename = None):
apo_shift_file = ""
holo_shift_file = ""
if structure_source == "NMR_real":
if CSmethod != "REAL":
print("malformed structure_source/CSmethod combination")
raise
apo_shift_file = real_CSList_dir + apo.upper() +'.csv'
holo_shift_file = real_CSList_dir + holo.upper() +'.csv'
elif structure_source == "NMR_pred":
if CSmethod == "UCBShift":
apo_shift_file = apo_AF2_shift_dir + apo +'.csv'
holo_shift_file = holo_NMR_shift_dir + holo +'.csv'
elif CSmethod == "SPARTA":
apo_shift_file = apo_AF2_SPARTA_dir + apo +'.tab'
holo_shift_file = holo_NMR_SPARTA_dir + holo +'.tab'
elif CSmethod == "ShiftX":
apo_shift_file = apo_AF2_ShiftX_dir + apo +'.pdb.cs'
holo_shift_file = holo_NMR_ShiftX_dir + holo +'.pdb.cs'
else:
print("malformed structure_source/CSmethod combination")
raise
elif structure_source == "AF2":
if CSmethod == "UCBShift":
apo_shift_file = apo_AF2_shift_dir + apo +'.csv'
holo_shift_file = holo_AF2_shift_dir + holo +'.csv'
elif CSmethod == "SPARTA":
apo_shift_file = apo_AF2_SPARTA_dir + apo +'.tab'
holo_shift_file = holo_AF2_SPARTA_dir + holo +'.tab'
elif CSmethod == "ShiftX":
apo_shift_file = apo_AF2_ShiftX_dir + apo +'.pdb.cs'
holo_shift_file = holo_AF2_ShiftX_dir + holo +'.pdb.cs'
else:
print("malformed structure_source/CSmethod combination")
raise
elif structure_source == "AF3":
if CSmethod == "UCBShift":
apo_shift_file = apo_AF3_shift_dir + apo +'.csv'
holo_shift_file = holo_AF3_shift_dir + holo +'.csv'
elif CSmethod == "SPARTA":
apo_shift_file = apo_AF3_SPARTA_dir + apo +'.tab'
holo_shift_file = holo_AF3_SPARTA_dir + holo +'.tab'
elif CSmethod == "ShiftX":
apo_shift_file = apo_AF3_ShiftX_dir + apo +'.pdb.cs'
holo_shift_file = holo_AF3_ShiftX_dir + holo +'.pdb.cs'
else:
print("malformed structure_source/CSmethod combination")
raise
elif structure_source == 'ES':
if CSmethod == 'UCBShift':
apo_shift_file = apo_AF2_shift_dir + apo + '.csv'
holo_shift_file = CS_Predictions + holo + '_AFS_shift_predictions/'+basename+'.csv'
if not(exists(holo_shift_file)):
holo_shift_file = CS_Predictions + holo + '_AFS2_shift_predictions/'+basename+'.csv'
if not(exists(holo_shift_file)):
print("Could not locate " + basename + " ES shift file in AFS_ or AFS2_shift_predictions directories.")
raise
else:
print("Unimplemented CS Prediction method and Structure Source.")
raise
else:
print("malformed structure_source.")
raise
return apo_shift_file, holo_shift_file
def calc_CSP(apo_csv, holo_csv, well_defined_res, method = "MONTE", CSmethod = "REAL", match_seq = None):
# Start of calc_CSP function
holo = holo_csv[holo_csv.rfind('/')+1:holo_csv.rfind('/')+5]
print("GETTING HOLO DATA FROM " + holo_csv)
print("GETTING APO DATA FROM " + apo_csv)
bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence = \
get_shift_data(apo_csv, holo_csv, CSmethod, match_seq)
if False:
print("holo N")
print(bound_N_shift)
print("holo H")
print(bound_H_shift)
print("apo N")
print(apo_N_shift)
print("apo H")
print(apo_H_shift)
continue_prompt()
CSPs, CSP_sig_cutoff = calculate_CSP_from_shifts(bound_N_shift, bound_H_shift, bound_CA_shift, bound_CB_shift, bound_CO_shift, \
apo_N_shift, apo_H_shift, apo_CA_shift, apo_CB_shift, apo_CO_shift, bound_sequence, method)
#print("Sig cutoff = " +str(CSP_sig_cutoff))
return CSPs, CSP_sig_cutoff, bound_sequence
def calc_CSP_wrapper(apo, holo, well_defined_res, method = "MONTE", CSmethod = "REAL", structure_source = "NMR_real", match_seq = "", basename=""):
apo_shift_file = ""
holo_shift_file = ""
if CSmethod != 'consensus':
apo_shift_file, holo_shift_file = get_shift_files(apo, holo, structure_source, CSmethod, basename=basename)
else:
return calc_CSP_consensus(apo, holo, well_defined_res, method = method, structure_source = structure_source, match_seq = match_seq)
return calc_CSP(apo_shift_file, holo_shift_file, well_defined_res, method=method, CSmethod = CSmethod, match_seq=match_seq)
def get_confusion(apo, holo, method, CSmethod, match_seq, well_defined_res, structure_source = "NMR"):
real_CSPs = []
real_CSP_outlier_cutoff = -1
real_bound_sequence = ""
pred_CSPs = []
pred_CSP_outlier_cutoff = -1
pred_bound_sequence = ""
try:
if structure_source == "NMR":
pred_CSPs, pred_CSP_outlier_cutoff, pred_holo_sequence = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure_source+'_pred', match_seq=match_seq)
else:
pred_CSPs, pred_CSP_outlier_cutoff, pred_holo_sequence = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod=CSmethod, structure_source=structure_source, match_seq=match_seq)
except Exception as e:
print(e)
raise
try:
real_CSPs, real_CSP_outlier_cutoff, real_holo_sequence = calc_CSP_wrapper(apo, holo, well_defined_res, method=method, CSmethod='REAL', structure_source='NMR_real', match_seq=match_seq)
except Exception as e:
print(e)
raise
#print("REAL CSPS = " + str(real_CSPs))
#print("PRED CSPS = " + str(pred_CSPs))
#print("Outlier cutoffs = ")
#print(real_CSP_outlier_cutoff)
#print(pred_CSP_outlier_cutoff)
#update_b_factors_longest_chain(bound_path, real_CSPs, real_bound_sequence, new_path)
pred_sequence_aligned, real_sequence_aligned = align(pred_holo_sequence, real_holo_sequence)
real_CSPs_aligned = align_shifts_to_seq(real_sequence_aligned, real_holo_sequence, real_CSPs)
pred_CSPs_aligned = align_shifts_to_seq(pred_sequence_aligned, pred_holo_sequence, pred_CSPs)
real_CSP_below_thresh = [ C for C in real_CSPs_aligned if C < real_CSP_outlier_cutoff and C > 0 ]
pred_CSP_below_thresh = [ C for C in pred_CSPs_aligned if C < pred_CSP_outlier_cutoff and C > 0 ]
pred_sig_cutoff = -1
real_sig_cutoff = -1
z_value = 0
try:
pred_sig_cutoff = calculate_z_score_threshold(pred_CSP_below_thresh, z_value)
real_sig_cutoff = calculate_z_score_threshold(real_CSP_below_thresh, z_value)
except Exception as e:
print(e)
pred_sig_cutoff = -1
real_sig_cutoff = -1
#continue
residues_expected_to_have_sig_CSPS = [True if f > pred_sig_cutoff else False for i, f in enumerate(pred_CSPs_aligned)]
TP = 0
TN = 0
FP = 0
FN = 0
try:
z_score_at_zero = calculate_z_score(real_CSP_below_thresh, 0)
except:
return
for j,real_CSP in enumerate(real_CSPs_aligned):
if real_CSP == -1:
continue
CSP_z_score = calculate_z_score(real_CSP_below_thresh, real_CSP)
if residues_expected_to_have_sig_CSPS[j]:
if CSP_z_score <= 0:
FN += -1 * max(CSP_z_score, z_score_at_zero)
else:
TP += min(CSP_z_score, 3)
else: # not in binding site
if CSP_z_score > 0:
FP += min(CSP_z_score, 3)
else:
TN += -1 * max(CSP_z_score, z_score_at_zero)
print(f"TP = {TP}, FP = {FP}\nTN = {TN}, FN = {FN}")
return TP, FP, FN, TN
def get_F_MCC_cons(TP, FP, FN, TN):
Precision = 0
if TP + FP > 0:
Precision = TP / ( TP + FP )
Recall = 0
if TP + FN > 0:
Recall = TP / ( TP + FN )
print(f"F = 2 * ({Precision} * {Recall}) / ({Precision} + {Recall})")
try:
F = 2 * (Precision * Recall) / (Precision + Recall)
except Exception as e:
print(e)
#raise
F = 0
print("("+str(TP)+"*"+str(TN)+"-"+str(FP)+"*"+str(FN)+")/sqrt(("+str(TP)+"+"+str(FP)+")*("+str(TN)+"+"+str(FP)+")*("+str(TN)+"+"+str(FN)+"))")
MCC = 0
try:
MCC = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
except Exception as e:
print(e)
MCC = -1
consensus = ( F + (MCC/2)+0.5 ) / 2
return F, MCC, consensus
def get_value_from_csv(csv_file, bound, apo, column):
try:
# Load the DataFrame if the CSV file exists
df = pd.read_csv(csv_file)
except (pd.errors.EmptyDataError, FileNotFoundError):
# Create an empty DataFrame if the CSV file is empty or doesn't exist
df = pd.DataFrame()
# Filter rows where 'bound' and 'apo' match the input
matching_rows = df[(df['bound'] == bound) & (df['apo'] == apo)]
# If there's no match, return None
if matching_rows.empty:
return None
# If there's more than one match, this will return the value from the first match
return matching_rows.iloc[0][column]
def get_longest_chain(pdb_filepath):
with open(pdb_filepath, 'r') as pdb_file:
lines = pdb_file.readlines()
chain_residue_count = defaultdict(set)
atom_lines = []
seqs = {}
last_res_ind = -1
for line in lines:
if line.startswith('ATOM') or line.startswith('HETATM'):
atom_lines.append(line)
chain_id = line[21]
if chain_id not in list(seqs):
seqs[chain_id] = ""
residue_index = int(line[22:26].strip())
if residue_index != last_res_ind:
last_res_ind = residue_index
#print(line[17:20].strip())
seqs[chain_id] += convert_aa_name(line[17:20].strip())
chain_residue_count[chain_id].add(residue_index)
# Identify the longest chain
longest_chain = max(chain_residue_count, key=lambda k: len(chain_residue_count[k]))
return longest_chain
def get_shortest_chain(pdb_filepath):
with open(pdb_filepath, 'r') as pdb_file:
lines = pdb_file.readlines()
chain_residue_count = defaultdict(set)
atom_lines = []
seqs = {}
last_res_ind = -1
for line in lines:
if line.startswith('ATOM') or line.startswith('HETATM'):
atom_lines.append(line)
chain_id = line[21]
if chain_id not in list(seqs):
seqs[chain_id] = ""
residue_index = int(line[22:26].strip())
if residue_index != last_res_ind:
last_res_ind = residue_index
#print(line[17:20].strip())
seqs[chain_id] += convert_aa_name(line[17:20].strip())
chain_residue_count[chain_id].add(residue_index)
# Identify the longest chain
longest_chain = min(chain_residue_count, key=lambda k: len(chain_residue_count[k]))
return longest_chain
def update_b_factors_longest_chain(pdb_filepath, b_factors, bound_seq, new_pdb_filepath):
with open(pdb_filepath, 'r') as pdb_file:
lines = pdb_file.readlines()
chain_residue_count = defaultdict(set)
atom_lines = []
seqs = {}
last_res_ind = -1
for line in lines:
if line.startswith('ATOM') or line.startswith('HETATM'):
atom_lines.append(line)
chain_id = line[21]
if chain_id not in list(seqs):
seqs[chain_id] = ""
residue_index = int(line[22:26].strip())
if residue_index != last_res_ind:
last_res_ind = residue_index
#print(line[17:20].strip())
seqs[chain_id] += convert_aa_name(line[17:20].strip())
chain_residue_count[chain_id].add(residue_index)
# Identify the longest chain
longest_chain = max(chain_residue_count, key=lambda k: len(chain_residue_count[k]))
new_bfactors = []
#print(seqs[longest_chain])
if len(chain_residue_count[longest_chain]) != len(b_factors):
# align
try:
bound_aligned1, bound_aligned2 = align(bound_seq, seqs[longest_chain])
#print(bound_aligned1)
#print(bound_aligned2)
except:
return
ind = 0
for i,c in enumerate(bound_aligned1):
if bound_aligned2[i] in ['_', '-']:
ind += 1
continue
if bound_aligned1[i] in ['_', '-']:
new_bfactors.append(-1)
continue
#if ind >= len(bound_seq) or ind >= len(b_factors) or c != bound_seq[ind]:
# new_bfactors.append(0)
# x = 0
#else:
new_bfactors.append(b_factors[ind])
ind += 1
if len(new_bfactors) != len(chain_residue_count[longest_chain]):
return
else:
new_bfactors = [l for l in b_factors]
b_factor_dict = {index: b_factor for index, b_factor in zip(sorted(chain_residue_count[longest_chain]), new_bfactors)}
#print(b_factor_dict)
updated_lines = []
for line in lines:
if (line.startswith('ATOM') or line.startswith('HETATM')) and line[21] == longest_chain:
residue_index = int(line[22:26].strip())
new_b_factor = b_factor_dict[residue_index]
updated_line = line[:60] + f'{new_b_factor:6.2f}' + line[66:]
#print(new_b_factor)
#print(updated_line)
updated_lines.append(updated_line)
else:
updated_lines.append(line)
#print(updated_lines)
print(new_pdb_filepath)
with open(new_pdb_filepath, 'w') as pdb_file:
pdb_file.writelines(updated_lines)
def get_interface_patch(contacts):
return set(i for contact in contacts for i in contact)
def calculate_ics_ips(pdb_file1, pdb_file2, threshold = 5.0):
contacts1 = get_interface_contacts(pdb_file1, threshold)
contacts2 = get_interface_contacts(pdb_file2, threshold)
# Calculate ICS
precision = len(contacts1.intersection(contacts2)) / len(contacts1)
recall = len(contacts2.intersection(contacts1)) / len(contacts2)
ics = 0
try:
ics = 2 * (precision * recall) / (precision + recall)
except:
ics = 0
# Calculate IPS
patch1 = get_interface_patch(contacts1)
patch2 = get_interface_patch(contacts2)
intersection_size = len(patch1.intersection(patch2))
union_size = len(patch1.union(patch2))
ips = 0
try:
ips = intersection_size / union_size
except:
ips = 0
return ics, ips
def get_atom_distance(atom1, atom2):
vec = atom1-atom2
distance = (vec[0]**2 + vec[1]**2 + vec[2]**2)**0.5
return distance
def get_closest_distance(residue1, residue2):
min_distance = 1000
for atom1 in residue1:
for atom2 in residue2:
distance = get_atom_distance(atom1.coord, atom2.coord)
# print(distance)
if distance < min_distance:
min_distance = distance
return min_distance
def get_interface_contacts(pdb_file, threshold):
parser = PDBParser()
structure = parser.get_structure('protein', pdb_file)
residues = [residue for model in structure for chain in model for residue in chain]
num_residues = len(residues)
pairs = []
for i in range(num_residues):
residue1 = residues[i]
chain1 = residue1.parent.id
for j in range(i+1, num_residues):
residue2 = residues[j]
chain2 = residue2.parent.id
if chain1 != chain2:
distance = get_closest_distance(residue1, residue2)
if distance <= threshold:
pairs.append((i, j))
ret = set(pairs)
return ret
def compute_lddt_score(pdb_file1, pdb_file2):
# need to merge chains because our version of lddt only accepts single chain files
merged_file1 = pdb_file1[:len(pdb_file1)-4]+'_merged.pdb'
merged_file2 = pdb_file2[:len(pdb_file2)-4]+'_merged.pdb'
merge_chains(pdb_file1, merged_file1)
merge_chains(pdb_file2, merged_file2)
clistring = './lddt ' + merged_file1 + ' ' + merged_file2
print(clistring)
result = subprocess.run(['./lddt', merged_file1, merged_file2],
capture_output=True, text=True)
# Parse the global LDDT score from the output
global_lddt_score = None
for line in result.stdout.splitlines():
if "Global LDDT score:" in line:
global_lddt_score = float(line.split()[3])
break
if global_lddt_score is None:
raise ValueError("Failed to parse global LDDT score from lddt output")
print("lddt score = " + str(global_lddt_score))
os.system('rm ' + merged_file1)
os.system('rm ' + merged_file2)
return global_lddt_score
def compute_TM_score(pdb_file1, pdb_file2):
# need to merge chains because our version of lddt only accepts single chain files
merged_file1 = pdb_file1[:len(pdb_file1)-4]+'_merged.pdb'
merged_file2 = pdb_file2[:len(pdb_file2)-4]+'_merged.pdb'
merge_chains(pdb_file1, merged_file1)
merge_chains(pdb_file2, merged_file2)
clistring = './TMalign ' + merged_file1 + ' ' + merged_file2
print(clistring)
# Call TMalign with the model and reference PDB files
result = subprocess.run(['./TMalign', merged_file1, merged_file2],
capture_output=True, text=True)
# Parse the TM-score from the output
tm_score = None
for line in result.stdout.splitlines():
if "TM-score=" in line and "LN=" in line:
tm_score = float(line.split()[1])
break
if tm_score is None:
raise ValueError("Failed to parse TM-score from TMalign output")
print("TM score = " + str(tm_score))
os.system('rm ' + merged_file1)
os.system('rm ' + merged_file2)
return tm_score
from Bio.PDB import PDBParser
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
from itertools import combinations
import random
from os.path import exists
def get_cb_pairs(structure, distance_threshold=8.0):
pairs = []
residue_count_1 = 1
residue_count_2 = 1
sorted_chains = sorted(list(structure[0]), key=lambda x: x.id) # Sorting chains alphabetically
len_chain1 = 0
for residue in sorted_chains[0]:
len_chain1 += 1
for residue1 in sorted_chains[0]:
if 'CB' in residue1:
residue_count_2 = 1
for residue2 in sorted_chains[1]:
if 'CB' in residue2:
distance = residue1['CB'] - residue2['CB']
if distance < distance_threshold:
pairs.append((residue_count_1, len_chain1 + residue_count_2))
residue_count_2 += 1
residue_count_1 += 1 # Increment the count after visiting each residue
return pairs if pairs else None
def cluster_and_pick_pairs(pairs, num_pairs):
pairs_comb = list(combinations(pairs, 2))
distances = [abs(pair1[0] - pair2[0]) + abs(pair1[1] - pair2[1]) for pair1, pair2 in pairs_comb]
Z = linkage(pdist(np.array(distances).reshape(-1, 1)), 'ward')
max_d = 1
clusters = fcluster(Z, max_d, criterion='distance')
picked_pairs = []
picked_clusters = set()
for i, c in enumerate(clusters):
if c not in picked_clusters:
picked_pairs.append(pairs_comb[i])
picked_clusters.add(c)
if len(picked_pairs) >= num_pairs:
break
return picked_pairs if picked_pairs else None
def get_1v1_restraints(pdb_file, num_pairs):
if exists(pdb_file):
try:
parser = PDBParser()
structure = parser.get_structure("protein_peptide_complex", pdb_file)
pairs = get_cb_pairs(structure)
#print(pairs)
if pairs is None:
return None
random.shuffle(pairs) # Shuffle the pairs to ensure randomness in the selection
picked_pairs = cluster_and_pick_pairs(pairs, num_pairs)
return [(pair[0], pair[1]) for pair in picked_pairs]
except Exception as e:
print(e)
return pdb_file
print('could not locate structure file.')
return None
def get_1vN_restraints(pdb_file, restraints, neighborhood_size=4):
parser = PDBParser()
structure = parser.get_structure("protein_peptide_complex", pdb_file)
all_residues = {(res.id[1], chain.id) for model in structure for chain in model for res in chain}
expanded_restraints = []
for pair in restraints:
single_residue = pair[0]
for restraint in pair[1]:
for res in restraint.split('-'):
res_id, chain_id = int(res[:-1]), res[-1]
for i in range(res_id - neighborhood_size, res_id + neighborhood_size + 1):
if (i, chain_id) in all_residues:
expanded_restraints.append((single_residue, i))
return expanded_restraints if expanded_restraints else None
def get_MvN_restraints(pdb_file, restraints, neighborhood_size=2):
parser = PDBParser()
structure = parser.get_structure("protein_peptide_complex", pdb_file)
all_residues = {(res.id[1], chain.id) for model in structure for chain in model for res in chain}
expanded_restraints = []
sorted_chains = sorted(list(structure[0]), key=lambda x: x.id) # Sorting chains alphabetically
len_chain1 = 0
for residue in sorted_chains[0]:
len_chain1 += 1
for pair in restraints:
pair1_chainA = pair[0][0]
pair1_chainB = pair[0][1]
pair2_chainA = pair[1][0]
pair2_chainB = pair[1][1]
for i in range(pair1_chainA-neighborhood_size, pair1_chainA+neighborhood_size+1):
if (i, 'A') in all_residues: # Assuming the chain id is 'A'
for j in range(pair1_chainB - neighborhood_size, pair1_chainB + neighborhood_size +1):
if (j-len_chain1, 'B') in all_residues:
expanded_restraints.append((i,j))
for j in range(pair2_chainB - neighborhood_size, pair2_chainB + neighborhood_size +1):
if (j-len_chain1, 'B') in all_residues:
expanded_restraints.append((i,j))
for i in range(pair2_chainA-neighborhood_size, pair2_chainA+neighborhood_size+1):
if (i, 'A') in all_residues: # Assuming the chain id is 'A'
for j in range(pair1_chainB - neighborhood_size, pair1_chainB + neighborhood_size +1):
if (j-len_chain1, 'B') in all_residues:
expanded_restraints.append((i,j))
for j in range(pair2_chainB - neighborhood_size, pair2_chainB + neighborhood_size +1):
if (j-len_chain1, 'B') in all_residues:
expanded_restraints.append((i,j))
# Group by the first residue and sort the second residues
grouped_restraints = {}
for pair in expanded_restraints:
if pair[0] not in grouped_restraints:
grouped_restraints[pair[0]] = []
grouped_restraints[pair[0]].append(pair[1])
for key in grouped_restraints:
grouped_restraints[key].sort()
# Find the continuous ranges of the second residues
final_result = []
for first_residue, second_residues in grouped_restraints.items():
range_start = second_residues[0]
range_end = second_residues[0]
for second_residue in second_residues[1:]:
if second_residue == range_end + 1 or second_residue == range_end:
range_end = second_residue
else:
final_result.append([first_residue, (range_start, range_end)])
range_start = range_end = second_residue
final_result.append([first_residue, (range_start, range_end)]) # append the last range
final_result.append(4)
return final_result