han-xudong commited on
Commit
0cabb09
·
1 Parent(s): 9013318

modified: modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +1 -11
modeling.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, PretrainedConfig
5
- from huggingface_hub import snapshot_download
6
 
7
  class MVAEConfig(PretrainedConfig):
8
  model_type = "mvae"
@@ -111,16 +110,7 @@ class MVAE(PreTrainedModel):
111
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
112
 
113
  prosoro_type = getattr(config, "prosoro_type", None)
114
- print(f"Loading model from {pretrained_model_name_or_path} with prosoro_type={prosoro_type}")
115
 
116
- cached_folder = snapshot_download(pretrained_model_name_or_path)
117
-
118
- if prosoro_type is not None:
119
- sub_path = os.path.join(cached_folder, prosoro_type)
120
- print(f"Loading prosoro type from subfolder: {sub_path}")
121
- if os.path.isdir(sub_path):
122
- pretrained_model_name_or_path = sub_path
123
- else:
124
- raise ValueError(f"No subfolder found for {prosoro_type}")
125
 
126
  return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
 
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, PretrainedConfig
 
5
 
6
  class MVAEConfig(PretrainedConfig):
7
  model_type = "mvae"
 
110
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
111
 
112
  prosoro_type = getattr(config, "prosoro_type", None)
 
113
 
114
+ pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}"
 
 
 
 
 
 
 
 
115
 
116
  return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)