adnlp commited on
Commit
beff72b
·
verified ·
1 Parent(s): 81a0917

Update modeling_multicasttimer.py

Browse files
Files changed (1) hide show
  1. modeling_multicasttimer.py +184 -180
modeling_multicasttimer.py CHANGED
@@ -1,181 +1,185 @@
1
- import torch
2
- from torch import nn
3
- from transformers import PreTrainedModel, PretrainedConfig
4
- from safetensors.torch import load_file
5
-
6
- # CLIP
7
- from .modeling_clipPT import CLIPVisionTransformer
8
- from transformers import CLIPImageProcessor
9
-
10
- from transformers import AutoTokenizer
11
- # Qwen
12
- from .modeling_qwen2 import Qwen2Model
13
-
14
- # Timer
15
- from .modeling_timer import TimerForPrediction
16
-
17
- class MulTiCastTimerConfig(PretrainedConfig):
18
- def __init__(
19
- self,
20
- forecasting_length = None,
21
- vision_model_name = None,
22
- text_model_name = None,
23
- vision_model_prompt_len = None,
24
- text_model_prompt_len = None,
25
- timer_prompt_len = None,
26
- **kwargs
27
- ):
28
- super().__init__(**kwargs)
29
- self.forecasting_length = forecasting_length
30
- self.vision_model_name = vision_model_name
31
- self.text_model_name = text_model_name
32
-
33
- self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10
34
- self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4
35
- self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4
36
-
37
- class MulTiCastTimerModel(PreTrainedModel):
38
-
39
- config_class = MulTiCastTimerConfig
40
-
41
- def __init__(self, config):
42
- super().__init__(config)
43
- self.config = config
44
-
45
- # Vision Model
46
- if config.vision_model_name is None:
47
- pass
48
- elif config.vision_model_name == 'CLIP':
49
- from transformers import AutoModel
50
- vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
51
- state_dict = vision_model.state_dict()
52
- state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
53
- self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len)
54
- self.vision_model.load_state_dict(state_dict, strict=False)
55
- self.processor = CLIPImageProcessor()
56
- for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts
57
- if "encoder.prompts" in name:
58
- param.requires_grad = True
59
- else:
60
- param.requires_grad = False
61
- else:
62
- pass
63
-
64
- # Text Model
65
- if config.text_model_name is None:
66
- pass
67
- elif config.text_model_name == 'Qwen':
68
- self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
69
- from transformers import AutoModelForCausalLM
70
- text_model = AutoModelForCausalLM.from_pretrained(
71
- "Qwen/Qwen2-1.5B-Instruct",
72
- torch_dtype=torch.bfloat16,
73
- device_map="cpu",
74
- attn_implementation="sdpa"
75
- ).model
76
- state_dict = text_model.state_dict()
77
- self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len)
78
- self.text_model.load_state_dict(state_dict, strict=False)
79
- for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts
80
- if "prompts" in name:
81
- param.requires_grad = True
82
- else:
83
- param.requires_grad = False
84
- else:
85
- pass
86
-
87
- # Timer
88
- from transformers import AutoModelForCausalLM
89
- timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)
90
- state_dict = timer.state_dict()
91
- self.timer = TimerForPrediction(timer.config, config.timer_prompt_len)
92
- self.timer.load_state_dict(state_dict, strict=False)
93
- for name, param in self.timer.named_parameters(): # Freeze layers other than prompts
94
- if "model.prompts" in name:
95
- param.requires_grad = True
96
- else:
97
- param.requires_grad = False
98
-
99
- # Vision Interaction Layer
100
- if config.vision_model_name is None:
101
- pass
102
- else:
103
- self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size)
104
-
105
- # Text Interaction Layer
106
- if config.text_model_name is None:
107
- pass
108
- else:
109
- self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size)
110
-
111
- def predict(self, input_ids = None, images = None, texts = None):
112
- images = self.processor.preprocess(images)['pixel_values'][0]
113
- images = torch.tensor(images)
114
- images = images.unsqueeze(0)
115
-
116
- if self.config.vision_model_name is None and images is None:
117
- vision_embedding = None
118
- else:
119
- vision_output = self.vision_model(images, output_attentions=True)
120
- vision_attentions = vision_output.attentions
121
- vision_embedding = vision_output.pooler_output
122
- vision_embedding = self.vision_interaction_layer(vision_embedding)
123
-
124
- if self.config.text_model_name is None and all(x is None for x in texts):
125
- text_embedding = None
126
- else:
127
- tokenized_texts = self.tokenizer(texts, return_tensors="pt")
128
- text_embedding = self.text_model(**tokenized_texts)
129
- text_embedding = text_embedding.last_hidden_state[:, 0 , :]
130
- text_embedding = self.text_interaction_layer(text_embedding)
131
-
132
- out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
133
-
134
- return {
135
- "logits": out.logits,
136
- "vision_attentions": vision_attentions,
137
- "time_series_attentions": out.attentions
138
- }
139
-
140
- def forward(self, input_ids = None, images = None, texts = None, labels = None):
141
- if self.config.vision_model_name is None and images is None:
142
- vision_embedding = None
143
- else:
144
- vision_embedding = self.vision_model(images)
145
- vision_embedding = vision_embedding.pooler_output
146
- vision_embedding = self.vision_interaction_layer(vision_embedding)
147
-
148
- if self.config.text_model_name is None and all(x is None for x in texts):
149
- text_embedding = None
150
- else:
151
- tokenized_texts = self.tokenizer(texts, return_tensors="pt")
152
- text_embedding = self.text_model(**tokenized_texts)
153
- text_embedding = text_embedding.last_hidden_state[:, 0 , :]
154
- text_embedding = self.text_interaction_layer(text_embedding)
155
-
156
- out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
157
- out = out["logits"]
158
-
159
- if labels is not None:
160
- if self.config.forecasting_length == out.shape[-1]:
161
- loss = torch.mean(torch.square(out-labels)) # MSE
162
- else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error.
163
- loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels))
164
- else:
165
- loss = None
166
-
167
- return {
168
- "loss": loss,
169
- "logits": out
170
- }
171
-
172
- @classmethod
173
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
174
- from transformers.utils import cached_file
175
- config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path)
176
- model = MulTiCastTimerModel(config)
177
- resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors")
178
- state_dict = load_file(resolved_file)
179
- model.load_state_dict(state_dict, strict=False)
180
-
 
 
 
 
181
  return model
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from safetensors.torch import load_file
5
+
6
+ # CLIP
7
+ from .modeling_clipPT import CLIPVisionTransformer
8
+ from transformers import CLIPImageProcessor
9
+
10
+ from transformers import AutoTokenizer
11
+ # Qwen
12
+ from .modeling_qwen2 import Qwen2Model
13
+
14
+ # Timer
15
+ from .modeling_timer import TimerForPrediction
16
+
17
+ class MulTiCastTimerConfig(PretrainedConfig):
18
+ def __init__(
19
+ self,
20
+ forecasting_length = None,
21
+ vision_model_name = None,
22
+ text_model_name = None,
23
+ vision_model_prompt_len = None,
24
+ text_model_prompt_len = None,
25
+ timer_prompt_len = None,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+ self.forecasting_length = forecasting_length
30
+ self.vision_model_name = vision_model_name
31
+ self.text_model_name = text_model_name
32
+
33
+ self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10
34
+ self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4
35
+ self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4
36
+
37
+ class MulTiCastTimerModel(PreTrainedModel):
38
+
39
+ config_class = MulTiCastTimerConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.config = config
44
+
45
+ # Vision Model
46
+ if config.vision_model_name is None:
47
+ pass
48
+ elif config.vision_model_name == 'CLIP':
49
+ from transformers import AutoModel
50
+ vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
51
+ state_dict = vision_model.state_dict()
52
+ state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
53
+ self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len)
54
+ self.vision_model.load_state_dict(state_dict, strict=False)
55
+ self.processor = CLIPImageProcessor()
56
+ for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts
57
+ if "encoder.prompts" in name:
58
+ param.requires_grad = True
59
+ else:
60
+ param.requires_grad = False
61
+ else:
62
+ pass
63
+
64
+ # Text Model
65
+ if config.text_model_name is None:
66
+ pass
67
+ elif config.text_model_name == 'Qwen':
68
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
69
+ from transformers import AutoModelForCausalLM
70
+ text_model = AutoModelForCausalLM.from_pretrained(
71
+ "Qwen/Qwen2-1.5B-Instruct",
72
+ torch_dtype=torch.bfloat16,
73
+ device_map="cpu",
74
+ attn_implementation="sdpa"
75
+ ).model
76
+ state_dict = text_model.state_dict()
77
+ self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len)
78
+ self.text_model.load_state_dict(state_dict, strict=False)
79
+ for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts
80
+ if "prompts" in name:
81
+ param.requires_grad = True
82
+ else:
83
+ param.requires_grad = False
84
+ else:
85
+ pass
86
+
87
+ # Timer
88
+ from transformers import AutoModelForCausalLM
89
+ timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)
90
+ state_dict = timer.state_dict()
91
+ self.timer = TimerForPrediction(timer.config, config.timer_prompt_len)
92
+ self.timer.load_state_dict(state_dict, strict=False)
93
+ for name, param in self.timer.named_parameters(): # Freeze layers other than prompts
94
+ if "model.prompts" in name:
95
+ param.requires_grad = True
96
+ else:
97
+ param.requires_grad = False
98
+
99
+ # Vision Interaction Layer
100
+ if config.vision_model_name is None:
101
+ pass
102
+ else:
103
+ self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size)
104
+
105
+ # Text Interaction Layer
106
+ if config.text_model_name is None:
107
+ pass
108
+ else:
109
+ self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size)
110
+
111
+ def predict(self, input_ids = None, images = None, texts = None):
112
+ images = self.processor.preprocess(images)['pixel_values'][0]
113
+ images = torch.tensor(images)
114
+ images = images.unsqueeze(0)
115
+
116
+ if self.config.vision_model_name is None and images is None:
117
+ vision_embedding = None
118
+ else:
119
+ vision_output = self.vision_model(images, output_attentions=True)
120
+ vision_attentions = vision_output.attentions
121
+ vision_embedding = vision_output.pooler_output
122
+ vision_embedding = self.vision_interaction_layer(vision_embedding)
123
+
124
+ if self.config.text_model_name is None and all(x is None for x in texts):
125
+ text_embedding = None
126
+ else:
127
+ tokenized_texts = self.tokenizer(texts, return_tensors="pt")
128
+ text_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_texts["input_ids"][0])
129
+ text_output = self.text_model(**tokenized_texts, output_attentions=True)
130
+ text_attentions = text_output.attentions
131
+ text_embedding = text_output.last_hidden_state[:, 0 , :]
132
+ text_embedding = self.text_interaction_layer(text_embedding)
133
+
134
+ out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
135
+
136
+ return {
137
+ "logits": out.logits,
138
+ "vision_attentions": vision_attentions,
139
+ "text_tokens": text_tokens,
140
+ "text_attentions": text_attentions,
141
+ "time_series_attentions": out.attentions
142
+ }
143
+
144
+ def forward(self, input_ids = None, images = None, texts = None, labels = None):
145
+ if self.config.vision_model_name is None and images is None:
146
+ vision_embedding = None
147
+ else:
148
+ vision_embedding = self.vision_model(images)
149
+ vision_embedding = vision_embedding.pooler_output
150
+ vision_embedding = self.vision_interaction_layer(vision_embedding)
151
+
152
+ if self.config.text_model_name is None and all(x is None for x in texts):
153
+ text_embedding = None
154
+ else:
155
+ tokenized_texts = self.tokenizer(texts, return_tensors="pt")
156
+ text_embedding = self.text_model(**tokenized_texts)
157
+ text_embedding = text_embedding.last_hidden_state[:, 0 , :]
158
+ text_embedding = self.text_interaction_layer(text_embedding)
159
+
160
+ out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
161
+ out = out["logits"]
162
+
163
+ if labels is not None:
164
+ if self.config.forecasting_length == out.shape[-1]:
165
+ loss = torch.mean(torch.square(out-labels)) # MSE
166
+ else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error.
167
+ loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels))
168
+ else:
169
+ loss = None
170
+
171
+ return {
172
+ "loss": loss,
173
+ "logits": out
174
+ }
175
+
176
+ @classmethod
177
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
178
+ from transformers.utils import cached_file
179
+ config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path)
180
+ model = MulTiCastTimerModel(config)
181
+ resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors")
182
+ state_dict = load_file(resolved_file)
183
+ model.load_state_dict(state_dict, strict=False)
184
+
185
  return model