Spaces:
Sleeping
Sleeping
import os | |
import time | |
import requests | |
import numpy as np | |
from flask import Flask, render_template, request, send_file | |
from rdkit import Chem | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
from modelstrc import CVanilla_RNN_Builder, get_mol_from_graph_list | |
from transformers import AutoModel, AutoTokenizer | |
import torch | |
import re | |
import torch.nn as nn | |
# DIRECTORIES | |
bio_model_dir = "/app/modelsBioembedSmall" | |
cvn_model_dir = "/app/models_folder" | |
UPLOAD_FOLDER = "/app/Samples" | |
UF="/tmp/" | |
os.makedirs(bio_model_dir, exist_ok=True) | |
os.makedirs(cvn_model_dir, exist_ok=True) | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
# ENV VARIABLES | |
os.environ["TMPDIR"] = bio_model_dir | |
os.environ["TEMP"] = bio_model_dir | |
os.environ["TMP"] = bio_model_dir | |
os.environ['NUMBA_CACHE_DIR'] = '/app/numba_cache' | |
os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache' | |
# ESM2 MODEL AND TOKENIZER | |
try: | |
print("Loading ESM2 model...") | |
model_name = "facebook/esm2_t6_8M_UR50D" # Smaller model with 320-dim embedding | |
tokenizer = AutoTokenizer.from_pretrained(bio_model_dir) | |
model = AutoModel.from_pretrained(bio_model_dir) | |
model.eval() | |
print("ESM2 model loaded.") | |
except Exception as e: | |
print(f"Error loading ESM2 model: {e}") | |
model = None | |
tokenizer = None | |
# linear transformation to map 320D embeddings to 1024D | |
class EmbeddingTransformer(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(EmbeddingTransformer, self).__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
return self.linear(x) | |
transformer = EmbeddingTransformer(input_dim=320, output_dim=1024) | |
# UDF TO GENERATE EMBEDDINGS | |
def generate_bio_embeddings(sequence): | |
""" | |
Generate protein sequence embeddings using ESM2 model. | |
Maps the 320-dimensional embedding to 1024 dimensions. | |
""" | |
if model is None or tokenizer is None: | |
print("Model or tokenizer not loaded.") | |
return None | |
if not sequence: | |
print("Sequence is empty after cleaning.") | |
return None | |
try: | |
inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state | |
mean_embedding = embeddings.mean(dim=1).squeeze() | |
transformed_embedding = transformer(mean_embedding) | |
transformed_embedding = transformed_embedding.detach().numpy() | |
return transformed_embedding.reshape(1, -1) | |
except Exception as e: | |
print(f"Embedding Error: {e}") | |
return None | |
# UDF FOR SMILES GENERATION | |
def generate_smiles(sequence, n_samples=100): | |
start_time = time.time() | |
protein_embedding = generate_bio_embeddings(sequence) | |
if protein_embedding is None: | |
return None, "Embedding generation failed!" | |
model = CVanilla_RNN_Builder(cvn_model_dir, gpu_id=None) | |
samples = model.sample(n_samples, c=protein_embedding[0], output_type='graph') | |
valid_samples = [sample for sample in samples if sample is not None] | |
smiles_list = [ | |
Chem.MolToSmiles(mol) for mol in get_mol_from_graph_list(valid_samples, sanitize=True) if mol is not None | |
] | |
if not smiles_list: | |
return None, "No valid SMILES generated!" | |
filename = os.path.join(UF, "SMILES_GENERATED.txt") | |
with open(filename, "w") as file: | |
file.write("\n".join(smiles_list)) | |
elapsed_time = time.time() - start_time | |
return filename, elapsed_time | |
app = Flask(__name__) | |
def index(): | |
if request.method == "POST": | |
sequence = request.form["sequence"].strip() | |
if not sequence: | |
return render_template("index.html", message="Please enter a valid sequence.") | |
file_path, result = generate_smiles(sequence) | |
if file_path is None: | |
return render_template("index.html", message=f"Error: {result}") | |
return render_template("index.html", message="SMILES generated successfully!", file_path=file_path, time_taken=result) | |
return render_template("index.html") | |
def download_file(): | |
file_path = os.path.join(UF, "SMILES_GENERATED.txt") | |
return send_file(file_path, as_attachment=True) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |