Commit
·
ef9533f
1
Parent(s):
c6e45fc
repair config path issue
Browse files- .idea/.gitignore +8 -0
- .idea/ApexOracle.iml +12 -0
- .idea/deployment.xml +16 -0
- .idea/inspectionProfiles/Project_Default.xml +7 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +7 -0
- DLM_emb_model.py +3 -3
- configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
- configs/callbacks/checkpoint_monitor.yaml +10 -0
- configs/callbacks/learning_rate_monitor.yaml +3 -0
- configs/config.yaml +102 -0
- configs/data/ag_news.yaml +6 -0
- configs/data/lambada.yaml +6 -0
- configs/data/lm1b-gpt2.yaml +6 -0
- configs/data/lm1b-streaming.yaml +6 -0
- configs/data/lm1b.yaml +6 -0
- configs/data/openwebtext-split.yaml +6 -0
- configs/data/openwebtext-streaming.yaml +6 -0
- configs/data/openwebtext.yaml +6 -0
- configs/data/ptb.yaml +6 -0
- configs/data/scientific_papers_arxiv.yaml +6 -0
- configs/data/scientific_papers_pubmed.yaml +6 -0
- configs/data/text8-crop.yaml +7 -0
- configs/data/text8.yaml +7 -0
- configs/data/wikitext103.yaml +6 -0
- configs/data/wikitext2.yaml +6 -0
- configs/lr_scheduler/constant_warmup.yaml +2 -0
- configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
- configs/model/medium.yaml +10 -0
- configs/model/small-ar.yaml +11 -0
- configs/model/small.yaml +10 -0
- configs/model/tiny-ar.yaml +11 -0
- configs/model/tiny-dimamba.yaml +11 -0
- configs/model/tiny.yaml +10 -0
- configs/noise/ar.yaml +2 -0
- configs/noise/linear.yaml +3 -0
- configs/noise/loglinear.yaml +3 -0
- configs/noise/polynomial.yaml +5 -0
- configs/strategy/ddp.yaml +2 -0
- configs/strategy/fsdp.yaml +3 -0
- temp_data/polymers_lit_scraped.csv +57 -0
- temp_fangping.py +74 -0
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/ApexOracle.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="ApexOracle_HF_H100" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/deployment.xml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="PublishConfigData" autoUpload="Always" serverName="ApexOracle HF H100" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false">
|
| 4 |
+
<option name="confirmBeforeUploading" value="false" />
|
| 5 |
+
<serverData>
|
| 6 |
+
<paths name="ApexOracle HF H100">
|
| 7 |
+
<serverdata>
|
| 8 |
+
<mappings>
|
| 9 |
+
<mapping deploy="/data2/tianang/projects/ApexOracle" local="$PROJECT_DIR$" />
|
| 10 |
+
</mappings>
|
| 11 |
+
</serverdata>
|
| 12 |
+
</paths>
|
| 13 |
+
</serverData>
|
| 14 |
+
<option name="myAutoUpload" value="ALWAYS" />
|
| 15 |
+
</component>
|
| 16 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
| 5 |
+
<inspection_tool class="PyUnboundLocalVariableInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
| 6 |
+
</profile>
|
| 7 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.9" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="ApexOracle_HF_H100" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ApexOracle.iml" filepath="$PROJECT_DIR$/.idea/ApexOracle.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 6 |
+
</component>
|
| 7 |
+
</project>
|
DLM_emb_model.py
CHANGED
|
@@ -31,10 +31,10 @@ import ast
|
|
| 31 |
from omegaconf import OmegaConf, DictConfig, ListConfig
|
| 32 |
from huggingface_hub import PyTorchModelHubMixin
|
| 33 |
|
| 34 |
-
|
| 35 |
-
current_directory = Path('/data2/tianang/projects/Synergy')
|
| 36 |
|
| 37 |
-
with initialize_config_dir(config_dir="
|
| 38 |
config = compose(config_name="config")
|
| 39 |
|
| 40 |
class mol_emb_mdlm(nn.Module):
|
|
|
|
| 31 |
from omegaconf import OmegaConf, DictConfig, ListConfig
|
| 32 |
from huggingface_hub import PyTorchModelHubMixin
|
| 33 |
|
| 34 |
+
current_directory = Path(__file__).parent
|
| 35 |
+
# current_directory = Path('/data2/tianang/projects/Synergy')
|
| 36 |
|
| 37 |
+
with initialize_config_dir(config_dir=str(current_directory/"configs")):
|
| 38 |
config = compose(config_name="config")
|
| 39 |
|
| 40 |
class mol_emb_mdlm(nn.Module):
|
configs/callbacks/checkpoint_every_n_steps.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_every_n_steps:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 3 |
+
save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
|
| 4 |
+
save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
|
| 5 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
| 6 |
+
verbose: True
|
| 7 |
+
auto_insert_metric_name: False
|
| 8 |
+
every_n_train_steps: 500
|
configs/callbacks/checkpoint_monitor.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_monitor:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 3 |
+
monitor: val/nll # name of the logged metric which determines when model is improving
|
| 4 |
+
mode: min # can be "max" or "min"
|
| 5 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 6 |
+
save_last: False # True = additionally always save model from last epoch
|
| 7 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
| 8 |
+
filename: best
|
| 9 |
+
auto_insert_metric_name: False
|
| 10 |
+
verbose: True
|
configs/callbacks/learning_rate_monitor.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
learning_rate_monitor:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
| 3 |
+
logging_interval: step
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
| 4 |
+
- /data: openwebtext
|
| 5 |
+
- /model: small # small / medium
|
| 6 |
+
- /strategy: ddp
|
| 7 |
+
- /noise: loglinear
|
| 8 |
+
- /lr_scheduler: constant_warmup
|
| 9 |
+
|
| 10 |
+
mode: sample_eval # train / ppl_eval / sample_eval
|
| 11 |
+
diffusion: absorbing_state
|
| 12 |
+
backbone: dit # dit / dimamba / ar
|
| 13 |
+
parameterization: subs # subs / d3pm / sedd
|
| 14 |
+
time_conditioning: False
|
| 15 |
+
T: 0 # 0 (continuous time) / 1000
|
| 16 |
+
subs_masking: False
|
| 17 |
+
|
| 18 |
+
seed: 1
|
| 19 |
+
|
| 20 |
+
loader:
|
| 21 |
+
global_batch_size: 512
|
| 22 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 23 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 24 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 25 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 26 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 27 |
+
pin_memory: True
|
| 28 |
+
|
| 29 |
+
sampling:
|
| 30 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 31 |
+
steps: 128
|
| 32 |
+
noise_removal: True
|
| 33 |
+
# TODO(yair): @subham, why aren't these params under `eval`?
|
| 34 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 35 |
+
num_sample_log: 2
|
| 36 |
+
semi_ar: False
|
| 37 |
+
stride_length: 1
|
| 38 |
+
num_strides: 1
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
training:
|
| 42 |
+
ema: 0.9999
|
| 43 |
+
antithetic_sampling: True
|
| 44 |
+
importance_sampling: False
|
| 45 |
+
sampling_eps: 1e-3
|
| 46 |
+
change_of_variables: False
|
| 47 |
+
|
| 48 |
+
eval:
|
| 49 |
+
checkpoint_path: '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt' # Used to evaluate a checkpoint after training.
|
| 50 |
+
disable_ema: False
|
| 51 |
+
compute_generative_perplexity: False
|
| 52 |
+
perplexity_batch_size: 8
|
| 53 |
+
compute_perplexity_on_sanity: False
|
| 54 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 55 |
+
generate_samples: True
|
| 56 |
+
|
| 57 |
+
optim:
|
| 58 |
+
weight_decay: 0
|
| 59 |
+
lr: 3e-4
|
| 60 |
+
beta1: 0.9
|
| 61 |
+
beta2: 0.999
|
| 62 |
+
eps: 1e-8
|
| 63 |
+
|
| 64 |
+
trainer:
|
| 65 |
+
_target_: lightning.Trainer
|
| 66 |
+
accelerator: cuda
|
| 67 |
+
num_nodes: 1
|
| 68 |
+
devices: ${device_count:}
|
| 69 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 70 |
+
gradient_clip_val: 1.0
|
| 71 |
+
precision: 'bf16'
|
| 72 |
+
num_sanity_val_steps: 2
|
| 73 |
+
max_steps: 1_000_000
|
| 74 |
+
log_every_n_steps: 10
|
| 75 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 76 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 77 |
+
val_check_interval: 10000
|
| 78 |
+
|
| 79 |
+
wandb:
|
| 80 |
+
project: text-diffusion
|
| 81 |
+
notes: Mulan for text
|
| 82 |
+
group: null
|
| 83 |
+
job_type: null
|
| 84 |
+
name: null
|
| 85 |
+
id: ${.name}_${seed}
|
| 86 |
+
tags:
|
| 87 |
+
- ${noise.type}
|
| 88 |
+
- ${data.train}
|
| 89 |
+
- ${data.valid}
|
| 90 |
+
|
| 91 |
+
hydra:
|
| 92 |
+
run:
|
| 93 |
+
dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
|
| 94 |
+
job:
|
| 95 |
+
chdir: true
|
| 96 |
+
|
| 97 |
+
checkpointing:
|
| 98 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 99 |
+
save_dir: ${cwd:}
|
| 100 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 101 |
+
resume_from_ckpt: true
|
| 102 |
+
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
|
configs/data/ag_news.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: ag_news
|
| 2 |
+
valid: ag_news
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/lambada.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: lambada
|
| 2 |
+
valid: lambada
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/lm1b-gpt2.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: lm1b
|
| 2 |
+
valid: lm1b
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/lm1b-streaming.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: lm1b
|
| 2 |
+
valid: lm1b
|
| 3 |
+
tokenizer_name_or_path: bert-base-uncased
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: False
|
| 6 |
+
streaming: True
|
configs/data/lm1b.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: lm1b
|
| 2 |
+
valid: lm1b
|
| 3 |
+
tokenizer_name_or_path: bert-base-uncased
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: False
|
| 6 |
+
streaming: False
|
configs/data/openwebtext-split.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: openwebtext-train
|
| 2 |
+
valid: openwebtext-valid
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/openwebtext-streaming.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: openwebtext
|
| 2 |
+
valid: wikitext103
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /tmp/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: True
|
configs/data/openwebtext.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: openwebtext
|
| 2 |
+
valid: wikitext103
|
| 3 |
+
tokenizer_name_or_path: ibm-research/materials.selfies-ted
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/ptb.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: ptb
|
| 2 |
+
valid: ptb
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/scientific_papers_arxiv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: scientific_papers_arxiv
|
| 2 |
+
valid: scientific_papers_arxiv
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/scientific_papers_pubmed.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: scientific_papers_pubmed
|
| 2 |
+
valid: scientific_papers_pubmed
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/text8-crop.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
|
| 2 |
+
train: text8-crop
|
| 3 |
+
valid: text8
|
| 4 |
+
tokenizer_name_or_path: text8
|
| 5 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 6 |
+
wrap: True
|
| 7 |
+
streaming: False
|
configs/data/text8.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
|
| 2 |
+
train: text8
|
| 3 |
+
valid: text8
|
| 4 |
+
tokenizer_name_or_path: text8
|
| 5 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 6 |
+
wrap: True
|
| 7 |
+
streaming: False
|
configs/data/wikitext103.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: wikitext103
|
| 2 |
+
valid: wikitext103
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/data/wikitext2.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: wikitext2
|
| 2 |
+
valid: wikitext2
|
| 3 |
+
tokenizer_name_or_path: gpt2
|
| 4 |
+
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
|
| 5 |
+
wrap: True
|
| 6 |
+
streaming: False
|
configs/lr_scheduler/constant_warmup.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 2 |
+
num_warmup_steps: 2500
|
configs/lr_scheduler/cosine_decay_warmup.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: utils.CosineDecayWarmupLRScheduler
|
| 2 |
+
t_in_epochs: False
|
| 3 |
+
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
|
| 4 |
+
warmup_prefix: True
|
| 5 |
+
warmup_lr_init: 1e-6
|
| 6 |
+
warmup_t: ${eval:0.1*${trainer.max_steps}}
|
| 7 |
+
lr_min: 1e-6
|
configs/model/medium.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: medium
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 1024
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 24
|
| 7 |
+
n_heads: 16
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
tie_word_embeddings: False
|
configs/model/small-ar.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: small
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 768
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 12
|
| 7 |
+
n_heads: 12
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
causal: True
|
| 11 |
+
tie_word_embeddings: False
|
configs/model/small.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: small
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 768
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 12
|
| 7 |
+
n_heads: 12
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
tie_word_embeddings: False
|
configs/model/tiny-ar.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tiny
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 512
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 8
|
| 7 |
+
n_heads: 8
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
causal: True
|
| 11 |
+
tie_word_embeddings: False
|
configs/model/tiny-dimamba.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tiny
|
| 2 |
+
type: dimamba
|
| 3 |
+
hidden_size: 512
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 14
|
| 7 |
+
n_heads: 8
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
temb_strategy: adaln
|
| 11 |
+
tie_word_embeddings: False
|
configs/model/tiny.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tiny
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 512
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 8
|
| 7 |
+
n_heads: 8
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
tie_word_embeddings: False
|
configs/noise/ar.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: ar
|
| 2 |
+
scale: 6.0
|
configs/noise/linear.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: linear
|
| 2 |
+
sigma_min: 1e-3
|
| 3 |
+
sigma_max: 7.0
|
configs/noise/loglinear.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: loglinear
|
| 2 |
+
sigma_min: 1e-4
|
| 3 |
+
sigma_max: 20
|
configs/noise/polynomial.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: polynomial
|
| 2 |
+
a: -3
|
| 3 |
+
b: 5
|
| 4 |
+
c: -4
|
| 5 |
+
eps: 1e-3
|
configs/strategy/ddp.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
| 2 |
+
find_unused_parameters: false # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this
|
configs/strategy/fsdp.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO(yair): Currenly not compatible with grad clipping
|
| 2 |
+
_target_: lightning.pytorch.strategies.FSDPStrategy
|
| 3 |
+
sharding_strategy: SHARD_GRAD_OP
|
temp_data/polymers_lit_scraped.csv
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Notebook reference,Polymer name,monomer A,mol fraction A,monomer B,fraction B,monomer C,fraction C,monomer D,fraction D,monomer E,fraction E,monomer F,fraction F,Distribution,Architecture,Target DP,MIC (E. coli),MIC (S. aureus),MIC (K. pneumoniae),MIC (E. faecium),HC50
|
| 2 |
+
SW1.84.1,L-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,70,>512,>512,,,>2000
|
| 3 |
+
SW1.84.2,L-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,70,>512,>512,,,>2000
|
| 4 |
+
SW1.84.3,L-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
|
| 5 |
+
SW1.89.1,L-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
|
| 6 |
+
SW1.89.2,L-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,70,128,32-64,256,512,>2000
|
| 7 |
+
SW1.89.3,L-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,70,128,32,512,512,>2000
|
| 8 |
+
SW1.110.1,L-Ni13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
|
| 9 |
+
SW1.110.2,L-Ni13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,64-128,,,>2000
|
| 10 |
+
SW1.110.3,L-Phe13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
|
| 11 |
+
SW1.115.1,L-Phe13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
|
| 12 |
+
SW1.115.2,L-Do13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,256-512,32,,,<50
|
| 13 |
+
SW1.115.3,L-Do13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,256,32,256,256,>2000
|
| 14 |
+
SW1.119.1,H-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,115,>512,128,,,>8000
|
| 15 |
+
SW1.119.2,H-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,115,>512,>512,,,>8000
|
| 16 |
+
SW1.119.3,H-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,115,256-512,128-256,64,>512,>8000
|
| 17 |
+
SW1.125.1,H-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,115,256,>512,nd,,>8000
|
| 18 |
+
SW1.119.5,H-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,115,128,32,128-256,256,>8000
|
| 19 |
+
SW1.119.6,H-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,115,128,32,256,>512,6300
|
| 20 |
+
SW2.3.1,L-Bam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.48,C=CC(=O)NCCCC,0.40,C=CC(=O)NCCCOC,0.12,,,,,,,statistical,linear,70,>512,>512,,,>8000
|
| 21 |
+
SW2.3.2,L-Bmam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.52,C=CC(=O)NCOCCCC,0.35,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,6200
|
| 22 |
+
SW2.3.3,L-Tmb31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NC(C)(C)CC(C)(C)C,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,64,64,,,<62.5
|
| 23 |
+
SW2.3.4,L-Oct31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NCCCCCCCC,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,256-128,64,256,>512,4700
|
| 24 |
+
SW2.3.5,L-Olam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.63,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.21,C=CC(=O)NCCCOC,0.16,,,,,,,statistical,linear,70,128,64-32,>512,>512,>8000
|
| 25 |
+
SW3.56.1,L-Do30Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.66,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,512,128,,,3400
|
| 26 |
+
SW3.56.2,L-Tmb5Mo90,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.04,C=CC(=O)NC(C)(C)CC(C)(C)C,0.04,C=CC(=O)N1CCOCC1,0.93,,,,,,,statistical,linear,70,>512,>512,,,>4000
|
| 27 |
+
SW3.56.3,L-Oct5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.87,C=CC(=O)NCCCCCCCC,0.05,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,>512,>512,,,>4000
|
| 28 |
+
SW3.56.4,L-Phe15Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.46,C=CC(=O)Nc1ccccc1,0.18,C=CC(=O)N1CCOCC1,0.37,,,,,,,statistical,linear,70,>512,16,,,>4000
|
| 29 |
+
SW4.14.2,L-Aeg5Phe25Mo50Mep20,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.038,C=CC(=O)Nc1ccccc1,0.246,C=CC(=O)N1CCOCC1,0.514,C=CC(=O)NCCCOC,0.203,,,,,statistical,linear,70,>512,>512,,,2200
|
| 30 |
+
SW4.29.1,L-Do5Mo40Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.416,C=CC(=O)NCCCCCCCCCCCC,0.036,C=CC(=O)N1CCOCC1,0.488,C=CC(=O)NCCCOC,0.060,,,,,statistical,linear,70,>512,>512,,,>4000
|
| 31 |
+
SW4.29.2,L-Phe20Olam5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.645,C=CC(=O)Nc1ccccc1,0.259,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.030,C=CC(=O)NCCCOC,0.067,,,,,statistical,linear,70,128,32,,,>4000
|
| 32 |
+
SW5.20.1,L-Do25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.777,C=CC(=O)NCCCCCCCCCCCC,0.223,,,,,,,,,statistical,linear,70,64,,,,>4000
|
| 33 |
+
SW5.20.2,L-Aeg10Olam30Mo60,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.091,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.164,C=CC(=O)N1CCOCC1,0.745,,,,,,,statistical,linear,70,>512,,,,>4000
|
| 34 |
+
SW5.20.3,L-Ni25Phe20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.427,C=CC(=O)NC(C)C,0.355,C=CC(=O)Nc1ccccc1,0.218,,,,,,,statistical,linear,70,>512,,,,>4000
|
| 35 |
+
SW5.20.4,L-Bam40Oct5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.438,C=CC(=O)NCCCC,0.517,C=CC(=O)NCCCCCCCC,0.045,,,,,,,statistical,linear,70,32,,,,<500
|
| 36 |
+
SW5.20.5,L-Phe23Oct5Mo55,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.126,C=CC(=O)Nc1ccccc1,0.239,C=CC(=O)N1CCOCC1,0.038,C=CC(=O)N1CCOCC1,0.597,,,,,statistical,linear,70,>512,,,,>4000
|
| 37 |
+
SW5.24.1,L-Aeg10Phe20Olam25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.450,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.107,C=CC(=O)Nc1ccccc1,0.281,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.161,,,,,statistical,linear,70,128,,,,1500
|
| 38 |
+
SW5.24.2,L-Aeg20Ni35Tmb10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.266,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.163,C=CC(=O)NC(C)C,0.486,C=CC(=O)NC(C)(C)CC(C)(C)C,0.086,,,,,statistical,linear,70,64,,,,<500
|
| 39 |
+
SW5.24.3,L-Phe35Olam10Mo20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.292,C=CC(=O)Nc1ccccc1,0.410,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)N1CCOCC1,0.244,,,,,statistical,linear,70,128,,,,>4000
|
| 40 |
+
SW5.24.4,L-Aeg17Tmb8Mo37,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.319,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.148,C=CC(=O)NC(C)(C)CC(C)(C)C,0.078,C=CC(=O)N1CCOCC1,0.455,,,,,statistical,linear,70,256,,,,<500
|
| 41 |
+
SW5.24.5,L-Aeg20Ni20Olam25Mo5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.269,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.193,C=CC(=O)NC(C)C,0.328,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.144,C=CC(=O)N1CCOCC1,0.066,,,statistical,linear,70,256,,,,>4000
|
| 42 |
+
SW5.41.1,L-Do10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.912,C=CC(=O)NCCCCCCCCCCCC,0.088,,,,,,,,,statistical,linear,70,256,,,,>4000
|
| 43 |
+
SW5.41.2,L-Phe15Do5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.759,C=CC(=O)Nc1ccccc1,0.200,C=CC(=O)NCCCCCCCCCCCC,0.041,,,,,,,statistical,linear,70,256,,,,>4000
|
| 44 |
+
SW5.41.3,L-Aeg5Phe5Olam5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.845,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.053,C=CC(=O)Nc1ccccc1,0.070,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.032,,,,,statistical,linear,70,128,,,,>4000
|
| 45 |
+
SW5.41.4,L-Ni20Do5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.593,C=CC(=O)NC(C)C,0.309,C=CC(=O)NCCCCCCCCCCCC,0.037,C=CC(=O)NCCCOC,0.061,,,,,statistical,linear,70,256,,,,>4000
|
| 46 |
+
SW5.41.5,L-Phe20Olam5Mo15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.530,C=CC(=O)Nc1ccccc1,0.248,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.028,C=CC(=O)N1CCOCC1,0.194,,,,,statistical,linear,70,128,,,,>4000
|
| 47 |
+
SW5.42.1,L-Phe5Do5Mo50,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.321,C=CC(=O)Nc1ccccc1,0.056,C=CC(=O)NCCCCCCCCCCCC,0.035,C=CC(=O)N1CCOCC1,0.588,,,,,statistical,linear,70,>512,,,,>4000
|
| 48 |
+
SW5.42.2,L-Aeg10Oct15Tmb5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.678,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.104,C=CC(=O)NCCCCCCCC,0.164,C=CC(=O)NC(C)(C)CC(C)(C)C,0.055,,,,,statistical,linear,70,128-256,,,,<500
|
| 49 |
+
SW5.42.3,L-Do5Bam5Mo20Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.570,C=CC(=O)NCCCCCCCCCCCC,0.038,C=CC(=O)NCCCC,0.071,C=CC(=O)N1CCOCC1,0.257,C=CC(=O)NCCCOC,0.063,,,statistical,linear,70,256,,,,>4000
|
| 50 |
+
SW5.42.4,L-Aeg5Phe15Bam30Mo25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.183,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.039,C=CC(=O)Nc1ccccc1,0.154,C=CC(=O)NCCCC,0.356,C=CC(=O)N1CCOCC1,0.268,,,statistical,linear,70,512,,,,>4000
|
| 51 |
+
SW5.42.5,L-Phe5Olam10Bmam10Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.674,C=CC(=O)Nc1ccccc1,0.068,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.062,C=CC(=O)NCOCCCC,0.127,C=CC(=O)NCCCOC,0.070,,,statistical,linear,70,64,,,,>4000
|
| 52 |
+
SW5.65.1,L-Aeg5Ni10Phe5Do30Mep15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.309,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.047,C=CC(=O)NC(C)C,0.161,C=CC(=O)Nc1ccccc1,0.062,C=CC(=O)NCCCCCCCCCCCC,0.229,C=CC(=O)NCCCOC,0.191,statistical,linear,70,64,,,,3300
|
| 53 |
+
SW5.65.5,L-Aeg10Ni15Bam10Olam20Mep20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.206,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.089,C=CC(=O)NC(C)C,0.226,C=CC(=O)NCCCC,0.134,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.106,C=CC(=O)NCCCOC,0.238,statistical,linear,70,128,,,,1400
|
| 54 |
+
SW5.65.7,L-Do15Bam15Oct10Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.245,C=CC(=O)NCCCCCCCCCCCC,0.106,C=CC(=O)NCCCC,0.199,C=CC(=O)NCCCCCCCC,0.092,C=CC(=O)N1CCOCC1,0.358,,,statistical,linear,70,128,,,,>4000
|
| 55 |
+
SW5.65.8,L-Aeg10Ni5Do25Tmb10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.122,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.088,C=CC(=O)NC(C)C,0.075,C=CC(=O)NCCCCCCCCCCCC,0.211,C=CC(=O)NC(C)(C)CC(C)(C)C,0.092,C=CC(=O)NCCCOC,0.412,statistical,linear,70,>512,,,,<500
|
| 56 |
+
SW5.65.9,L-Ni10Do5Mo60,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.185,C=CC(=O)NC(C)C,0.135,C=CC(=O)NCCCCCCCCCCCC,0.032,C=CC(=O)N1CCOCC1,0.649,,,,,statistical,linear,70,>512,,,,>4000
|
| 57 |
+
SW5.65.10,L-Aeg15Ni10Do10Olam10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.167,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.134,C=CC(=O)NC(C)C,0.152,C=CC(=O)NCCCCCCCCCCCC,0.072,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)NCCCOC,0.421,statistical,linear,70,>512,,,,2500
|
temp_fangping.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from DLM_emb_model import MolEmbDLM
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
import torch
|
| 6 |
+
import selfies as sf
|
| 7 |
+
|
| 8 |
+
MODEL_DIR = "Kiria-Nozan/ApexOracle"
|
| 9 |
+
|
| 10 |
+
# Load model and tokenizer
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
| 12 |
+
model = MolEmbDLM.from_pretrained(MODEL_DIR)
|
| 13 |
+
model.eval()
|
| 14 |
+
|
| 15 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
model = model.to(device)
|
| 17 |
+
|
| 18 |
+
# Load CSV data
|
| 19 |
+
df = pd.read_csv("temp_data/polymers_lit_scraped.csv")
|
| 20 |
+
|
| 21 |
+
# Extract all unique monomer SMILES
|
| 22 |
+
monomer_columns = ["monomer A", "monomer B", "monomer C", "monomer D", "monomer E", "monomer F"]
|
| 23 |
+
all_monomers = set()
|
| 24 |
+
|
| 25 |
+
for col in monomer_columns:
|
| 26 |
+
if col in df.columns:
|
| 27 |
+
monomers = df[col].dropna().unique()
|
| 28 |
+
all_monomers.update(monomers)
|
| 29 |
+
|
| 30 |
+
print(f"Total unique monomers: {len(all_monomers)}")
|
| 31 |
+
|
| 32 |
+
# Convert SMILES to SELFIES and prepare for embedding
|
| 33 |
+
monomer_selfies = {}
|
| 34 |
+
valid_monomers = []
|
| 35 |
+
|
| 36 |
+
for smiles in all_monomers:
|
| 37 |
+
try:
|
| 38 |
+
selfies = sf.encoder(smiles)
|
| 39 |
+
monomer_selfies[smiles] = selfies
|
| 40 |
+
valid_monomers.append((smiles, selfies))
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error converting {smiles} to SELFIES: {e}")
|
| 43 |
+
|
| 44 |
+
print(f"Valid monomers for embedding: {len(valid_monomers)}")
|
| 45 |
+
|
| 46 |
+
# Generate embeddings for all monomers
|
| 47 |
+
monomer_embeddings = {}
|
| 48 |
+
|
| 49 |
+
for smiles, selfies in valid_monomers:
|
| 50 |
+
# Prepare input similar to example.py
|
| 51 |
+
batch = tokenizer(
|
| 52 |
+
selfies.replace('][', '] ['),
|
| 53 |
+
padding=False,
|
| 54 |
+
truncation=False,
|
| 55 |
+
return_tensors="pt",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
embeddings = model(
|
| 62 |
+
input_ids=batch["input_ids"],
|
| 63 |
+
attention_mask=batch["attention_mask"],
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Store the embedding (average pooling over sequence length)
|
| 67 |
+
monomer_embeddings[smiles] = embeddings[0][0].cpu().numpy()
|
| 68 |
+
|
| 69 |
+
print(f"Generated embeddings for {len(monomer_embeddings)} monomers")
|
| 70 |
+
print(f"Embedding shape: {list(monomer_embeddings.values())[0].shape}")
|
| 71 |
+
|
| 72 |
+
# Save results
|
| 73 |
+
np.save("temp_data/monomer_embeddings.npy", monomer_embeddings)
|
| 74 |
+
print("Embeddings saved to monomer_embeddings.npy")
|