Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSP_Rank/hca.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
158 lines (131 sloc)
5.5 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#hca.py | |
import os | |
import dash | |
from dash import dcc, html | |
from dash.dependencies import Input, Output | |
import pandas as pd | |
from sklearn.cluster import AgglomerativeClustering | |
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster | |
import hdbscan | |
import matplotlib.pyplot as plt | |
import plotly.express as px | |
pca_csv = './data/7jq8_aligned_CSPREDB_PCA_chain_B_data.csv' | |
cspred_csv = 'CSP_7jq8_CSpred.csv' | |
destination_dir = './cluster/7jq8/' | |
# Create the destination directory if it doesn't exist | |
os.makedirs(destination_dir, exist_ok=True) | |
features = {} | |
pca_df = pd.read_csv(pca_csv) | |
cspred_df = pd.read_csv(cspred_csv) | |
for i,pdb_file in enumerate(pca_df['pdb_file']): | |
pdb_file = pdb_file[:pdb_file.rfind('.')] | |
pc1 = float(pca_df['PC1'][i]) | |
pc2 = float(pca_df['PC2'][i]) | |
pc3 = float(pca_df['PC3'][i]) | |
features[pdb_file] = [pc1, pc2, pc3] | |
for i, pdb_file in enumerate(cspred_df['holo_model_path']): | |
try: | |
pdb_file = pdb_file[pdb_file.rfind('/')+1:pdb_file.rfind('.')] | |
if pdb_file not in list(features): | |
continue | |
else: | |
features[pdb_file].append(cspred_df['F1'][i]) | |
features[pdb_file].append(cspred_df['MCC'][i]) | |
features[pdb_file].append(cspred_df['consensus'][i]) | |
except: | |
continue | |
# Convert features dictionary to DataFrame | |
features_df = pd.DataFrame.from_dict(features, orient='index', columns=['pc1', 'pc2', 'pc3', 'F1', 'MCC', 'consensus']) | |
features_df['pdb_file'] = features_df.index | |
# Filter out rows with less than 6 features | |
features_df = features_df.dropna() | |
# Calculate the 90th percentile of the consensus column | |
consensus_90th_percentile = features_df['consensus'].quantile(0.90) | |
# Filter the DataFrame to only include rows where the consensus value is in the top 10 percentile | |
features_df = features_df[features_df['consensus'] >= consensus_90th_percentile] | |
# Prepare data for clustering by excluding the pdb_file column | |
clustering_data = features_df.drop(columns=['pdb_file']) | |
# Recursive HDBSCAN clustering to remove noise points | |
while True: | |
hdbscan_clusterer = hdbscan.HDBSCAN(min_cluster_size=2) | |
hdbscan_labels = hdbscan_clusterer.fit_predict(clustering_data) | |
if -1 not in hdbscan_labels: | |
break | |
# Filter out noise points | |
features_df = features_df[hdbscan_labels != -1] | |
clustering_data = clustering_data[hdbscan_labels != -1] | |
print(f"Filtered out noise points, remaining points: {len(features_df)}") | |
# Perform final clustering | |
hdbscan_clusterer = hdbscan.HDBSCAN(min_cluster_size=2) | |
final_labels = hdbscan_clusterer.fit_predict(clustering_data) | |
# Add cluster labels to the DataFrame | |
features_df['hdbscan_labels'] = final_labels | |
# Create Dash app | |
app = dash.Dash(__name__) | |
app.layout = html.Div([ | |
dcc.RadioItems( | |
id='color-toggle', | |
options=[ | |
{'label': 'Color by HDBSCAN Cluster', 'value': 'hdbscan_labels'}, | |
{'label': 'Color by Consensus', 'value': 'consensus'} | |
], | |
value='hdbscan_labels', | |
labelStyle={'display': 'inline-block'} | |
), | |
dcc.Graph(id='scatter-plot') | |
]) | |
@app.callback( | |
Output('scatter-plot', 'figure'), | |
[Input('color-toggle', 'value')] | |
) | |
def update_figure(selected_color): | |
if selected_color == 'consensus': | |
fig = px.scatter(features_df, x='pc1', y='pc2', color='consensus', | |
color_continuous_scale='Bluered', hover_data=['consensus', 'hdbscan_labels', 'pdb_file'], | |
title='Final HDBSCAN Clustering') | |
else: | |
fig = px.scatter(features_df, x='pc1', y='pc2', color='hdbscan_labels', | |
hover_data=['consensus', 'pdb_file'], title='Final HDBSCAN Clustering') | |
fig.update_layout(legend_title_text='Cluster') | |
return fig | |
@app.callback( | |
Output('file-move-output', 'children'), | |
[Input('scatter-plot', 'clickData')] | |
) | |
def move_files(clickData): | |
print("HERE") | |
if clickData: | |
pdb_file = clickData['points'][0]['customdata'][1] | |
print(pdb_file) | |
original_path = cspred_df[cspred_df['holo_model_path'].str.contains(pdb_file)]['holo_model_path'].values[0] | |
shutil.move(original_path, os.path.join(destination_dir, os.path.basename(original_path))) | |
return f"Moved file: {original_path} to {destination_dir}" | |
return "Click on a point to move the corresponding PDB file." | |
# Optionally, you can save the final filtered DataFrame and clustering results to a CSV file | |
features_df.to_csv('final_clustering_results.csv', index=True) | |
# Perform hierarchical clustering using Agglomerative Clustering (Scikit-learn) | |
agg_clustering = AgglomerativeClustering(n_clusters=3) | |
agg_clustering.fit(clustering_data) | |
agg_labels = agg_clustering.labels_ | |
# Perform hierarchical clustering using Linkage (SciPy) | |
Z = linkage(clustering_data, method='ward') | |
#dendrogram(Z) | |
#plt.title('Dendrogram (SciPy)') | |
#plt.xlabel('Sample index') | |
#plt.ylabel('Distance') | |
#plt.show() | |
# Perform flat clustering using fcluster (SciPy) | |
max_d = 50 | |
clusters = fcluster(Z, max_d, criterion='distance') | |
# Print the clustering results | |
#print("Agglomerative Clustering Labels (Scikit-learn):", agg_labels) | |
#print("Flat Clustering Labels (SciPy):", clusters) | |
# Print the final clustering labels | |
#print("Final HDBSCAN Clustering Labels:", final_labels) | |
# Optionally, you can save the results to a CSV file | |
features_df['agg_labels'] = agg_labels | |
features_df['scipy_clusters'] = clusters | |
features_df['hdbscan_labels'] = hdbscan_labels | |
features_df.to_csv('clustering_results.csv', index=True) | |
if __name__ == '__main__': | |
app.run_server(debug=True, port=8054) |