han-xudong commited on
Commit ·
0cabb09
1
Parent(s): 9013318
modified: modeling.py
Browse files- modeling.py +1 -11
modeling.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
-
from huggingface_hub import snapshot_download
|
| 6 |
|
| 7 |
class MVAEConfig(PretrainedConfig):
|
| 8 |
model_type = "mvae"
|
|
@@ -111,16 +110,7 @@ class MVAE(PreTrainedModel):
|
|
| 111 |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
| 112 |
|
| 113 |
prosoro_type = getattr(config, "prosoro_type", None)
|
| 114 |
-
print(f"Loading model from {pretrained_model_name_or_path} with prosoro_type={prosoro_type}")
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
if prosoro_type is not None:
|
| 119 |
-
sub_path = os.path.join(cached_folder, prosoro_type)
|
| 120 |
-
print(f"Loading prosoro type from subfolder: {sub_path}")
|
| 121 |
-
if os.path.isdir(sub_path):
|
| 122 |
-
pretrained_model_name_or_path = sub_path
|
| 123 |
-
else:
|
| 124 |
-
raise ValueError(f"No subfolder found for {prosoro_type}")
|
| 125 |
|
| 126 |
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
| 5 |
|
| 6 |
class MVAEConfig(PretrainedConfig):
|
| 7 |
model_type = "mvae"
|
|
|
|
| 110 |
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
| 111 |
|
| 112 |
prosoro_type = getattr(config, "prosoro_type", None)
|
|
|
|
| 113 |
|
| 114 |
+
pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|