WebAppPTS / app.py
Bhanushray's picture
Update app.py
f22f24a verified
import os
import time
import requests
import numpy as np
from flask import Flask, render_template, request, send_file
from rdkit import Chem
from transformers import AutoModelForMaskedLM, AutoTokenizer
from modelstrc import CVanilla_RNN_Builder, get_mol_from_graph_list
from transformers import AutoModel, AutoTokenizer
import torch
import re
import torch.nn as nn
# DIRECTORIES
bio_model_dir = "/app/modelsBioembedSmall"
cvn_model_dir = "/app/models_folder"
UPLOAD_FOLDER = "/app/Samples"
UF="/tmp/"
os.makedirs(bio_model_dir, exist_ok=True)
os.makedirs(cvn_model_dir, exist_ok=True)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# ENV VARIABLES
os.environ["TMPDIR"] = bio_model_dir
os.environ["TEMP"] = bio_model_dir
os.environ["TMP"] = bio_model_dir
os.environ['NUMBA_CACHE_DIR'] = '/app/numba_cache'
os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache'
# ESM2 MODEL AND TOKENIZER
try:
print("Loading ESM2 model...")
model_name = "facebook/esm2_t6_8M_UR50D" # Smaller model with 320-dim embedding
tokenizer = AutoTokenizer.from_pretrained(bio_model_dir)
model = AutoModel.from_pretrained(bio_model_dir)
model.eval()
print("ESM2 model loaded.")
except Exception as e:
print(f"Error loading ESM2 model: {e}")
model = None
tokenizer = None
# linear transformation to map 320D embeddings to 1024D
class EmbeddingTransformer(nn.Module):
def __init__(self, input_dim, output_dim):
super(EmbeddingTransformer, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
transformer = EmbeddingTransformer(input_dim=320, output_dim=1024)
# UDF TO GENERATE EMBEDDINGS
def generate_bio_embeddings(sequence):
"""
Generate protein sequence embeddings using ESM2 model.
Maps the 320-dimensional embedding to 1024 dimensions.
"""
if model is None or tokenizer is None:
print("Model or tokenizer not loaded.")
return None
if not sequence:
print("Sequence is empty after cleaning.")
return None
try:
inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state
mean_embedding = embeddings.mean(dim=1).squeeze()
transformed_embedding = transformer(mean_embedding)
transformed_embedding = transformed_embedding.detach().numpy()
return transformed_embedding.reshape(1, -1)
except Exception as e:
print(f"Embedding Error: {e}")
return None
# UDF FOR SMILES GENERATION
def generate_smiles(sequence, n_samples=100):
start_time = time.time()
protein_embedding = generate_bio_embeddings(sequence)
if protein_embedding is None:
return None, "Embedding generation failed!"
model = CVanilla_RNN_Builder(cvn_model_dir, gpu_id=None)
samples = model.sample(n_samples, c=protein_embedding[0], output_type='graph')
valid_samples = [sample for sample in samples if sample is not None]
smiles_list = [
Chem.MolToSmiles(mol) for mol in get_mol_from_graph_list(valid_samples, sanitize=True) if mol is not None
]
if not smiles_list:
return None, "No valid SMILES generated!"
filename = os.path.join(UF, "SMILES_GENERATED.txt")
with open(filename, "w") as file:
file.write("\n".join(smiles_list))
elapsed_time = time.time() - start_time
return filename, elapsed_time
app = Flask(__name__)
@app.route("/", methods=["GET", "POST"])
def index():
if request.method == "POST":
sequence = request.form["sequence"].strip()
if not sequence:
return render_template("index.html", message="Please enter a valid sequence.")
file_path, result = generate_smiles(sequence)
if file_path is None:
return render_template("index.html", message=f"Error: {result}")
return render_template("index.html", message="SMILES generated successfully!", file_path=file_path, time_taken=result)
return render_template("index.html")
@app.route("/download")
def download_file():
file_path = os.path.join(UF, "SMILES_GENERATED.txt")
return send_file(file_path, as_attachment=True)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)