Kiria-Nozan commited on
Commit
ef9533f
·
1 Parent(s): c6e45fc

repair config path issue

Browse files
Files changed (44) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/ApexOracle.iml +12 -0
  3. .idea/deployment.xml +16 -0
  4. .idea/inspectionProfiles/Project_Default.xml +7 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +7 -0
  7. .idea/modules.xml +8 -0
  8. .idea/vcs.xml +7 -0
  9. DLM_emb_model.py +3 -3
  10. configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
  11. configs/callbacks/checkpoint_monitor.yaml +10 -0
  12. configs/callbacks/learning_rate_monitor.yaml +3 -0
  13. configs/config.yaml +102 -0
  14. configs/data/ag_news.yaml +6 -0
  15. configs/data/lambada.yaml +6 -0
  16. configs/data/lm1b-gpt2.yaml +6 -0
  17. configs/data/lm1b-streaming.yaml +6 -0
  18. configs/data/lm1b.yaml +6 -0
  19. configs/data/openwebtext-split.yaml +6 -0
  20. configs/data/openwebtext-streaming.yaml +6 -0
  21. configs/data/openwebtext.yaml +6 -0
  22. configs/data/ptb.yaml +6 -0
  23. configs/data/scientific_papers_arxiv.yaml +6 -0
  24. configs/data/scientific_papers_pubmed.yaml +6 -0
  25. configs/data/text8-crop.yaml +7 -0
  26. configs/data/text8.yaml +7 -0
  27. configs/data/wikitext103.yaml +6 -0
  28. configs/data/wikitext2.yaml +6 -0
  29. configs/lr_scheduler/constant_warmup.yaml +2 -0
  30. configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
  31. configs/model/medium.yaml +10 -0
  32. configs/model/small-ar.yaml +11 -0
  33. configs/model/small.yaml +10 -0
  34. configs/model/tiny-ar.yaml +11 -0
  35. configs/model/tiny-dimamba.yaml +11 -0
  36. configs/model/tiny.yaml +10 -0
  37. configs/noise/ar.yaml +2 -0
  38. configs/noise/linear.yaml +3 -0
  39. configs/noise/loglinear.yaml +3 -0
  40. configs/noise/polynomial.yaml +5 -0
  41. configs/strategy/ddp.yaml +2 -0
  42. configs/strategy/fsdp.yaml +3 -0
  43. temp_data/polymers_lit_scraped.csv +57 -0
  44. temp_fangping.py +74 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/ApexOracle.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="ApexOracle_HF_H100" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/deployment.xml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" autoUpload="Always" serverName="ApexOracle HF H100" remoteFilesAllowedToDisappearOnAutoupload="false" confirmBeforeUploading="false">
4
+ <option name="confirmBeforeUploading" value="false" />
5
+ <serverData>
6
+ <paths name="ApexOracle HF H100">
7
+ <serverdata>
8
+ <mappings>
9
+ <mapping deploy="/data2/tianang/projects/ApexOracle" local="$PROJECT_DIR$" />
10
+ </mappings>
11
+ </serverdata>
12
+ </paths>
13
+ </serverData>
14
+ <option name="myAutoUpload" value="ALWAYS" />
15
+ </component>
16
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
5
+ <inspection_tool class="PyUnboundLocalVariableInspection" enabled="false" level="WARNING" enabled_by_default="false" />
6
+ </profile>
7
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.9" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="ApexOracle_HF_H100" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/ApexOracle.iml" filepath="$PROJECT_DIR$/.idea/ApexOracle.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
6
+ </component>
7
+ </project>
DLM_emb_model.py CHANGED
@@ -31,10 +31,10 @@ import ast
31
  from omegaconf import OmegaConf, DictConfig, ListConfig
32
  from huggingface_hub import PyTorchModelHubMixin
33
 
34
- # current_directory = Path(__file__).parent
35
- current_directory = Path('/data2/tianang/projects/Synergy')
36
 
37
- with initialize_config_dir(config_dir="/data2/tianang/projects/mdlm/configs"):
38
  config = compose(config_name="config")
39
 
40
  class mol_emb_mdlm(nn.Module):
 
31
  from omegaconf import OmegaConf, DictConfig, ListConfig
32
  from huggingface_hub import PyTorchModelHubMixin
33
 
