|
|
from pathlib import Path |
|
|
import argparse |
|
|
import json |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from sklearn.model_selection import KFold |
|
|
from sklearn.metrics import average_precision_score |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from sklearn.cluster import AgglomerativeClustering |
|
|
from Bio import Phylo |
|
|
from triton.language import bfloat16 |
|
|
from scipy.stats import pearsonr, spearmanr |
|
|
import json |
|
|
import itertools |
|
|
import logging |
|
|
|
|
|
import hydra |
|
|
from hydra import compose, initialize, initialize_config_dir |
|
|
import models |
|
|
from collections import OrderedDict |
|
|
import noise_schedule |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import ast |
|
|
from omegaconf import OmegaConf, DictConfig, ListConfig |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
|
|
|
current_directory = Path('/data2/tianang/projects/Synergy') |
|
|
|
|
|
with initialize_config_dir(config_dir="/data2/tianang/projects/mdlm/configs"): |
|
|
config = compose(config_name="config") |
|
|
|
|
|
class mol_emb_mdlm(nn.Module): |
|
|
def __init__(self, config, vocab_size, ckpt_path, mask_index): |
|
|
super(mol_emb_mdlm, self).__init__() |
|
|
self.config = config |
|
|
self.vocab_size = vocab_size |
|
|
self.mask_index = mask_index |
|
|
self.ckpt_path = ckpt_path |
|
|
self.parameterization = self.config.parameterization |
|
|
self.time_conditioning = self.config.time_conditioning |
|
|
self.backbone = self.load_DIT() |
|
|
|
|
|
self.noise = noise_schedule.get_noise(self.config) |
|
|
|
|
|
def _process_sigma(self, sigma): |
|
|
if sigma is None: |
|
|
assert self.parameterization == 'ar' |
|
|
return sigma |
|
|
if sigma.ndim > 1: |
|
|
sigma = sigma.squeeze(-1) |
|
|
if not self.time_conditioning: |
|
|
sigma = torch.zeros_like(sigma) |
|
|
assert sigma.ndim == 1, sigma.shape |
|
|
return sigma |
|
|
|
|
|
def _sample_t(self, n, device): |
|
|
sampling_eps = 1e-3 |
|
|
_eps_t = torch.rand(n, device=device) * 0 |
|
|
t = (1 - sampling_eps) * _eps_t + sampling_eps |
|
|
return t * 0 |
|
|
|
|
|
def _forward(self, x, sigma, attnmask): |
|
|
sigma = self._process_sigma(sigma) |
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
|
x = self.backbone.vocab_embed(x) |
|
|
c = F.silu(self.backbone.sigma_map(sigma)) |
|
|
rotary_cos_sin = self.backbone.rotary_emb(x) |
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
for i in range(len(self.backbone.blocks)): |
|
|
x = self.backbone.blocks[i](x, rotary_cos_sin, c, seqlens=None, attnmask=attnmask) |
|
|
|
|
|
return x |
|
|
|
|
|
def q_xt(self, x, move_chance): |
|
|
"""Computes the noisy sample xt. |
|
|
|
|
|
Args: |
|
|
x: int torch.Tensor with shape (batch_size, |
|
|
diffusion_model_input_length), input. |
|
|
move_chance: float torch.Tensor with shape (batch_size, 1). |
|
|
""" |
|
|
move_indices = torch.rand(*x.shape, device=x.device) < move_chance |
|
|
xt = torch.where(move_indices, self.mask_index, x) |
|
|
return xt |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
t = self._sample_t(input_ids.shape[0], input_ids.device) |
|
|
sigma, dsigma = self.noise(t) |
|
|
unet_conditioning = sigma[:, None] |
|
|
move_chance = 1 - torch.exp(-sigma[:, None]) |
|
|
xt = self.q_xt(input_ids, move_chance) |
|
|
outputs = self._forward(xt, unet_conditioning, attnmask = attention_mask) |
|
|
return outputs |
|
|
|
|
|
def load_DIT(self): |
|
|
backbone = models.dit.DIT_non_pad(self.config, vocab_size=self.vocab_size) |
|
|
|
|
|
if self.ckpt_path is not None: |
|
|
lightning_ckpt = torch.load(self.ckpt_path, map_location='cpu') |
|
|
state_dict = lightning_ckpt['state_dict'] |
|
|
|
|
|
new_sd = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('backbone.'): |
|
|
new_key = k[len('backbone.'):] |
|
|
else: |
|
|
new_key = k |
|
|
new_sd[new_key] = v |
|
|
|
|
|
backbone.load_state_dict(new_sd, strict=False) |
|
|
|
|
|
return backbone |
|
|
|
|
|
|
|
|
class MolEmbDLM(mol_emb_mdlm, PyTorchModelHubMixin): |
|
|
""" |
|
|
继承你的原模型 + Hub Mixin。 |
|
|
不重写任何方法,使用默认 _save_pretrained/_from_pretrained。 |
|
|
""" |
|
|
def __init__(self, config, vocab_size, ckpt_path, mask_index): |
|
|
super(mol_emb_mdlm, self).__init__() |
|
|
if not isinstance(config, DictConfig) and not isinstance(config, ListConfig): |
|
|
self.config = OmegaConf.create(config) |
|
|
else: |
|
|
self.config = config |
|
|
self.vocab_size = vocab_size |
|
|
self.mask_index = mask_index |
|
|
self.ckpt_path = ckpt_path |
|
|
self.parameterization = self.config.parameterization |
|
|
self.time_conditioning = self.config.time_conditioning |
|
|
self.backbone = self.load_DIT() |
|
|
|
|
|
self.noise = noise_schedule.get_noise(self.config) |
|
|
|
|
|
def to_container_safe(node): |
|
|
if isinstance(node, DictConfig): |
|
|
out = {} |
|
|
for k in node.keys(): |
|
|
try: |
|
|
out[k] = to_container_safe(node[k]) |
|
|
except Exception: |
|
|
out[k] = str(node._get_node(k)) |
|
|
return out |
|
|
if isinstance(node, ListConfig): |
|
|
res = [] |
|
|
for i, v in enumerate(node): |
|
|
try: |
|
|
res.append(to_container_safe(v)) |
|
|
except Exception: |
|
|
res.append(str(node._get_node(i))) |
|
|
return res |
|
|
return node |
|
|
|
|
|
def build_hf_config(hydra_cfg, tokenizer): |
|
|
return { |
|
|
"model_type": "mol_emb_raw", |
|
|
"vocab_size": len(tokenizer.get_vocab()), |
|
|
"ckpt_path": None, |
|
|
"config":to_container_safe(hydra_cfg), |
|
|
"hidden_size": hydra_cfg.model.hidden_size, |
|
|
"n_blocks": hydra_cfg.model.n_blocks, |
|
|
"n_heads": hydra_cfg.model.n_heads, |
|
|
"max_position_embeddings": hydra_cfg.model.length, |
|
|
"parameterization": hydra_cfg.parameterization, |
|
|
"time_conditioning": hydra_cfg.time_conditioning, |
|
|
"noise_schedule_type": hydra_cfg.noise.type, |
|
|
"sigma_min": hydra_cfg.noise.sigma_min, |
|
|
"sigma_max": hydra_cfg.noise.sigma_max, |
|
|
"mask_index": tokenizer.mask_token_id, |
|
|
"tokenizer_name_or_path": hydra_cfg.data.tokenizer_name_or_path |
|
|
} |
|
|
|
|
|
if __name__ == '__main__': |
|
|
model_name = "ibm-research/materials.selfies-ted" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
DIT_ckpt_path = '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt' |
|
|
model = MolEmbDLM(config, len(tokenizer.get_vocab()), DIT_ckpt_path, tokenizer.mask_token_id) |
|
|
|
|
|
hf_config = build_hf_config(config, tokenizer) |
|
|
|
|
|
EXPORT_DIR = "/data2/tianang/projects/mdlm/huggingface/huggingface_model" |
|
|
model.save_pretrained(EXPORT_DIR, config=hf_config) |
|
|
tokenizer.save_pretrained(EXPORT_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|