jimmygao3218 commited on
Commit
bbec003
·
verified ·
1 Parent(s): c33990e

Upload 4 files

Browse files
model_STSSDL/STSSDL.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import math
5
+ import numpy as np
6
+
7
+ class AGCN(nn.Module):
8
+ def __init__(self, dim_in, dim_out, cheb_k, num_support):
9
+ super(AGCN, self).__init__()
10
+ self.cheb_k = cheb_k
11
+ self.weights = nn.Parameter(torch.FloatTensor(num_support*cheb_k*dim_in, dim_out)) # num_support*cheb_k*dim_in is the length of support
12
+ # self.weights = nn.Parameter(torch.FloatTensor(dim_in, dim_out))
13
+ self.bias = nn.Parameter(torch.FloatTensor(dim_out))
14
+ nn.init.xavier_normal_(self.weights)
15
+ nn.init.constant_(self.bias, val=0)
16
+
17
+ def forward(self, x, supports):
18
+ x_g = []
19
+ for support in supports:
20
+ if len(support.shape) == 2:
21
+ support_ks = [torch.eye(support.shape[0]).to(support.device), support]
22
+ for k in range(2, self.cheb_k):
23
+ support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
24
+ for graph in support_ks:
25
+ x_g.append(torch.einsum("nm,bmc->bnc", graph, x))
26
+ else:
27
+ support_ks = [torch.eye(support.shape[1]).repeat(support.shape[0], 1, 1).to(support.device), support]
28
+ for k in range(2, self.cheb_k):
29
+ support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
30
+ for graph in support_ks:
31
+ x_g.append(torch.einsum("bnm,bmc->bnc", graph, x))
32
+ x_g = torch.cat(x_g, dim=-1)
33
+ x_gconv = torch.einsum('bni,io->bno', x_g, self.weights) + self.bias # b, N, dim_out
34
+ return x_gconv
35
+
36
+ class AGCRNCell(nn.Module):
37
+ def __init__(self, node_num, dim_in, dim_out, cheb_k, num_support):
38
+ super(AGCRNCell, self).__init__()
39
+ self.node_num = node_num
40
+ self.hidden_dim = dim_out
41
+ self.gate = AGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, num_support)
42
+ self.update = AGCN(dim_in+self.hidden_dim, dim_out, cheb_k, num_support)
43
+
44
+ def forward(self, x, state, supports):
45
+ #x: B, num_nodes, input_dim
46
+ #state: B, num_nodes, hidden_dim
47
+ state = state.to(x.device)
48
+ input_and_state = torch.cat((x, state), dim=-1)
49
+ z_r = torch.sigmoid(self.gate(input_and_state, supports))
50
+ z, r = torch.split(z_r, self.hidden_dim, dim=-1)
51
+ candidate = torch.cat((x, z*state), dim=-1)
52
+ hc = torch.tanh(self.update(candidate, supports))
53
+ h = r*state + (1-r)*hc
54
+ return h
55
+
56
+ def init_hidden_state(self, batch_size):
57
+ return torch.zeros(batch_size, self.node_num, self.hidden_dim)
58
+
59
+ class ADCRNN_Encoder(nn.Module):
60
+ def __init__(self, node_num, dim_in, dim_out, cheb_k, rnn_layers, num_support):
61
+ super(ADCRNN_Encoder, self).__init__()
62
+ assert rnn_layers >= 1, 'At least one DCRNN layer in the Encoder.'
63
+ self.node_num = node_num
64
+ self.input_dim = dim_in
65
+ self.rnn_layers = rnn_layers
66
+ self.dcrnn_cells = nn.ModuleList()
67
+ self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, num_support))
68
+ for _ in range(1, rnn_layers):
69
+ self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, num_support))
70
+
71
+ def forward(self, x, init_state, supports):
72
+ #shape of x: (B, T, N, D), shape of init_state: (rnn_layers, B, N, hidden_dim)
73
+ assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
74
+ seq_length = x.shape[1]
75
+ current_inputs = x
76
+ output_hidden = []
77
+ for i in range(self.rnn_layers):
78
+ state = init_state[i]
79
+ inner_states = []
80
+ for t in range(seq_length):
81
+ state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, supports)
82
+ inner_states.append(state)
83
+ output_hidden.append(state)
84
+ current_inputs = torch.stack(inner_states, dim=1)
85
+ #current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
86
+ #last_state: (B, N, hidden_dim)
87
+ #output_hidden: the last state for each layer: (rnn_layers, B, N, hidden_dim)
88
+ #return current_inputs, torch.stack(output_hidden, dim=0)
89
+ return current_inputs, output_hidden
90
+
91
+ def init_hidden(self, batch_size):
92
+ init_states = []
93
+ for i in range(self.rnn_layers):
94
+ init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
95
+ return init_states
96
+
97
+ class ADCRNN_Decoder(nn.Module):
98
+ def __init__(self, node_num, dim_in, dim_out, cheb_k, rnn_layers, num_support):
99
+ super(ADCRNN_Decoder, self).__init__()
100
+ assert rnn_layers >= 1, 'At least one DCRNN layer in the Decoder.'
101
+ self.node_num = node_num
102
+ self.input_dim = dim_in
103
+ self.rnn_layers = rnn_layers
104
+ self.dcrnn_cells = nn.ModuleList()
105
+ self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, num_support))
106
+ for _ in range(1, rnn_layers):
107
+ self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, num_support))
108
+
109
+ def forward(self, xt, init_state, supports):
110
+ # xt: (B, N, D)
111
+ # init_state: (rnn_layers, B, N, hidden_dim)
112
+ assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
113
+ current_inputs = xt
114
+ output_hidden = []
115
+ for i in range(self.rnn_layers):
116
+ state = self.dcrnn_cells[i](current_inputs, init_state[i], supports)
117
+ output_hidden.append(state)
118
+ current_inputs = state
119
+ return current_inputs, output_hidden
120
+
121
+
122
+ class STSSDL(nn.Module):
123
+ def __init__(self, num_nodes=207, input_dim=1, output_dim=1, horizon=12, rnn_units=128, rnn_layers=1, cheb_k=3,
124
+ ycov_dim=1, prototype_num=20, prototype_dim=64, tod_embed_dim=10, adj_mx=None, cl_decay_steps=2000,
125
+ TDAY=288, use_curriculum_learning=True, use_STE=False, device="cpu",adaptive_embedding_dim=48,node_embedding_dim=20,input_embedding_dim=128):
126
+ super(STSSDL, self).__init__()
127
+ self.num_nodes = num_nodes
128
+ self.input_dim = input_dim
129
+ self.rnn_units = rnn_units
130
+ self.output_dim = output_dim
131
+ self.horizon = horizon
132
+ self.rnn_layers = rnn_layers
133
+ self.cheb_k = cheb_k
134
+ self.ycov_dim = ycov_dim
135
+ self.tod_embed_dim = tod_embed_dim
136
+ self.cl_decay_steps = cl_decay_steps
137
+ self.use_curriculum_learning = use_curriculum_learning
138
+ self.device = device
139
+ self.use_STE = use_STE
140
+ self.TDAY = TDAY
141
+ self.adaptive_embedding_dim=adaptive_embedding_dim
142
+ self.node_embedding_dim = node_embedding_dim
143
+ self.input_embedding_dim=input_embedding_dim
144
+ self.total_embedding_dim= self.tod_embed_dim+self.adaptive_embedding_dim+self.node_embedding_dim
145
+ # prototypes
146
+ self.prototype_num = prototype_num
147
+ self.prototype_dim = prototype_dim
148
+ self.prototypes = self.construct_prototypes()
149
+
150
+ # projection & spatio-temporal embedding
151
+ if self.use_STE:
152
+ if self.adaptive_embedding_dim > 0:
153
+ self.adaptive_embedding = nn.init.xavier_uniform_(
154
+ nn.Parameter(torch.empty(12, num_nodes, self.adaptive_embedding_dim))
155
+ )
156
+ self.input_proj = nn.Linear(self.input_dim, input_embedding_dim)
157
+ self.node_embedding = nn.Parameter(torch.empty(self.num_nodes, self.node_embedding_dim))
158
+ self.time_embedding = nn.Parameter(torch.empty(self.TDAY, self.tod_embed_dim))
159
+ nn.init.xavier_uniform_(self.node_embedding)
160
+ nn.init.xavier_uniform_(self.time_embedding)
161
+
162
+ # encoder
163
+ self.adj_mx = adj_mx
164
+ if self.use_STE:
165
+ self.encoder = ADCRNN_Encoder(self.num_nodes, input_embedding_dim + self.total_embedding_dim, self.rnn_units, self.cheb_k, self.rnn_layers, len(self.adj_mx))
166
+ else:
167
+ self.encoder = ADCRNN_Encoder(self.num_nodes, self.input_dim, self.rnn_units, self.cheb_k, self.rnn_layers, len(self.adj_mx))
168
+
169
+ # decoder
170
+ self.decoder_dim = self.rnn_units + self.prototype_dim
171
+ if self.use_STE:
172
+ self.decoder = ADCRNN_Decoder(self.num_nodes, input_embedding_dim + self.total_embedding_dim-self.adaptive_embedding_dim, self.decoder_dim, self.cheb_k, self.rnn_layers, 1)
173
+ else:
174
+ self.decoder = ADCRNN_Decoder(self.num_nodes, self.output_dim + self.ycov_dim, self.decoder_dim, self.cheb_k, self.rnn_layers, 1)
175
+
176
+ # output
177
+ self.proj = nn.Sequential(nn.Linear(self.decoder_dim, self.output_dim, bias=True))
178
+
179
+ # graph
180
+ self.hypernet = nn.Sequential(nn.Linear(self.decoder_dim*2, self.tod_embed_dim, bias=True))
181
+
182
+ self.act_dict = {'relu': nn.ReLU(), 'lrelu': nn.LeakyReLU(), 'sigmoid': nn.Sigmoid()}
183
+ self.act_fn = 'sigmoid' # 'relu' 'lrelu' 'sigmoid'
184
+
185
+ def compute_sampling_threshold(self, batches_seen):
186
+ return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
187
+
188
+ def construct_prototypes(self):
189
+ prototypes_dict = nn.ParameterDict()
190
+ prototype = torch.randn(self.prototype_num, self.prototype_dim)
191
+ prototypes_dict['prototypes'] = nn.Parameter(prototype, requires_grad=True) # (M, d)
192
+ prototypes_dict['Wq'] = nn.Parameter(torch.randn(self.rnn_units, self.prototype_dim), requires_grad=True) # project to query
193
+ for param in prototypes_dict.values():
194
+ nn.init.xavier_normal_(param)
195
+
196
+ return prototypes_dict
197
+
198
+ def query_prototypes(self, h_t:torch.Tensor):
199
+ query = torch.matmul(h_t, self.prototypes['Wq']) # (B, N, d)
200
+ att_score = torch.softmax(torch.matmul(query, self.prototypes['prototypes'].t()), dim=-1) # alpha: (B, N, M)
201
+ value = torch.matmul(att_score, self.prototypes['prototypes']) # (B, N, d)
202
+ _, ind = torch.topk(att_score, k=2, dim=-1)
203
+ pos = self.prototypes['prototypes'][ind[:, :, 0]] # B, N, d
204
+ neg = self.prototypes['prototypes'][ind[:, :, 1]] # B, N, d
205
+ mask = torch.stack([ind[:, :, 0], ind[:, :, 1]], dim=-1) # B, N, 2
206
+
207
+ return value, query, pos, neg, mask
208
+
209
+ def calculate_distance(self, pos, pos_his, mask=None):
210
+ score = torch.sum(torch.abs(pos - pos_his), dim=-1)
211
+ return score, mask
212
+
213
+ def forward(self, x, x_cov, x_his, y_cov, labels=None, batches_seen=None):
214
+ if self.use_STE:
215
+ if self.input_embedding_dim>0:
216
+ x = self.input_proj(x) # [B,T,N,1]->[B,T,N,D]
217
+ features = [x]
218
+
219
+ tod = x_cov.squeeze() # [B, T, N]
220
+ if self.tod_embed_dim>0:
221
+ time_emb = self.time_embedding[(x_cov.squeeze() * self.TDAY).type(torch.LongTensor)] # [B, T, N, d]
222
+ features.append(time_emb)
223
+ if self.adaptive_embedding_dim > 0:
224
+ adp_emb = self.adaptive_embedding.expand(
225
+ size=(x.shape[0], *self.adaptive_embedding.shape)
226
+ )
227
+ features.append(adp_emb)
228
+ if self.node_embedding_dim>0:
229
+ node_emb = self.node_embedding.unsqueeze(0).unsqueeze(1).expand(x.shape[0], self.horizon, -1, -1) # [B,T,N,d]
230
+ features.append(node_emb)
231
+ x = torch.cat(features, dim=-1) # [B, T, N, D+d+80]
232
+ supports_en = self.adj_mx
233
+ init_state = self.encoder.init_hidden(x.shape[0])
234
+ h_en, state_en = self.encoder(x, init_state, supports_en) # B, T, N, hidden
235
+ h_t = h_en[:, -1, :, :] # B, N, hidden (last state)
236
+ v_t, q_t, p_t, n_t, mask = self.query_prototypes(h_t)
237
+ if self.use_STE:
238
+ if self.input_embedding_dim>0:
239
+ x_his = self.input_proj(x_his) # [B,T,N,1]->[B,T,N,D]
240
+ features = [x_his]
241
+ tod = x_cov.squeeze() # [B, T, N]
242
+ if self.tod_embed_dim>0:
243
+ time_emb = self.time_embedding[(x_cov.squeeze() * self.TDAY).type(torch.LongTensor)] # [B, T, N, d]
244
+
245
+ features.append(time_emb)
246
+ if self.adaptive_embedding_dim > 0:
247
+ adp_emb = self.adaptive_embedding.expand(
248
+ size=(x.shape[0], *self.adaptive_embedding.shape)
249
+ )
250
+ features.append(adp_emb)
251
+ if self.node_embedding_dim>0:
252
+ node_emb = self.node_embedding.unsqueeze(0).unsqueeze(1).expand(x.shape[0], self.horizon, -1, -1) # [B,T,N,d]
253
+ features.append(node_emb)
254
+ x_his = torch.cat(features, dim=-1) # [B, T, N, D+d+80]
255
+ h_his_en, state_his_en = self.encoder(x_his, init_state, supports_en) # B, T, N, hidden
256
+ h_a = h_his_en[:, -1, :, :] # B, N, hidden (last state)
257
+ v_a, q_a, p_a, n_a, mask_his = self.query_prototypes(h_a)
258
+
259
+ latent_dis, _ = self.calculate_distance(q_t, q_a)
260
+ prototype_dis, mask_dis = self.calculate_distance(p_t, p_a)
261
+
262
+ query = torch.stack([q_t, q_a], dim=0)
263
+ pos = torch.stack([p_t, p_a], dim=0)
264
+ neg = torch.stack([n_t, n_a], dim=0)
265
+ mask = torch.stack([mask, mask_his], dim=0) if mask is not None else [None, None]
266
+
267
+ h_de = torch.cat([h_t, v_t], dim=-1)
268
+ h_aug = torch.cat([h_t, v_t, h_a, v_a], dim=-1) # B, N, D
269
+
270
+ node_embeddings = self.hypernet(h_aug) # B, N, e
271
+ support = F.softmax(F.relu(torch.einsum('bnc,bmc->bnm', node_embeddings, node_embeddings)), dim=-1)
272
+ supports_de = [support]
273
+
274
+ ht_list = [h_de]*self.rnn_layers
275
+ go = torch.zeros((x.shape[0], self.num_nodes, self.output_dim), device=x.device)
276
+
277
+ out = []
278
+ for t in range(self.horizon):
279
+ if self.use_STE:
280
+ if self.input_embedding_dim>0:
281
+ go = self.input_proj(go) # equal to torch.zeros(B,N,D)
282
+ features = [go]
283
+ tod = y_cov[:, t, ...].squeeze() # [B, T, N]
284
+ if self.tod_embed_dim>0:
285
+ time_emb = self.time_embedding[(tod * self.TDAY).type(torch.LongTensor)]
286
+ features.append(time_emb)
287
+ if self.node_embedding_dim>0:
288
+ node_emb = self.node_embedding.unsqueeze(0).expand(x.shape[0], -1, -1) # [B,N,d]
289
+ features.append(node_emb)
290
+ go = torch.cat(features, dim=-1) # [B, T, N, D+d]
291
+ h_de, ht_list = self.decoder(go, ht_list, supports_de)
292
+ else:
293
+ h_de, ht_list = self.decoder(torch.cat([go, y_cov[:, t, ...]], dim=-1), ht_list, supports_de)
294
+ go = self.proj(h_de)
295
+ out.append(go)
296
+ if self.training and self.use_curriculum_learning:
297
+ c = np.random.uniform(0, 1)
298
+ if c < self.compute_sampling_threshold(batches_seen):
299
+ go = labels[:, t, ...]
300
+
301
+ output = torch.stack(out, dim=1)
302
+
303
+ return output, query, pos, neg, mask, latent_dis, prototype_dis
304
+
305
+ def print_params(model):
306
+ # print trainable params
307
+ param_count = 0
308
+ print('Trainable parameter list:')
309
+ for name, param in model.named_parameters():
310
+ if param.requires_grad:
311
+ print(name, param.shape, param.numel())
312
+ param_count += param.numel()
313
+ print(f'In total: {param_count} trainable parameters.')
314
+ return
315
+
316
+ def main():
317
+ from torchinfo import summary
318
+ from utils import load_adj
319
+
320
+ adj_mx = load_adj('../METRLA/adj_mx.pkl', "symadj")
321
+ adj_mx = [torch.FloatTensor(i) for i in adj_mx]
322
+ model = STSSDL(adj_mx=adj_mx)
323
+ summary(model, [[8, 12, 207, 1], [8, 12, 207, 1], [8, 12, 207, 1], [8, 12, 207, 1]], device="cpu")
324
+
325
+ if __name__ == '__main__':
326
+ main()
model_STSSDL/metrics.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def MSE(y_true, y_pred, mask=None):
4
+
5
+ mse = torch.square(y_pred - y_true)
6
+ if mask is not None:
7
+ mse_ig = mse * (1 - mask) # max: 0.5 (0.9422?) min:0, but max should be close 0
8
+ mse = mse * mask # max: 0.65, min:0
9
+ anomoly_num = torch.sum(mask)
10
+ normal_num = torch.sum(1 - mask) # in general, normal = 3 * anomoly
11
+ mse_ig = torch.mean(mse_ig) # mae_ig = 3 * mae alought max:0.9422, mean:0.04, i.e., most is accurately predicted
12
+ mse = torch.mean(mse)
13
+
14
+ return mse
15
+
16
+ def RMSE(y_true, y_pred, mask=None):
17
+
18
+ rmse = torch.square(torch.abs(y_pred - y_true))
19
+ rmse = torch.sqrt(MSE(y_true, y_pred, mask))
20
+ return rmse
21
+
22
+ def MAE(y_true, y_pred, mask=None):
23
+
24
+ mae = torch.abs(y_pred - y_true)
25
+ if mask is not None:
26
+ mae_ig = mae * (1 - mask) # max: 0.5 (0.9422?) min:0, but max should be close 0
27
+ mae = mae * mask # max: 0.65, min:0
28
+ anomoly_num = torch.sum(mask)
29
+ normal_num = torch.sum(1 - mask) # in general, normal = 3 * anomoly
30
+ mae_ig = torch.mean(mae_ig) # mae_ig = 3 * mae alought max:0.9422, mean:0.04, i.e., most is accurately predicted
31
+ mae = torch.mean(mae)
32
+
33
+ return mae
34
+
35
+
36
+
37
+
38
+
model_STSSDL/train_STSSDL.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import shutil
4
+ import numpy as np
5
+ import pandas as pd
6
+ import time
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.init as init
10
+ import torch.nn.functional as F
11
+ from torchinfo import summary
12
+ import argparse
13
+ import logging
14
+ from utils import StandardScaler, masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
15
+ from utils import load_adj
16
+ from metrics import RMSE, MAE, MSE
17
+ from STSSDL import STSSDL
18
+ import random
19
+ class ContrastiveLoss():
20
+ def __init__(self, contra_loss='triplet', mask=None, temp=1.0, margin=0.5):
21
+ self.infonce = contra_loss in ['infonce']
22
+ self.mask = mask
23
+ self.temp = temp
24
+ self.margin = margin
25
+
26
+ def calculate(self, query, pos, neg, mask):
27
+ """
28
+ :param query: shape (batch_size, num_sensor, hidden_dim)
29
+ :param pos: shape (batch_size, num_sensor, hidden_dim)
30
+ :param neg: shape (batch_size, num_sensor, hidden_dim) or (batch_size, num_sensor, num_prototypes, hidden_dim)
31
+ :param mask: shape (batch_size, num_sensor, num_prototypes) True means positives
32
+ """
33
+ contrastive_loss = nn.TripletMarginLoss(margin=self.margin)
34
+ return contrastive_loss(query.detach(), pos, neg)
35
+
36
+
37
+ def print_model(model):
38
+ param_count = 0
39
+ logger.info('Trainable parameter list:')
40
+ for name, param in model.named_parameters():
41
+ if param.requires_grad:
42
+ print(name, param.shape, param.numel())
43
+ param_count += param.numel()
44
+ logger.info(f'In total: {param_count} trainable parameters.')
45
+ return
46
+
47
+ def get_model():
48
+ adj_mx = load_adj(adj_mx_path, args.adj_type)
49
+ adjs = [torch.tensor(i).to(device) for i in adj_mx]
50
+ model = STSSDL(num_nodes=args.num_nodes, input_dim=args.input_dim, output_dim=args.output_dim, horizon=args.horizon,
51
+ rnn_units=args.rnn_units, rnn_layers=args.rnn_layers, cheb_k = args.cheb_k, prototype_num=args.prototype_num,
52
+ prototype_dim=args.prototype_dim, tod_embed_dim=args.tod_embed_dim, adj_mx = adjs, cl_decay_steps=args.cl_decay_steps,
53
+ use_curriculum_learning=args.use_curriculum_learning, use_STE=args.use_STE, adaptive_embedding_dim=args.adaptive_embedding_dim,node_embedding_dim=args.node_embedding_dim,input_embedding_dim=args.input_embedding_dim,device=device).to(device)
54
+ return model
55
+
56
+ def prepare_x_y(x, y):
57
+ """
58
+ :param x: shape (batch_size, seq_len, num_sensor, input_dim)
59
+ :param y: shape (batch_size, horizon, num_sensor, input_dim)
60
+ :return1: x shape (seq_len, batch_size, num_sensor, input_dim)
61
+ y shape (horizon, batch_size, num_sensor, input_dim)
62
+ :return2: x: shape (seq_len, batch_size, num_sensor * input_dim)
63
+ y: shape (horizon, batch_size, num_sensor * output_dim)
64
+ """
65
+ x0 = x[..., 0:1]
66
+ x1 = x[..., 1:2]
67
+ x2 = x[..., 2:3]
68
+ y0 = y[..., 0:1]
69
+ y1 = y[..., 1:2]
70
+ return x0, x1, x2, y0, y1 # x, x_cov, x_his, y, y_cov
71
+
72
+ def evaluate(model, mode):
73
+ with torch.no_grad():
74
+ model = model.eval()
75
+ data_iter = data[f'{mode}_loader']
76
+ ys_true, ys_pred = [], []
77
+ losses = []
78
+ for x, y in data_iter:
79
+ x = x.to(device)
80
+ y = y.to(device)
81
+ x, x_cov, x_his, y, y_cov = prepare_x_y(x, y)
82
+ output, _, _, _, _, _, _ = model(x, x_cov, x_his, y_cov)
83
+ y_pred = scaler.inverse_transform(output)
84
+ y_true = y
85
+ ys_true.append(y_true)
86
+ ys_pred.append(y_pred)
87
+ losses.append(masked_mae_loss(y_pred, y_true).item())
88
+
89
+ ys_true, ys_pred = torch.cat(ys_true, dim=0), torch.cat(ys_pred, dim=0)
90
+ loss = masked_mae_loss(ys_pred, ys_true)
91
+
92
+ if mode == 'test':
93
+ mae = masked_mae_loss(ys_pred, ys_true).item()
94
+ mape = masked_mape_loss(ys_pred, ys_true).item()
95
+ rmse = masked_rmse_loss(ys_pred, ys_true).item()
96
+ mae_3 = masked_mae_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
97
+ mape_3 = masked_mape_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
98
+ rmse_3 = masked_rmse_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
99
+ mae_6 = masked_mae_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
100
+ mape_6 = masked_mape_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
101
+ rmse_6 = masked_rmse_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
102
+ mae_12 = masked_mae_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
103
+ mape_12 = masked_mape_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
104
+ rmse_12 = masked_rmse_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
105
+
106
+ logger.info('Horizon overall: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae, mape * 100, rmse))
107
+ logger.info('Horizon 15mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_3, mape_3 * 100, rmse_3))
108
+ logger.info('Horizon 30mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_6, mape_6 * 100, rmse_6))
109
+ logger.info('Horizon 60mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_12, mape_12 * 100, rmse_12))
110
+
111
+ return np.mean(losses), ys_true, ys_pred
112
+
113
+
114
+ def traintest_model():
115
+ model = get_model()
116
+ print_model(model)
117
+
118
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=args.epsilon, weight_decay=args.weight_decay)
119
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.steps, gamma=args.lr_decay_ratio)
120
+ min_val_loss = float('inf')
121
+ wait = 0
122
+ batches_seen = 0
123
+ for epoch_num in range(args.epochs):
124
+ start_time = time.time()
125
+ model = model.train()
126
+ data_iter = data['train_loader']
127
+ losses, mae_losses, contra_losses, deviation_losses = [], [], [], []
128
+ for x, y in data_iter:
129
+ optimizer.zero_grad()
130
+ x = x.to(device)
131
+ y = y.to(device)
132
+ x, x_cov, x_his, y, y_cov = prepare_x_y(x, y)
133
+ output, query, pos, neg, mask, query_simi, pos_simi = model(x, x_cov, x_his, y_cov, scaler.transform(y), batches_seen)
134
+ y_pred = scaler.inverse_transform(output)
135
+ y_true = y
136
+
137
+ mae_loss = masked_mae_loss(y_pred, y_true) # masked_mae_loss(y_pred, y_true)
138
+ contrastive_loss = ContrastiveLoss(contra_loss=args.contra_loss, mask=mask, temp=args.temp)
139
+
140
+ loss_c = contrastive_loss.calculate(query[0], pos[0], neg[0], mask[0])
141
+ loss_d = F.l1_loss(query_simi.detach(), pos_simi)
142
+ loss = mae_loss + args.lamb_c * loss_c + args.lamb_d * loss_d
143
+
144
+ losses.append(loss.item())
145
+ mae_losses.append(mae_loss.item())
146
+ contra_losses.append(loss_c.item())
147
+ deviation_losses.append(loss_d.item())
148
+ losses.append(loss.item())
149
+ batches_seen += 1
150
+ loss.backward()
151
+ if args.max_grad_norm:
152
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # gradient clipping - this does it in place
153
+ optimizer.step()
154
+
155
+ end_time2 = time.time()
156
+ train_loss = np.mean(losses)
157
+ train_mae_loss = np.mean(mae_losses)
158
+ train_contra_loss = np.mean(contra_losses)
159
+ train_deviation_loss = np.mean(deviation_losses)
160
+ lr_scheduler.step()
161
+ val_loss, _, _ = evaluate(model, 'val')
162
+ message = 'Epoch [{}/{}] ({}) train_loss: {:.4f}, train_mae_loss: {:.4f}, train_contra_loss: {:.4f}, train_deviation_loss: {:.4f}, val_loss: {:.4f}, lr: {:.6f}, {:.2f}s'.format(epoch_num + 1, args.epochs, batches_seen, train_loss, train_mae_loss, train_contra_loss, train_deviation_loss, val_loss, optimizer.param_groups[0]['lr'], (end_time2 - start_time))
163
+ logger.info(message)
164
+
165
+ test_loss, _, _ = evaluate(model, 'test')
166
+ logger.info("\n")
167
+
168
+ if val_loss < min_val_loss:
169
+ wait = 0
170
+ min_val_loss = val_loss
171
+ torch.save(model.state_dict(), modelpt_path)
172
+ elif val_loss >= min_val_loss:
173
+ wait += 1
174
+ if wait == args.patience:
175
+ logger.info('Early stopping at epoch: %d' % (epoch_num + 1))
176
+ break
177
+
178
+ logger.info('=' * 35 + 'Best val_loss model performance' + '=' * 35)
179
+ logger.info('=' * 22 + 'Better results might be found from model at different epoch' + '=' * 22)
180
+ model = get_model()
181
+ model.load_state_dict(torch.load(modelpt_path))
182
+ start=time.time()
183
+ test_loss, _, _ = evaluate(model, 'test')
184
+ end=time.time()
185
+ logger.info(f"Inference Time: {(end-start):.2f}s")
186
+
187
+ #########################################################################################
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument('--dataset', type=str, choices=['METRLA', 'PEMSBAY','PEMS04','PEMS07','PEMS08','PEMSD7M'], default='METRLA', help='which dataset to run')
190
+ parser.add_argument('--num_nodes', type=int, default=207, help='num_nodes')
191
+ parser.add_argument('--seq_len', type=int, default=12, help='input sequence length')
192
+ parser.add_argument('--horizon', type=int, default=12, help='output sequence length')
193
+ parser.add_argument('--input_dim', type=int, default=1, help='number of input channel')
194
+ parser.add_argument('--output_dim', type=int, default=1, help='number of output channel')
195
+ parser.add_argument('--tod_embed_dim', type=int, default=10, help='embedding dimension for adaptive graph')
196
+ parser.add_argument('--cheb_k', type=int, default=3, help='max diffusion step or Cheb K')
197
+ parser.add_argument('--rnn_layers', type=int, default=1, help='number of rnn layers')
198
+ parser.add_argument('--rnn_units', type=int, default=128, help='number of rnn units')
199
+ parser.add_argument('--prototype_num', type=int, default=20, help='number of meta-nodes/prototypes')
200
+ parser.add_argument('--prototype_dim', type=int, default=64, help='dimension of meta-nodes/prototypes')
201
+ parser.add_argument("--loss", type=str, default='mask_mae_loss', help="mask_mae_loss")
202
+ parser.add_argument("--epochs", type=int, default=200, help="number of epochs of training")
203
+ parser.add_argument("--patience", type=int, default=30, help="patience used for early stop")
204
+ parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
205
+ parser.add_argument("--lr", type=float, default=0.01, help="base learning rate")
206
+ parser.add_argument("--steps", type=eval, default=[50, 100], help="steps")
207
+ parser.add_argument("--lr_decay_ratio", type=float, default=0.1, help="lr_decay_ratio")
208
+ parser.add_argument("--weight_decay", type=float, default=0, help="weight_decay_ratio")
209
+ parser.add_argument("--epsilon", type=float, default=1e-3, help="optimizer epsilon")
210
+ parser.add_argument("--max_grad_norm", type=int, default=5, help="max_grad_norm")
211
+ parser.add_argument("--use_curriculum_learning", type=eval, choices=[True, False], default='True', help="use_curriculum_learning")
212
+ parser.add_argument("--adj_type", type=str, default='symadj', help="scalap, normlap, symadj, transition, doubletransition")
213
+ parser.add_argument("--cl_decay_steps", type=int, default=2000, help="cl_decay_steps")
214
+ parser.add_argument('--gpu', type=int, default=0, help='which gpu to use')
215
+ parser.add_argument('--seed', type=int, default=100, help='random seed.')
216
+ parser.add_argument('--temp', type=float, default=1.0, help='temperature parameter')
217
+ parser.add_argument('--lamb_c', type=float, default=0.1, help='contra loss lambda')
218
+ parser.add_argument('--lamb_d', type=float, default=1.0, help='deviation loss lambda')
219
+ parser.add_argument('--contra_loss', type=str, choices=['triplet', 'infonce'], default='triplet', help='whether to triplet or infonce contra loss')
220
+ parser.add_argument("--use_STE", type=eval, choices=[True, False], default='True', help="use spatio-temporal embedding")
221
+ parser.add_argument("--adaptive_embedding_dim", type=int,default=48, help="use spatio-temporal adaptive embedding")
222
+ parser.add_argument("--node_embedding_dim", type=int,default=20, help="use spatio-temporal adaptive embedding")
223
+ parser.add_argument("--input_embedding_dim", type=int,default=128, help="use spatio-temporal adaptive embedding")
224
+
225
+ args = parser.parse_args()
226
+ num_nodes_dict={
227
+ "METRLA": 207,
228
+ "PEMSBAY": 325,
229
+ "PEMS04": 307,
230
+ "PEMS07": 883,
231
+ "PEMS08": 170,
232
+ "PEMSD7M": 228,
233
+ }
234
+ if args.dataset == 'METRLA':
235
+ data_path = f'../{args.dataset}/metr-la.h5'
236
+ adj_mx_path = f'../{args.dataset}/adj_mx.pkl'
237
+ args.num_nodes = 207
238
+ args.use_STE=True
239
+ rand_seed=random.randint(0, 1000000)# 31340
240
+ args.seed=999
241
+ args.lamb_c=0.01
242
+ args.lamb_d=1
243
+ args.steps = [50,70]
244
+ args.input_embedding_dim=3
245
+ args.node_embedding_dim=25
246
+ args.tod_embed_dim=20 #TOD embedding
247
+ args.adaptive_embedding_dim=0
248
+
249
+ elif args.dataset == 'PEMSBAY':
250
+ data_path = f'../{args.dataset}/pems-bay.h5'
251
+ adj_mx_path = f'../{args.dataset}/adj_mx_bay.pkl'
252
+ args.num_nodes = 325
253
+ args.use_STE=True
254
+ args.cl_decay_steps = 8000
255
+ args.steps = [10, 70,150]
256
+ args.seed=666
257
+ args.lamb_c=0.01
258
+ args.lamb_d=1
259
+ args.input_embedding_dim=10
260
+ args.node_embedding_dim=20
261
+ args.tod_embed_dim=20 #TOD embedding
262
+ args.adaptive_embedding_dim=0
263
+
264
+ elif args.dataset == 'PEMS04':
265
+ data_path = f'../{args.dataset}/{args.dataset}.npz'
266
+ adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
267
+ args.num_nodes = num_nodes_dict[args.dataset]
268
+ rand_seed=random.randint(0, 1000000)# 31340
269
+ args.seed=610958
270
+ args.patience=30
271
+ args.batch_size=16
272
+ args.lr=0.001
273
+ args.epochs=200
274
+ args.steps=[50, 100]
275
+ args.weight_decay=0
276
+ args.max_grad_norm=0
277
+ args.rnn_units=32
278
+ args.prototype_num=20
279
+ args.prototype_dim=64
280
+ args.cl_decay_steps=6000
281
+ args.max_diffusion_step=3
282
+ args.input_embedding_dim=32
283
+ args.node_embedding_dim=24
284
+ args.tod_embed_dim=40 #TOD embedding
285
+ args.adaptive_embedding_dim=0
286
+ args.use_curriculum_learning=True
287
+ args.lamb_c=0.01
288
+ args.lamb_d=0.01
289
+
290
+
291
+ elif args.dataset == 'PEMS07':
292
+ data_path = f'../{args.dataset}/{args.dataset}.npz'
293
+ adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
294
+ args.num_nodes = num_nodes_dict[args.dataset]
295
+ args.patience=20
296
+ args.batch_size=16
297
+ args.lr=0.001
298
+ args.steps=[50, 100]
299
+ args.weight_decay=0
300
+ args.max_grad_norm=0
301
+ args.rnn_units=64
302
+ args.prototype_num=20
303
+ args.prototype_dim=64
304
+ args.cl_decay_steps=6000
305
+ args.max_diffusion_step=3
306
+ args.lamb_c=0.01
307
+ args.lamb_d=1
308
+ args.seed=100
309
+ args.input_embedding_dim=64
310
+ args.node_embedding_dim=16
311
+ args.tod_embed_dim=16
312
+ args.adaptive_embedding_dim=0
313
+ elif args.dataset == 'PEMS08':
314
+ data_path = f'../{args.dataset}/{args.dataset}.npz'
315
+ adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
316
+ args.num_nodes = num_nodes_dict[args.dataset]
317
+ args.use_STE=True
318
+ args.patience=20
319
+ args.batch_size=16
320
+ rand_seed=random.randint(0, 1000000)# 31340
321
+ args.seed=rand_seed
322
+ args.cl_decay_steps=6000
323
+ args.max_diffusion_step=3
324
+ args.steps=[70, 100]
325
+ args.prototype_num=20
326
+ args.prototype_dim=64
327
+ args.use_curriculum_learning=True
328
+ args.rnn_units = 12
329
+ args.lamb_c=0.1
330
+ args.lamb_d=1
331
+ args.input_embedding_dim=16
332
+ args.node_embedding_dim=20
333
+ args.tod_embed_dim=20 #TOD embedding
334
+ args.adaptive_embedding_dim=0
335
+
336
+ elif args.dataset == 'PEMSD7M':
337
+ data_path = f'../{args.dataset}/{args.dataset}.npz'
338
+ adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
339
+ args.num_nodes = num_nodes_dict[args.dataset]
340
+ rand_seed=random.randint(0, 1000000)# 31340
341
+ args.seed=119089
342
+ args.patience=30
343
+ args.batch_size=16
344
+ args.lr=0.001
345
+ args.steps=[50, 100]
346
+ args.weight_decay=0
347
+ args.max_grad_norm=0
348
+ args.rnn_units=32
349
+ args.prototype_num=16
350
+ args.prototype_dim=64
351
+ args.cl_decay_steps=4000
352
+ args.max_diffusion_step=3
353
+ args.lamb_c=0.1
354
+ args.lamb_d=1
355
+ args.input_embedding_dim=32
356
+ args.node_embedding_dim=20
357
+ args.tod_embed_dim=16 #TOD embedding
358
+ args.adaptive_embedding_dim=0
359
+
360
+ model_name = 'STSSDL'
361
+ timestring = time.strftime('%Y%m%d%H%M%S', time.localtime())
362
+ path = f'../save/{args.dataset}_{model_name}_{timestring}'
363
+ logging_path = f'{path}/{model_name}_{timestring}_logging.txt'
364
+ score_path = f'{path}/{model_name}_{timestring}_scores.txt'
365
+ epochlog_path = f'{path}/{model_name}_{timestring}_epochlog.txt'
366
+ modelpt_path = f'{path}/{model_name}_{timestring}.pt'
367
+ if not os.path.exists(path): os.makedirs(path)
368
+ shutil.copy2(sys.argv[0], path)
369
+ shutil.copy2(f'{model_name}.py', path)
370
+ shutil.copy2('utils.py', path)
371
+
372
+ logger = logging.getLogger(__name__)
373
+ logger.setLevel(level = logging.INFO)
374
+ class MyFormatter(logging.Formatter):
375
+ def format(self, record):
376
+ spliter = ' '
377
+ record.msg = str(record.msg) + spliter + spliter.join(map(str, record.args))
378
+ record.args = tuple() # set empty to args
379
+ return super().format(record)
380
+ formatter = MyFormatter()
381
+ handler = logging.FileHandler(logging_path, mode='a')
382
+ handler.setLevel(logging.INFO)
383
+ handler.setFormatter(formatter)
384
+ console = logging.StreamHandler()
385
+ console.setLevel(logging.INFO)
386
+ console.setFormatter(formatter)
387
+ logger.addHandler(handler)
388
+ logger.addHandler(console)
389
+ message = ''.join([f'{k}: {v}\n' for k, v in vars(args).items()])
390
+ logger.info(message)
391
+
392
+ cpu_num = 1
393
+ os.environ ['OMP_NUM_THREADS'] = str(cpu_num)
394
+ os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
395
+ os.environ ['MKL_NUM_THREADS'] = str(cpu_num)
396
+ os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
397
+ os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
398
+ torch.set_num_threads(cpu_num)
399
+ device = torch.device("cuda:{}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu")
400
+
401
+ np.random.seed(args.seed)
402
+ torch.manual_seed(args.seed)
403
+ if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed)
404
+ #####################################################################################################
405
+
406
+ data = {}
407
+ for category in ['train', 'val', 'test']:
408
+ cat_data = np.load(os.path.join(f'../{args.dataset}', category + 'his.npz'))
409
+ data['x_' + category] = np.nan_to_num(cat_data['x']) if True in np.isnan(cat_data['x']) else cat_data['x']
410
+ data['y_' + category] = np.nan_to_num(cat_data['y']) if True in np.isnan(cat_data['y']) else cat_data['y']
411
+ scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std())
412
+ for category in ['train', 'val', 'test']:
413
+ data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
414
+ data['x_' + category][..., 2] = scaler.transform(data['x_' + category][..., 2]) # x_his
415
+
416
+ data['train_loader'] = torch.utils.data.DataLoader(
417
+ torch.utils.data.TensorDataset(torch.FloatTensor(data['x_train']), torch.FloatTensor(data['y_train'])),
418
+ batch_size=args.batch_size,
419
+ shuffle=True
420
+ )
421
+ data['val_loader'] = torch.utils.data.DataLoader(
422
+ torch.utils.data.TensorDataset(torch.FloatTensor(data['x_val']), torch.FloatTensor(data['y_val'])),
423
+ batch_size=args.batch_size,
424
+ shuffle=False
425
+ )
426
+ data['test_loader'] = torch.utils.data.DataLoader(
427
+ torch.utils.data.TensorDataset(torch.FloatTensor(data['x_test']), torch.FloatTensor(data['y_test'])),
428
+ batch_size=args.batch_size,
429
+ shuffle=False
430
+ )
431
+
432
+ def main():
433
+ logger.info(args.dataset, 'training and testing started', time.ctime())
434
+ logger.info('train xs.shape, ys.shape', data['x_train'].shape, data['y_train'].shape)
435
+ logger.info('val xs.shape, ys.shape', data['x_val'].shape, data['y_val'].shape)
436
+ logger.info('test xs.shape, ys.shape', data['x_test'].shape, data['y_test'].shape)
437
+ traintest_model()
438
+ logger.info(args.dataset, 'training and testing ended', time.ctime())
439
+
440
+ if __name__ == '__main__':
441
+ main()
442
+
model_STSSDL/utils.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import scipy.sparse as sp
6
+ from scipy.sparse import linalg
7
+
8
+ class DataLoader(object):
9
+ def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
10
+ """
11
+
12
+ :param xs:
13
+ :param ys:
14
+ :param batch_size:
15
+ :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
16
+ """
17
+ self.batch_size = batch_size
18
+ self.current_ind = 0
19
+ if pad_with_last_sample:
20
+ num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
21
+ x_padding = np.repeat(xs[-1:], num_padding, axis=0)
22
+ y_padding = np.repeat(ys[-1:], num_padding, axis=0)
23
+ xs = np.concatenate([xs, x_padding], axis=0)
24
+ ys = np.concatenate([ys, y_padding], axis=0)
25
+ self.size = len(xs)
26
+ self.num_batch = int(self.size // self.batch_size)
27
+ if shuffle:
28
+ permutation = np.random.permutation(self.size)
29
+ xs, ys = xs[permutation], ys[permutation]
30
+ self.xs = xs
31
+ self.ys = ys
32
+
33
+ def get_iterator(self):
34
+ self.current_ind = 0
35
+
36
+ def _wrapper():
37
+ while self.current_ind < self.num_batch:
38
+ start_ind = self.batch_size * self.current_ind
39
+ end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
40
+ x_i = self.xs[start_ind: end_ind, ...]
41
+ y_i = self.ys[start_ind: end_ind, ...]
42
+ yield (x_i, y_i)
43
+ self.current_ind += 1
44
+
45
+ return _wrapper()
46
+
47
+ class StandardScaler():
48
+ def __init__(self, mean, std):
49
+ self.mean = mean
50
+ self.std = std
51
+
52
+ def transform(self, data):
53
+ return (data - self.mean) / self.std
54
+
55
+ def inverse_transform(self, data):
56
+ return (data * self.std) + self.mean
57
+
58
+ def getTimestamp(data):
59
+ num_samples, num_nodes = data.shape
60
+ time_ind = (data.index.values - data.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D")
61
+ time_in_day = np.tile(time_ind, [num_nodes,1]).transpose((1, 0))
62
+ return time_in_day
63
+
64
+ def getDayTimestamp(data):
65
+ # 288 timeslots each day for dataset has 5 minutes time interval.
66
+ df = pd.DataFrame({'timestamp':data.index.values})
67
+ df['weekdaytime'] = df['timestamp'].dt.weekday * 288 + (df['timestamp'].dt.hour * 60 + df['timestamp'].dt.minute)//5
68
+ df['weekdaytime'] = df['weekdaytime'] / df['weekdaytime'].max()
69
+ num_samples, num_nodes = data.shape
70
+ time_ind = df['weekdaytime'].values
71
+ time_ind_node = np.tile(time_ind, [num_nodes,1]).transpose((1, 0))
72
+ return time_ind_node
73
+
74
+ def getDayTimestamp_(start, end, freq, num_nodes):
75
+ # 288 timeslots each day for dataset has 5 minutes time interval.
76
+ df = pd.DataFrame({'timestamp':pd.date_range(start=start, end=end, freq=freq)})
77
+ df['weekdaytime'] = df['timestamp'].dt.weekday * 288 + (df['timestamp'].dt.hour * 60 + df['timestamp'].dt.minute)//5
78
+ df['weekdaytime'] = df['weekdaytime'] / df['weekdaytime'].max()
79
+ time_ind = df['weekdaytime'].values
80
+ time_ind_node = np.tile(time_ind, [num_nodes, 1]).transpose((1, 0))
81
+ return time_ind_node
82
+
83
+ def masked_mse(preds, labels, null_val=1e-3):
84
+ if np.isnan(null_val):
85
+ mask = ~torch.isnan(labels)
86
+ else:
87
+ mask = (labels > null_val)
88
+ mask = mask.float()
89
+ mask /= torch.mean((mask))
90
+ mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
91
+ loss = (preds-labels)**2
92
+ loss = loss * mask
93
+ loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
94
+ return torch.mean(loss)
95
+
96
+ def masked_rmse(preds, labels, null_val=1e-3):
97
+ return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val))
98
+
99
+
100
+ def masked_mae(preds, labels, null_val=1e-3):
101
+ if np.isnan(null_val):
102
+ mask = ~torch.isnan(labels)
103
+ else:
104
+ mask = (labels > null_val)
105
+ mask = mask.float()
106
+ mask /= torch.mean((mask))
107
+ mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
108
+ loss = torch.abs(preds-labels)
109
+ loss = loss * mask
110
+ loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
111
+ return torch.mean(loss)
112
+
113
+
114
+ def masked_mape(preds, labels, null_val=1e-3):
115
+ if np.isnan(null_val):
116
+ mask = ~torch.isnan(labels)
117
+ else:
118
+ mask = (labels > null_val)
119
+ mask = mask.float()
120
+ mask /= torch.mean((mask))
121
+ mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
122
+ loss = torch.abs(preds-labels)/labels
123
+ loss = loss * mask
124
+ loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
125
+ return torch.mean(loss)
126
+
127
+ # DCRNN
128
+ def masked_mae_loss(y_pred, y_true):
129
+ mask = (y_true != 0).float()
130
+ mask /= mask.mean()
131
+ loss = torch.abs(y_pred - y_true)
132
+ loss = loss * mask
133
+ # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
134
+ loss[loss != loss] = 0
135
+ return loss.mean()
136
+
137
+ def masked_mape_loss(y_pred, y_true):
138
+ mask = (y_true != 0).float()
139
+ mask /= mask.mean()
140
+ loss = torch.abs(torch.div(y_true - y_pred, y_true))
141
+ loss = loss * mask
142
+ # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
143
+ loss[loss != loss] = 0
144
+ return loss.mean()
145
+
146
+ def masked_rmse_loss(y_pred, y_true):
147
+ mask = (y_true != 0).float()
148
+ mask /= mask.mean()
149
+ loss = torch.pow(y_true - y_pred, 2)
150
+ loss = loss * mask
151
+ # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
152
+ loss[loss != loss] = 0
153
+ return torch.sqrt(loss.mean())
154
+
155
+ def masked_mse_loss(y_pred, y_true):
156
+ mask = (y_true != 0).float()
157
+ mask /= mask.mean()
158
+ loss = torch.pow(y_true - y_pred, 2)
159
+ loss = loss * mask
160
+ # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
161
+ loss[loss != loss] = 0
162
+ return loss.mean()
163
+
164
+ def load_pickle(pickle_file):
165
+ try:
166
+ with open(pickle_file, 'rb') as f:
167
+ pickle_data = pickle.load(f)
168
+ except UnicodeDecodeError as e:
169
+ with open(pickle_file, 'rb') as f:
170
+ pickle_data = pickle.load(f, encoding='latin1')
171
+ except Exception as e:
172
+ print('Unable to load data ', pickle_file, ':', e)
173
+ raise
174
+ return pickle_data
175
+
176
+ def sym_adj(adj):
177
+ """Symmetrically normalize adjacency matrix."""
178
+ adj = sp.coo_matrix(adj)
179
+ rowsum = np.array(adj.sum(1))
180
+ d_inv_sqrt = np.power(rowsum, -0.5).flatten()
181
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
182
+ d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
183
+ return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense()
184
+
185
+ def asym_adj(adj):
186
+ adj = sp.coo_matrix(adj)
187
+ rowsum = np.array(adj.sum(1)).flatten()
188
+ d_inv = np.power(rowsum, -1).flatten()
189
+ d_inv[np.isinf(d_inv)] = 0.
190
+ d_mat = sp.diags(d_inv)
191
+ return d_mat.dot(adj).astype(np.float32).todense()
192
+
193
+ def calculate_normalized_laplacian(adj):
194
+ """
195
+ # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
196
+ # D = diag(A 1)
197
+ :param adj:
198
+ :return:
199
+ """
200
+ adj = sp.coo_matrix(adj)
201
+ d = np.array(adj.sum(1))
202
+ d_inv_sqrt = np.power(d, -0.5).flatten()
203
+ d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
204
+ d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
205
+ normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
206
+ return normalized_laplacian
207
+
208
+ def calculate_random_walk_matrix(adj_mx):
209
+ adj_mx = sp.coo_matrix(adj_mx)
210
+ d = np.array(adj_mx.sum(1))
211
+ d_inv = np.power(d, -1).flatten()
212
+ d_inv[np.isinf(d_inv)] = 0.
213
+ d_mat_inv = sp.diags(d_inv)
214
+ random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
215
+ return random_walk_mx
216
+
217
+ def calculate_reverse_random_walk_matrix(adj_mx):
218
+ return calculate_random_walk_matrix(np.transpose(adj_mx))
219
+
220
+ def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
221
+ if undirected:
222
+ adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
223
+ L = calculate_normalized_laplacian(adj_mx)
224
+ if lambda_max is None:
225
+ lambda_max, _ = linalg.eigsh(L, 1, which='LM')
226
+ lambda_max = lambda_max[0]
227
+ L = sp.csr_matrix(L)
228
+ M, _ = L.shape
229
+ I = sp.identity(M, format='csr', dtype=L.dtype)
230
+ L = (2 / lambda_max * L) - I
231
+ return L.astype(np.float32)
232
+
233
+ def load_adj(pkl_filename, adjtype):
234
+ if "PEMS0" in pkl_filename or "D7" in pkl_filename:
235
+ adj_mx = load_pickle(pkl_filename)
236
+ else:
237
+ sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
238
+ if adjtype == "scalap":
239
+ adj = [calculate_scaled_laplacian(adj_mx)]
240
+ elif adjtype == "normlap":
241
+ adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32).todense()]
242
+ elif adjtype == "symadj":
243
+ adj = [sym_adj(adj_mx)]
244
+ elif adjtype == "transition":
245
+ adj = [asym_adj(adj_mx)]
246
+ elif adjtype == "doubletransition":
247
+ adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))]
248
+ elif adjtype == "identity":
249
+ adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)]
250
+ else:
251
+ error = 0
252
+ assert error, "adj type not defined"
253
+ return adj
254
+
255
+ def print_params(model):
256
+ # print trainable params
257
+ param_count = 0
258
+ print('Trainable parameter list:')
259
+ for name, param in model.named_parameters():
260
+ if param.requires_grad:
261
+ print(name, param.shape, param.numel())
262
+ param_count += param.numel()
263
+ print(f'\n In total: {param_count} trainable parameters. \n')
264
+ return