34
+ current_directory = Path(__file__).parent
35
+ # current_directory = Path('/data2/tianang/projects/Synergy')
36
 
37
+ with initialize_config_dir(config_dir=str(current_directory/"configs")):
38
  config = compose(config_name="config")
39
 
40
  class mol_emb_mdlm(nn.Module):
configs/callbacks/checkpoint_every_n_steps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoint_every_n_steps:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
4
+ save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
5
+ dirpath: ${checkpointing.save_dir}/checkpoints
6
+ verbose: True
7
+ auto_insert_metric_name: False
8
+ every_n_train_steps: 500
configs/callbacks/checkpoint_monitor.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_monitor:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ monitor: val/nll # name of the logged metric which determines when model is improving
4
+ mode: min # can be "max" or "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: False # True = additionally always save model from last epoch
7
+ dirpath: ${checkpointing.save_dir}/checkpoints
8
+ filename: best
9
+ auto_insert_metric_name: False
10
+ verbose: True
configs/callbacks/learning_rate_monitor.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate_monitor:
2
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
3
+ logging_interval: step
configs/config.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /data: openwebtext
5
+ - /model: small # small / medium
6
+ - /strategy: ddp
7
+ - /noise: loglinear
8
+ - /lr_scheduler: constant_warmup
9
+
10
+ mode: sample_eval # train / ppl_eval / sample_eval
11
+ diffusion: absorbing_state
12
+ backbone: dit # dit / dimamba / ar
13
+ parameterization: subs # subs / d3pm / sedd
14
+ time_conditioning: False
15
+ T: 0 # 0 (continuous time) / 1000
16
+ subs_masking: False
17
+
18
+ seed: 1
19
+
20
+ loader:
21
+ global_batch_size: 512
22
+ eval_global_batch_size: ${.global_batch_size}
23
+ # Note: batch_size and eval_batch_size are **per machine**
24
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
25
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
26
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
27
+ pin_memory: True
28
+
29
+ sampling:
30
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
31
+ steps: 128
32
+ noise_removal: True
33
+ # TODO(yair): @subham, why aren't these params under `eval`?
34
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
35
+ num_sample_log: 2
36
+ semi_ar: False
37
+ stride_length: 1
38
+ num_strides: 1
39
+
40
+
41
+ training:
42
+ ema: 0.9999
43
+ antithetic_sampling: True
44
+ importance_sampling: False
45
+ sampling_eps: 1e-3
46
+ change_of_variables: False
47
+
48
+ eval:
49
+ checkpoint_path: '/data2/tianang/projects/mdlm/Checkpoints_fangping/1-255000-fine-tune.ckpt' # Used to evaluate a checkpoint after training.
50
+ disable_ema: False
51
+ compute_generative_perplexity: False
52
+ perplexity_batch_size: 8
53
+ compute_perplexity_on_sanity: False
54
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
55
+ generate_samples: True
56
+
57
+ optim:
58
+ weight_decay: 0
59
+ lr: 3e-4
60
+ beta1: 0.9
61
+ beta2: 0.999
62
+ eps: 1e-8
63
+
64
+ trainer:
65
+ _target_: lightning.Trainer
66
+ accelerator: cuda
67
+ num_nodes: 1
68
+ devices: ${device_count:}
69
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
70
+ gradient_clip_val: 1.0
71
+ precision: 'bf16'
72
+ num_sanity_val_steps: 2
73
+ max_steps: 1_000_000
74
+ log_every_n_steps: 10
75
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
76
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
77
+ val_check_interval: 10000
78
+
79
+ wandb:
80
+ project: text-diffusion
81
+ notes: Mulan for text
82
+ group: null
83
+ job_type: null
84
+ name: null
85
+ id: ${.name}_${seed}
86
+ tags:
87
+ - ${noise.type}
88
+ - ${data.train}
89
+ - ${data.valid}
90
+
91
+ hydra:
92
+ run:
93
+ dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
94
+ job:
95
+ chdir: true
96
+
97
+ checkpointing:
98
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
99
+ save_dir: ${cwd:}
100
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
101
+ resume_from_ckpt: true
102
+ resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
configs/data/ag_news.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: ag_news
2
+ valid: ag_news
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lambada.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lambada
2
+ valid: lambada
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lm1b-gpt2.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/lm1b-streaming.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: True
configs/data/lm1b.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
configs/data/openwebtext-split.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext-train
2
+ valid: openwebtext-valid
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/openwebtext-streaming.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /tmp/data
5
+ wrap: True
6
+ streaming: True
configs/data/openwebtext.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: openwebtext
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: ibm-research/materials.selfies-ted
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/ptb.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: ptb
2
+ valid: ptb
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/scientific_papers_arxiv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: scientific_papers_arxiv
2
+ valid: scientific_papers_arxiv
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/scientific_papers_pubmed.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: scientific_papers_pubmed
2
+ valid: scientific_papers_pubmed
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/text8-crop.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8-crop
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
configs/data/text8.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
configs/data/wikitext103.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: wikitext103
2
+ valid: wikitext103
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/data/wikitext2.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: wikitext2
2
+ valid: wikitext2
3
+ tokenizer_name_or_path: gpt2
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: True
6
+ streaming: False
configs/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
configs/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-6
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-6
configs/model/medium.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medium
2
+ type: ddit
3
+ hidden_size: 1024
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 24
7
+ n_heads: 16
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/small-ar.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ causal: True
11
+ tie_word_embeddings: False
configs/model/small.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/tiny-ar.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ causal: True
11
+ tie_word_embeddings: False
configs/model/tiny-dimamba.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: dimamba
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 14
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ temb_strategy: adaln
11
+ tie_word_embeddings: False
configs/model/tiny.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/noise/ar.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ type: ar
2
+ scale: 6.0
configs/noise/linear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: linear
2
+ sigma_min: 1e-3
3
+ sigma_max: 7.0
configs/noise/loglinear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: loglinear
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
configs/noise/polynomial.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: polynomial
2
+ a: -3
3
+ b: 5
4
+ c: -4
5
+ eps: 1e-3
configs/strategy/ddp.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: lightning.pytorch.strategies.DDPStrategy
2
+ find_unused_parameters: false # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this
configs/strategy/fsdp.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # TODO(yair): Currenly not compatible with grad clipping
2
+ _target_: lightning.pytorch.strategies.FSDPStrategy
3
+ sharding_strategy: SHARD_GRAD_OP
temp_data/polymers_lit_scraped.csv ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Notebook reference,Polymer name,monomer A,mol fraction A,monomer B,fraction B,monomer C,fraction C,monomer D,fraction D,monomer E,fraction E,monomer F,fraction F,Distribution,Architecture,Target DP,MIC (E. coli),MIC (S. aureus),MIC (K. pneumoniae),MIC (E. faecium),HC50
2
+ SW1.84.1,L-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,70,>512,>512,,,>2000
3
+ SW1.84.2,L-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,70,>512,>512,,,>2000
4
+ SW1.84.3,L-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
5
+ SW1.89.1,L-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,>2000
6
+ SW1.89.2,L-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,70,128,32-64,256,512,>2000
7
+ SW1.89.3,L-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,70,128,32,512,512,>2000
8
+ SW1.110.1,L-Ni13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
9
+ SW1.110.2,L-Ni13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.73,C=CC(=O)NC(C)C,0.21,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,64-128,,,>2000
10
+ SW1.110.3,L-Phe13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
11
+ SW1.115.1,L-Phe13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.77,C=CC(=O)Nc1ccccc1,0.17,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,>512,32,,,>2000
12
+ SW1.115.2,L-Do13Mo4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)N1CCOCC1,0.06,,,,,,,statistical,linear,70,256-512,32,,,<50
13
+ SW1.115.3,L-Do13Mep4,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.83,C=CC(=O)NCCCCCCCCCCCC,0.11,C=CC(=O)NCCCOC,0.06,,,,,,,statistical,linear,70,256,32,256,256,>2000
14
+ SW1.119.1,H-Ni31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)N1CCOCC1,0.12,,,,,,,statistical,linear,115,>512,128,,,>8000
15
+ SW1.119.2,H-Ni31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.45,C=CC(=O)NC(C)C,0.43,C=CC(=O)NCCCOC,0.11,,,,,,,statistical,linear,115,>512,>512,,,>8000
16
+ SW1.119.3,H-Phe31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.50,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)N1CCOCC1,0.13,,,,,,,statistical,linear,115,256-512,128-256,64,>512,>8000
17
+ SW1.125.1,H-Phe31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.51,C=CC(=O)Nc1ccccc1,0.37,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,115,256,>512,nd,,>8000
18
+ SW1.119.5,H-Do31Mo10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)N1CCOCC1,0.15,,,,,,,statistical,linear,115,128,32,128-256,256,>8000
19
+ SW1.119.6,H-Do31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.59,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.15,,,,,,,statistical,linear,115,128,32,256,>512,6300
20
+ SW2.3.1,L-Bam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.48,C=CC(=O)NCCCC,0.40,C=CC(=O)NCCCOC,0.12,,,,,,,statistical,linear,70,>512,>512,,,>8000
21
+ SW2.3.2,L-Bmam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.52,C=CC(=O)NCOCCCC,0.35,C=CC(=O)NCCCOC,0.13,,,,,,,statistical,linear,70,256,>512,,,6200
22
+ SW2.3.3,L-Tmb31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NC(C)(C)CC(C)(C)C,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,64,64,,,<62.5
23
+ SW2.3.4,L-Oct31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.54,C=CC(=O)NCCCCCCCC,0.32,C=CC(=O)NCCCOC,0.14,,,,,,,statistical,linear,70,256-128,64,256,>512,4700
24
+ SW2.3.5,L-Olam31Mep10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.63,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.21,C=CC(=O)NCCCOC,0.16,,,,,,,statistical,linear,70,128,64-32,>512,>512,>8000
25
+ SW3.56.1,L-Do30Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.66,C=CC(=O)NCCCCCCCCCCCC,0.26,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,512,128,,,3400
26
+ SW3.56.2,L-Tmb5Mo90,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.04,C=CC(=O)NC(C)(C)CC(C)(C)C,0.04,C=CC(=O)N1CCOCC1,0.93,,,,,,,statistical,linear,70,>512,>512,,,>4000
27
+ SW3.56.3,L-Oct5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.87,C=CC(=O)NCCCCCCCC,0.05,C=CC(=O)NCCCOC,0.07,,,,,,,statistical,linear,70,>512,>512,,,>4000
28
+ SW3.56.4,L-Phe15Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.46,C=CC(=O)Nc1ccccc1,0.18,C=CC(=O)N1CCOCC1,0.37,,,,,,,statistical,linear,70,>512,16,,,>4000
29
+ SW4.14.2,L-Aeg5Phe25Mo50Mep20,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.038,C=CC(=O)Nc1ccccc1,0.246,C=CC(=O)N1CCOCC1,0.514,C=CC(=O)NCCCOC,0.203,,,,,statistical,linear,70,>512,>512,,,2200
30
+ SW4.29.1,L-Do5Mo40Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.416,C=CC(=O)NCCCCCCCCCCCC,0.036,C=CC(=O)N1CCOCC1,0.488,C=CC(=O)NCCCOC,0.060,,,,,statistical,linear,70,>512,>512,,,>4000
31
+ SW4.29.2,L-Phe20Olam5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.645,C=CC(=O)Nc1ccccc1,0.259,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.030,C=CC(=O)NCCCOC,0.067,,,,,statistical,linear,70,128,32,,,>4000
32
+ SW5.20.1,L-Do25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.777,C=CC(=O)NCCCCCCCCCCCC,0.223,,,,,,,,,statistical,linear,70,64,,,,>4000
33
+ SW5.20.2,L-Aeg10Olam30Mo60,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.091,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.164,C=CC(=O)N1CCOCC1,0.745,,,,,,,statistical,linear,70,>512,,,,>4000
34
+ SW5.20.3,L-Ni25Phe20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.427,C=CC(=O)NC(C)C,0.355,C=CC(=O)Nc1ccccc1,0.218,,,,,,,statistical,linear,70,>512,,,,>4000
35
+ SW5.20.4,L-Bam40Oct5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.438,C=CC(=O)NCCCC,0.517,C=CC(=O)NCCCCCCCC,0.045,,,,,,,statistical,linear,70,32,,,,<500
36
+ SW5.20.5,L-Phe23Oct5Mo55,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.126,C=CC(=O)Nc1ccccc1,0.239,C=CC(=O)N1CCOCC1,0.038,C=CC(=O)N1CCOCC1,0.597,,,,,statistical,linear,70,>512,,,,>4000
37
+ SW5.24.1,L-Aeg10Phe20Olam25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.450,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.107,C=CC(=O)Nc1ccccc1,0.281,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.161,,,,,statistical,linear,70,128,,,,1500
38
+ SW5.24.2,L-Aeg20Ni35Tmb10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.266,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.163,C=CC(=O)NC(C)C,0.486,C=CC(=O)NC(C)(C)CC(C)(C)C,0.086,,,,,statistical,linear,70,64,,,,<500
39
+ SW5.24.3,L-Phe35Olam10Mo20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.292,C=CC(=O)Nc1ccccc1,0.410,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)N1CCOCC1,0.244,,,,,statistical,linear,70,128,,,,>4000
40
+ SW5.24.4,L-Aeg17Tmb8Mo37,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.319,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.148,C=CC(=O)NC(C)(C)CC(C)(C)C,0.078,C=CC(=O)N1CCOCC1,0.455,,,,,statistical,linear,70,256,,,,<500
41
+ SW5.24.5,L-Aeg20Ni20Olam25Mo5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.269,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.193,C=CC(=O)NC(C)C,0.328,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.144,C=CC(=O)N1CCOCC1,0.066,,,statistical,linear,70,256,,,,>4000
42
+ SW5.41.1,L-Do10,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.912,C=CC(=O)NCCCCCCCCCCCC,0.088,,,,,,,,,statistical,linear,70,256,,,,>4000
43
+ SW5.41.2,L-Phe15Do5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.759,C=CC(=O)Nc1ccccc1,0.200,C=CC(=O)NCCCCCCCCCCCC,0.041,,,,,,,statistical,linear,70,256,,,,>4000
44
+ SW5.41.3,L-Aeg5Phe5Olam5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.845,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.053,C=CC(=O)Nc1ccccc1,0.070,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.032,,,,,statistical,linear,70,128,,,,>4000
45
+ SW5.41.4,L-Ni20Do5Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.593,C=CC(=O)NC(C)C,0.309,C=CC(=O)NCCCCCCCCCCCC,0.037,C=CC(=O)NCCCOC,0.061,,,,,statistical,linear,70,256,,,,>4000
46
+ SW5.41.5,L-Phe20Olam5Mo15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.530,C=CC(=O)Nc1ccccc1,0.248,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.028,C=CC(=O)N1CCOCC1,0.194,,,,,statistical,linear,70,128,,,,>4000
47
+ SW5.42.1,L-Phe5Do5Mo50,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.321,C=CC(=O)Nc1ccccc1,0.056,C=CC(=O)NCCCCCCCCCCCC,0.035,C=CC(=O)N1CCOCC1,0.588,,,,,statistical,linear,70,>512,,,,>4000
48
+ SW5.42.2,L-Aeg10Oct15Tmb5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.678,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.104,C=CC(=O)NCCCCCCCC,0.164,C=CC(=O)NC(C)(C)CC(C)(C)C,0.055,,,,,statistical,linear,70,128-256,,,,<500
49
+ SW5.42.3,L-Do5Bam5Mo20Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.570,C=CC(=O)NCCCCCCCCCCCC,0.038,C=CC(=O)NCCCC,0.071,C=CC(=O)N1CCOCC1,0.257,C=CC(=O)NCCCOC,0.063,,,statistical,linear,70,256,,,,>4000
50
+ SW5.42.4,L-Aeg5Phe15Bam30Mo25,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.183,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.039,C=CC(=O)Nc1ccccc1,0.154,C=CC(=O)NCCCC,0.356,C=CC(=O)N1CCOCC1,0.268,,,statistical,linear,70,512,,,,>4000
51
+ SW5.42.5,L-Phe5Olam10Bmam10Mep5,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.674,C=CC(=O)Nc1ccccc1,0.068,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.062,C=CC(=O)NCOCCCC,0.127,C=CC(=O)NCCCOC,0.070,,,statistical,linear,70,64,,,,>4000
52
+ SW5.65.1,L-Aeg5Ni10Phe5Do30Mep15,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.309,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.047,C=CC(=O)NC(C)C,0.161,C=CC(=O)Nc1ccccc1,0.062,C=CC(=O)NCCCCCCCCCCCC,0.229,C=CC(=O)NCCCOC,0.191,statistical,linear,70,64,,,,3300
53
+ SW5.65.5,L-Aeg10Ni15Bam10Olam20Mep20,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.206,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.089,C=CC(=O)NC(C)C,0.226,C=CC(=O)NCCCC,0.134,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.106,C=CC(=O)NCCCOC,0.238,statistical,linear,70,128,,,,1400
54
+ SW5.65.7,L-Do15Bam15Oct10Mo30,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.245,C=CC(=O)NCCCCCCCCCCCC,0.106,C=CC(=O)NCCCC,0.199,C=CC(=O)NCCCCCCCC,0.092,C=CC(=O)N1CCOCC1,0.358,,,statistical,linear,70,128,,,,>4000
55
+ SW5.65.8,L-Aeg10Ni5Do25Tmb10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.122,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.088,C=CC(=O)NC(C)C,0.075,C=CC(=O)NCCCCCCCCCCCC,0.211,C=CC(=O)NC(C)(C)CC(C)(C)C,0.092,C=CC(=O)NCCCOC,0.412,statistical,linear,70,>512,,,,<500
56
+ SW5.65.9,L-Ni10Do5Mo60,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.185,C=CC(=O)NC(C)C,0.135,C=CC(=O)NCCCCCCCCCCCC,0.032,C=CC(=O)N1CCOCC1,0.649,,,,,statistical,linear,70,>512,,,,>4000
57
+ SW5.65.10,L-Aeg15Ni10Do10Olam10Mep35,C=CC(=O)NCC[N+](C)(C)C.[Cl-],0.167,C=CC(=O)NCCNC(N)=[NH2+].[Cl-],0.134,C=CC(=O)NC(C)C,0.152,C=CC(=O)NCCCCCCCCCCCC,0.072,C=CC(=O)NCCCCCCCC/C=C\CCCCCCCC,0.054,C=CC(=O)NCCCOC,0.421,statistical,linear,70,>512,,,,2500
temp_fangping.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from DLM_emb_model import MolEmbDLM
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+ import selfies as sf
7
+
8
+ MODEL_DIR = "Kiria-Nozan/ApexOracle"
9
+
10
+ # Load model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
12
+ model = MolEmbDLM.from_pretrained(MODEL_DIR)
13
+ model.eval()
14
+
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+ # Load CSV data
19
+ df = pd.read_csv("temp_data/polymers_lit_scraped.csv")
20
+
21
+ # Extract all unique monomer SMILES
22
+ monomer_columns = ["monomer A", "monomer B", "monomer C", "monomer D", "monomer E", "monomer F"]
23
+ all_monomers = set()
24
+
25
+ for col in monomer_columns:
26
+ if col in df.columns:
27
+ monomers = df[col].dropna().unique()
28
+ all_monomers.update(monomers)
29
+
30
+ print(f"Total unique monomers: {len(all_monomers)}")
31
+
32
+ # Convert SMILES to SELFIES and prepare for embedding
33
+ monomer_selfies = {}
34
+ valid_monomers = []
35
+
36
+ for smiles in all_monomers:
37
+ try:
38
+ selfies = sf.encoder(smiles)
39
+ monomer_selfies[smiles] = selfies
40
+ valid_monomers.append((smiles, selfies))
41
+ except Exception as e:
42
+ print(f"Error converting {smiles} to SELFIES: {e}")
43
+
44
+ print(f"Valid monomers for embedding: {len(valid_monomers)}")
45
+
46
+ # Generate embeddings for all monomers
47
+ monomer_embeddings = {}
48
+
49
+ for smiles, selfies in valid_monomers:
50
+ # Prepare input similar to example.py
51
+ batch = tokenizer(
52
+ selfies.replace('][', '] ['),
53
+ padding=False,
54
+ truncation=False,
55
+ return_tensors="pt",
56
+ )
57
+
58
+ batch = {k: v.to(device) for k, v in batch.items()}
59
+
60
+ with torch.no_grad():
61
+ embeddings = model(
62
+ input_ids=batch["input_ids"],
63
+ attention_mask=batch["attention_mask"],
64
+ )
65
+
66
+ # Store the embedding (average pooling over sequence length)
67
+ monomer_embeddings[smiles] = embeddings[0][0].cpu().numpy()
68
+
69
+ print(f"Generated embeddings for {len(monomer_embeddings)} monomers")
70
+ print(f"Embedding shape: {list(monomer_embeddings.values())[0].shape}")
71
+
72
+ # Save results
73
+ np.save("temp_data/monomer_embeddings.npy", monomer_embeddings)
74
+ print("Embeddings saved to monomer_embeddings.npy")