Upload 4 files
Browse files- model_STSSDL/STSSDL.py +326 -0
- model_STSSDL/metrics.py +38 -0
- model_STSSDL/train_STSSDL.py +442 -0
- model_STSSDL/utils.py +264 -0
model_STSSDL/STSSDL.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class AGCN(nn.Module):
|
| 8 |
+
def __init__(self, dim_in, dim_out, cheb_k, num_support):
|
| 9 |
+
super(AGCN, self).__init__()
|
| 10 |
+
self.cheb_k = cheb_k
|
| 11 |
+
self.weights = nn.Parameter(torch.FloatTensor(num_support*cheb_k*dim_in, dim_out)) # num_support*cheb_k*dim_in is the length of support
|
| 12 |
+
# self.weights = nn.Parameter(torch.FloatTensor(dim_in, dim_out))
|
| 13 |
+
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
| 14 |
+
nn.init.xavier_normal_(self.weights)
|
| 15 |
+
nn.init.constant_(self.bias, val=0)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, supports):
|
| 18 |
+
x_g = []
|
| 19 |
+
for support in supports:
|
| 20 |
+
if len(support.shape) == 2:
|
| 21 |
+
support_ks = [torch.eye(support.shape[0]).to(support.device), support]
|
| 22 |
+
for k in range(2, self.cheb_k):
|
| 23 |
+
support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
|
| 24 |
+
for graph in support_ks:
|
| 25 |
+
x_g.append(torch.einsum("nm,bmc->bnc", graph, x))
|
| 26 |
+
else:
|
| 27 |
+
support_ks = [torch.eye(support.shape[1]).repeat(support.shape[0], 1, 1).to(support.device), support]
|
| 28 |
+
for k in range(2, self.cheb_k):
|
| 29 |
+
support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
|
| 30 |
+
for graph in support_ks:
|
| 31 |
+
x_g.append(torch.einsum("bnm,bmc->bnc", graph, x))
|
| 32 |
+
x_g = torch.cat(x_g, dim=-1)
|
| 33 |
+
x_gconv = torch.einsum('bni,io->bno', x_g, self.weights) + self.bias # b, N, dim_out
|
| 34 |
+
return x_gconv
|
| 35 |
+
|
| 36 |
+
class AGCRNCell(nn.Module):
|
| 37 |
+
def __init__(self, node_num, dim_in, dim_out, cheb_k, num_support):
|
| 38 |
+
super(AGCRNCell, self).__init__()
|
| 39 |
+
self.node_num = node_num
|
| 40 |
+
self.hidden_dim = dim_out
|
| 41 |
+
self.gate = AGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, num_support)
|
| 42 |
+
self.update = AGCN(dim_in+self.hidden_dim, dim_out, cheb_k, num_support)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, state, supports):
|
| 45 |
+
#x: B, num_nodes, input_dim
|
| 46 |
+
#state: B, num_nodes, hidden_dim
|
| 47 |
+
state = state.to(x.device)
|
| 48 |
+
input_and_state = torch.cat((x, state), dim=-1)
|
| 49 |
+
z_r = torch.sigmoid(self.gate(input_and_state, supports))
|
| 50 |
+
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
| 51 |
+
candidate = torch.cat((x, z*state), dim=-1)
|
| 52 |
+
hc = torch.tanh(self.update(candidate, supports))
|
| 53 |
+
h = r*state + (1-r)*hc
|
| 54 |
+
return h
|
| 55 |
+
|
| 56 |
+
def init_hidden_state(self, batch_size):
|
| 57 |
+
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
| 58 |
+
|
| 59 |
+
class ADCRNN_Encoder(nn.Module):
|
| 60 |
+
def __init__(self, node_num, dim_in, dim_out, cheb_k, rnn_layers, num_support):
|
| 61 |
+
super(ADCRNN_Encoder, self).__init__()
|
| 62 |
+
assert rnn_layers >= 1, 'At least one DCRNN layer in the Encoder.'
|
| 63 |
+
self.node_num = node_num
|
| 64 |
+
self.input_dim = dim_in
|
| 65 |
+
self.rnn_layers = rnn_layers
|
| 66 |
+
self.dcrnn_cells = nn.ModuleList()
|
| 67 |
+
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, num_support))
|
| 68 |
+
for _ in range(1, rnn_layers):
|
| 69 |
+
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, num_support))
|
| 70 |
+
|
| 71 |
+
def forward(self, x, init_state, supports):
|
| 72 |
+
#shape of x: (B, T, N, D), shape of init_state: (rnn_layers, B, N, hidden_dim)
|
| 73 |
+
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
| 74 |
+
seq_length = x.shape[1]
|
| 75 |
+
current_inputs = x
|
| 76 |
+
output_hidden = []
|
| 77 |
+
for i in range(self.rnn_layers):
|
| 78 |
+
state = init_state[i]
|
| 79 |
+
inner_states = []
|
| 80 |
+
for t in range(seq_length):
|
| 81 |
+
state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, supports)
|
| 82 |
+
inner_states.append(state)
|
| 83 |
+
output_hidden.append(state)
|
| 84 |
+
current_inputs = torch.stack(inner_states, dim=1)
|
| 85 |
+
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
|
| 86 |
+
#last_state: (B, N, hidden_dim)
|
| 87 |
+
#output_hidden: the last state for each layer: (rnn_layers, B, N, hidden_dim)
|
| 88 |
+
#return current_inputs, torch.stack(output_hidden, dim=0)
|
| 89 |
+
return current_inputs, output_hidden
|
| 90 |
+
|
| 91 |
+
def init_hidden(self, batch_size):
|
| 92 |
+
init_states = []
|
| 93 |
+
for i in range(self.rnn_layers):
|
| 94 |
+
init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
|
| 95 |
+
return init_states
|
| 96 |
+
|
| 97 |
+
class ADCRNN_Decoder(nn.Module):
|
| 98 |
+
def __init__(self, node_num, dim_in, dim_out, cheb_k, rnn_layers, num_support):
|
| 99 |
+
super(ADCRNN_Decoder, self).__init__()
|
| 100 |
+
assert rnn_layers >= 1, 'At least one DCRNN layer in the Decoder.'
|
| 101 |
+
self.node_num = node_num
|
| 102 |
+
self.input_dim = dim_in
|
| 103 |
+
self.rnn_layers = rnn_layers
|
| 104 |
+
self.dcrnn_cells = nn.ModuleList()
|
| 105 |
+
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, num_support))
|
| 106 |
+
for _ in range(1, rnn_layers):
|
| 107 |
+
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, num_support))
|
| 108 |
+
|
| 109 |
+
def forward(self, xt, init_state, supports):
|
| 110 |
+
# xt: (B, N, D)
|
| 111 |
+
# init_state: (rnn_layers, B, N, hidden_dim)
|
| 112 |
+
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
|
| 113 |
+
current_inputs = xt
|
| 114 |
+
output_hidden = []
|
| 115 |
+
for i in range(self.rnn_layers):
|
| 116 |
+
state = self.dcrnn_cells[i](current_inputs, init_state[i], supports)
|
| 117 |
+
output_hidden.append(state)
|
| 118 |
+
current_inputs = state
|
| 119 |
+
return current_inputs, output_hidden
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class STSSDL(nn.Module):
|
| 123 |
+
def __init__(self, num_nodes=207, input_dim=1, output_dim=1, horizon=12, rnn_units=128, rnn_layers=1, cheb_k=3,
|
| 124 |
+
ycov_dim=1, prototype_num=20, prototype_dim=64, tod_embed_dim=10, adj_mx=None, cl_decay_steps=2000,
|
| 125 |
+
TDAY=288, use_curriculum_learning=True, use_STE=False, device="cpu",adaptive_embedding_dim=48,node_embedding_dim=20,input_embedding_dim=128):
|
| 126 |
+
super(STSSDL, self).__init__()
|
| 127 |
+
self.num_nodes = num_nodes
|
| 128 |
+
self.input_dim = input_dim
|
| 129 |
+
self.rnn_units = rnn_units
|
| 130 |
+
self.output_dim = output_dim
|
| 131 |
+
self.horizon = horizon
|
| 132 |
+
self.rnn_layers = rnn_layers
|
| 133 |
+
self.cheb_k = cheb_k
|
| 134 |
+
self.ycov_dim = ycov_dim
|
| 135 |
+
self.tod_embed_dim = tod_embed_dim
|
| 136 |
+
self.cl_decay_steps = cl_decay_steps
|
| 137 |
+
self.use_curriculum_learning = use_curriculum_learning
|
| 138 |
+
self.device = device
|
| 139 |
+
self.use_STE = use_STE
|
| 140 |
+
self.TDAY = TDAY
|
| 141 |
+
self.adaptive_embedding_dim=adaptive_embedding_dim
|
| 142 |
+
self.node_embedding_dim = node_embedding_dim
|
| 143 |
+
self.input_embedding_dim=input_embedding_dim
|
| 144 |
+
self.total_embedding_dim= self.tod_embed_dim+self.adaptive_embedding_dim+self.node_embedding_dim
|
| 145 |
+
# prototypes
|
| 146 |
+
self.prototype_num = prototype_num
|
| 147 |
+
self.prototype_dim = prototype_dim
|
| 148 |
+
self.prototypes = self.construct_prototypes()
|
| 149 |
+
|
| 150 |
+
# projection & spatio-temporal embedding
|
| 151 |
+
if self.use_STE:
|
| 152 |
+
if self.adaptive_embedding_dim > 0:
|
| 153 |
+
self.adaptive_embedding = nn.init.xavier_uniform_(
|
| 154 |
+
nn.Parameter(torch.empty(12, num_nodes, self.adaptive_embedding_dim))
|
| 155 |
+
)
|
| 156 |
+
self.input_proj = nn.Linear(self.input_dim, input_embedding_dim)
|
| 157 |
+
self.node_embedding = nn.Parameter(torch.empty(self.num_nodes, self.node_embedding_dim))
|
| 158 |
+
self.time_embedding = nn.Parameter(torch.empty(self.TDAY, self.tod_embed_dim))
|
| 159 |
+
nn.init.xavier_uniform_(self.node_embedding)
|
| 160 |
+
nn.init.xavier_uniform_(self.time_embedding)
|
| 161 |
+
|
| 162 |
+
# encoder
|
| 163 |
+
self.adj_mx = adj_mx
|
| 164 |
+
if self.use_STE:
|
| 165 |
+
self.encoder = ADCRNN_Encoder(self.num_nodes, input_embedding_dim + self.total_embedding_dim, self.rnn_units, self.cheb_k, self.rnn_layers, len(self.adj_mx))
|
| 166 |
+
else:
|
| 167 |
+
self.encoder = ADCRNN_Encoder(self.num_nodes, self.input_dim, self.rnn_units, self.cheb_k, self.rnn_layers, len(self.adj_mx))
|
| 168 |
+
|
| 169 |
+
# decoder
|
| 170 |
+
self.decoder_dim = self.rnn_units + self.prototype_dim
|
| 171 |
+
if self.use_STE:
|
| 172 |
+
self.decoder = ADCRNN_Decoder(self.num_nodes, input_embedding_dim + self.total_embedding_dim-self.adaptive_embedding_dim, self.decoder_dim, self.cheb_k, self.rnn_layers, 1)
|
| 173 |
+
else:
|
| 174 |
+
self.decoder = ADCRNN_Decoder(self.num_nodes, self.output_dim + self.ycov_dim, self.decoder_dim, self.cheb_k, self.rnn_layers, 1)
|
| 175 |
+
|
| 176 |
+
# output
|
| 177 |
+
self.proj = nn.Sequential(nn.Linear(self.decoder_dim, self.output_dim, bias=True))
|
| 178 |
+
|
| 179 |
+
# graph
|
| 180 |
+
self.hypernet = nn.Sequential(nn.Linear(self.decoder_dim*2, self.tod_embed_dim, bias=True))
|
| 181 |
+
|
| 182 |
+
self.act_dict = {'relu': nn.ReLU(), 'lrelu': nn.LeakyReLU(), 'sigmoid': nn.Sigmoid()}
|
| 183 |
+
self.act_fn = 'sigmoid' # 'relu' 'lrelu' 'sigmoid'
|
| 184 |
+
|
| 185 |
+
def compute_sampling_threshold(self, batches_seen):
|
| 186 |
+
return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
|
| 187 |
+
|
| 188 |
+
def construct_prototypes(self):
|
| 189 |
+
prototypes_dict = nn.ParameterDict()
|
| 190 |
+
prototype = torch.randn(self.prototype_num, self.prototype_dim)
|
| 191 |
+
prototypes_dict['prototypes'] = nn.Parameter(prototype, requires_grad=True) # (M, d)
|
| 192 |
+
prototypes_dict['Wq'] = nn.Parameter(torch.randn(self.rnn_units, self.prototype_dim), requires_grad=True) # project to query
|
| 193 |
+
for param in prototypes_dict.values():
|
| 194 |
+
nn.init.xavier_normal_(param)
|
| 195 |
+
|
| 196 |
+
return prototypes_dict
|
| 197 |
+
|
| 198 |
+
def query_prototypes(self, h_t:torch.Tensor):
|
| 199 |
+
query = torch.matmul(h_t, self.prototypes['Wq']) # (B, N, d)
|
| 200 |
+
att_score = torch.softmax(torch.matmul(query, self.prototypes['prototypes'].t()), dim=-1) # alpha: (B, N, M)
|
| 201 |
+
value = torch.matmul(att_score, self.prototypes['prototypes']) # (B, N, d)
|
| 202 |
+
_, ind = torch.topk(att_score, k=2, dim=-1)
|
| 203 |
+
pos = self.prototypes['prototypes'][ind[:, :, 0]] # B, N, d
|
| 204 |
+
neg = self.prototypes['prototypes'][ind[:, :, 1]] # B, N, d
|
| 205 |
+
mask = torch.stack([ind[:, :, 0], ind[:, :, 1]], dim=-1) # B, N, 2
|
| 206 |
+
|
| 207 |
+
return value, query, pos, neg, mask
|
| 208 |
+
|
| 209 |
+
def calculate_distance(self, pos, pos_his, mask=None):
|
| 210 |
+
score = torch.sum(torch.abs(pos - pos_his), dim=-1)
|
| 211 |
+
return score, mask
|
| 212 |
+
|
| 213 |
+
def forward(self, x, x_cov, x_his, y_cov, labels=None, batches_seen=None):
|
| 214 |
+
if self.use_STE:
|
| 215 |
+
if self.input_embedding_dim>0:
|
| 216 |
+
x = self.input_proj(x) # [B,T,N,1]->[B,T,N,D]
|
| 217 |
+
features = [x]
|
| 218 |
+
|
| 219 |
+
tod = x_cov.squeeze() # [B, T, N]
|
| 220 |
+
if self.tod_embed_dim>0:
|
| 221 |
+
time_emb = self.time_embedding[(x_cov.squeeze() * self.TDAY).type(torch.LongTensor)] # [B, T, N, d]
|
| 222 |
+
features.append(time_emb)
|
| 223 |
+
if self.adaptive_embedding_dim > 0:
|
| 224 |
+
adp_emb = self.adaptive_embedding.expand(
|
| 225 |
+
size=(x.shape[0], *self.adaptive_embedding.shape)
|
| 226 |
+
)
|
| 227 |
+
features.append(adp_emb)
|
| 228 |
+
if self.node_embedding_dim>0:
|
| 229 |
+
node_emb = self.node_embedding.unsqueeze(0).unsqueeze(1).expand(x.shape[0], self.horizon, -1, -1) # [B,T,N,d]
|
| 230 |
+
features.append(node_emb)
|
| 231 |
+
x = torch.cat(features, dim=-1) # [B, T, N, D+d+80]
|
| 232 |
+
supports_en = self.adj_mx
|
| 233 |
+
init_state = self.encoder.init_hidden(x.shape[0])
|
| 234 |
+
h_en, state_en = self.encoder(x, init_state, supports_en) # B, T, N, hidden
|
| 235 |
+
h_t = h_en[:, -1, :, :] # B, N, hidden (last state)
|
| 236 |
+
v_t, q_t, p_t, n_t, mask = self.query_prototypes(h_t)
|
| 237 |
+
if self.use_STE:
|
| 238 |
+
if self.input_embedding_dim>0:
|
| 239 |
+
x_his = self.input_proj(x_his) # [B,T,N,1]->[B,T,N,D]
|
| 240 |
+
features = [x_his]
|
| 241 |
+
tod = x_cov.squeeze() # [B, T, N]
|
| 242 |
+
if self.tod_embed_dim>0:
|
| 243 |
+
time_emb = self.time_embedding[(x_cov.squeeze() * self.TDAY).type(torch.LongTensor)] # [B, T, N, d]
|
| 244 |
+
|
| 245 |
+
features.append(time_emb)
|
| 246 |
+
if self.adaptive_embedding_dim > 0:
|
| 247 |
+
adp_emb = self.adaptive_embedding.expand(
|
| 248 |
+
size=(x.shape[0], *self.adaptive_embedding.shape)
|
| 249 |
+
)
|
| 250 |
+
features.append(adp_emb)
|
| 251 |
+
if self.node_embedding_dim>0:
|
| 252 |
+
node_emb = self.node_embedding.unsqueeze(0).unsqueeze(1).expand(x.shape[0], self.horizon, -1, -1) # [B,T,N,d]
|
| 253 |
+
features.append(node_emb)
|
| 254 |
+
x_his = torch.cat(features, dim=-1) # [B, T, N, D+d+80]
|
| 255 |
+
h_his_en, state_his_en = self.encoder(x_his, init_state, supports_en) # B, T, N, hidden
|
| 256 |
+
h_a = h_his_en[:, -1, :, :] # B, N, hidden (last state)
|
| 257 |
+
v_a, q_a, p_a, n_a, mask_his = self.query_prototypes(h_a)
|
| 258 |
+
|
| 259 |
+
latent_dis, _ = self.calculate_distance(q_t, q_a)
|
| 260 |
+
prototype_dis, mask_dis = self.calculate_distance(p_t, p_a)
|
| 261 |
+
|
| 262 |
+
query = torch.stack([q_t, q_a], dim=0)
|
| 263 |
+
pos = torch.stack([p_t, p_a], dim=0)
|
| 264 |
+
neg = torch.stack([n_t, n_a], dim=0)
|
| 265 |
+
mask = torch.stack([mask, mask_his], dim=0) if mask is not None else [None, None]
|
| 266 |
+
|
| 267 |
+
h_de = torch.cat([h_t, v_t], dim=-1)
|
| 268 |
+
h_aug = torch.cat([h_t, v_t, h_a, v_a], dim=-1) # B, N, D
|
| 269 |
+
|
| 270 |
+
node_embeddings = self.hypernet(h_aug) # B, N, e
|
| 271 |
+
support = F.softmax(F.relu(torch.einsum('bnc,bmc->bnm', node_embeddings, node_embeddings)), dim=-1)
|
| 272 |
+
supports_de = [support]
|
| 273 |
+
|
| 274 |
+
ht_list = [h_de]*self.rnn_layers
|
| 275 |
+
go = torch.zeros((x.shape[0], self.num_nodes, self.output_dim), device=x.device)
|
| 276 |
+
|
| 277 |
+
out = []
|
| 278 |
+
for t in range(self.horizon):
|
| 279 |
+
if self.use_STE:
|
| 280 |
+
if self.input_embedding_dim>0:
|
| 281 |
+
go = self.input_proj(go) # equal to torch.zeros(B,N,D)
|
| 282 |
+
features = [go]
|
| 283 |
+
tod = y_cov[:, t, ...].squeeze() # [B, T, N]
|
| 284 |
+
if self.tod_embed_dim>0:
|
| 285 |
+
time_emb = self.time_embedding[(tod * self.TDAY).type(torch.LongTensor)]
|
| 286 |
+
features.append(time_emb)
|
| 287 |
+
if self.node_embedding_dim>0:
|
| 288 |
+
node_emb = self.node_embedding.unsqueeze(0).expand(x.shape[0], -1, -1) # [B,N,d]
|
| 289 |
+
features.append(node_emb)
|
| 290 |
+
go = torch.cat(features, dim=-1) # [B, T, N, D+d]
|
| 291 |
+
h_de, ht_list = self.decoder(go, ht_list, supports_de)
|
| 292 |
+
else:
|
| 293 |
+
h_de, ht_list = self.decoder(torch.cat([go, y_cov[:, t, ...]], dim=-1), ht_list, supports_de)
|
| 294 |
+
go = self.proj(h_de)
|
| 295 |
+
out.append(go)
|
| 296 |
+
if self.training and self.use_curriculum_learning:
|
| 297 |
+
c = np.random.uniform(0, 1)
|
| 298 |
+
if c < self.compute_sampling_threshold(batches_seen):
|
| 299 |
+
go = labels[:, t, ...]
|
| 300 |
+
|
| 301 |
+
output = torch.stack(out, dim=1)
|
| 302 |
+
|
| 303 |
+
return output, query, pos, neg, mask, latent_dis, prototype_dis
|
| 304 |
+
|
| 305 |
+
def print_params(model):
|
| 306 |
+
# print trainable params
|
| 307 |
+
param_count = 0
|
| 308 |
+
print('Trainable parameter list:')
|
| 309 |
+
for name, param in model.named_parameters():
|
| 310 |
+
if param.requires_grad:
|
| 311 |
+
print(name, param.shape, param.numel())
|
| 312 |
+
param_count += param.numel()
|
| 313 |
+
print(f'In total: {param_count} trainable parameters.')
|
| 314 |
+
return
|
| 315 |
+
|
| 316 |
+
def main():
|
| 317 |
+
from torchinfo import summary
|
| 318 |
+
from utils import load_adj
|
| 319 |
+
|
| 320 |
+
adj_mx = load_adj('../METRLA/adj_mx.pkl', "symadj")
|
| 321 |
+
adj_mx = [torch.FloatTensor(i) for i in adj_mx]
|
| 322 |
+
model = STSSDL(adj_mx=adj_mx)
|
| 323 |
+
summary(model, [[8, 12, 207, 1], [8, 12, 207, 1], [8, 12, 207, 1], [8, 12, 207, 1]], device="cpu")
|
| 324 |
+
|
| 325 |
+
if __name__ == '__main__':
|
| 326 |
+
main()
|
model_STSSDL/metrics.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def MSE(y_true, y_pred, mask=None):
|
| 4 |
+
|
| 5 |
+
mse = torch.square(y_pred - y_true)
|
| 6 |
+
if mask is not None:
|
| 7 |
+
mse_ig = mse * (1 - mask) # max: 0.5 (0.9422?) min:0, but max should be close 0
|
| 8 |
+
mse = mse * mask # max: 0.65, min:0
|
| 9 |
+
anomoly_num = torch.sum(mask)
|
| 10 |
+
normal_num = torch.sum(1 - mask) # in general, normal = 3 * anomoly
|
| 11 |
+
mse_ig = torch.mean(mse_ig) # mae_ig = 3 * mae alought max:0.9422, mean:0.04, i.e., most is accurately predicted
|
| 12 |
+
mse = torch.mean(mse)
|
| 13 |
+
|
| 14 |
+
return mse
|
| 15 |
+
|
| 16 |
+
def RMSE(y_true, y_pred, mask=None):
|
| 17 |
+
|
| 18 |
+
rmse = torch.square(torch.abs(y_pred - y_true))
|
| 19 |
+
rmse = torch.sqrt(MSE(y_true, y_pred, mask))
|
| 20 |
+
return rmse
|
| 21 |
+
|
| 22 |
+
def MAE(y_true, y_pred, mask=None):
|
| 23 |
+
|
| 24 |
+
mae = torch.abs(y_pred - y_true)
|
| 25 |
+
if mask is not None:
|
| 26 |
+
mae_ig = mae * (1 - mask) # max: 0.5 (0.9422?) min:0, but max should be close 0
|
| 27 |
+
mae = mae * mask # max: 0.65, min:0
|
| 28 |
+
anomoly_num = torch.sum(mask)
|
| 29 |
+
normal_num = torch.sum(1 - mask) # in general, normal = 3 * anomoly
|
| 30 |
+
mae_ig = torch.mean(mae_ig) # mae_ig = 3 * mae alought max:0.9422, mean:0.04, i.e., most is accurately predicted
|
| 31 |
+
mae = torch.mean(mae)
|
| 32 |
+
|
| 33 |
+
return mae
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
model_STSSDL/train_STSSDL.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.init as init
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torchinfo import summary
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
from utils import StandardScaler, masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
|
| 15 |
+
from utils import load_adj
|
| 16 |
+
from metrics import RMSE, MAE, MSE
|
| 17 |
+
from STSSDL import STSSDL
|
| 18 |
+
import random
|
| 19 |
+
class ContrastiveLoss():
|
| 20 |
+
def __init__(self, contra_loss='triplet', mask=None, temp=1.0, margin=0.5):
|
| 21 |
+
self.infonce = contra_loss in ['infonce']
|
| 22 |
+
self.mask = mask
|
| 23 |
+
self.temp = temp
|
| 24 |
+
self.margin = margin
|
| 25 |
+
|
| 26 |
+
def calculate(self, query, pos, neg, mask):
|
| 27 |
+
"""
|
| 28 |
+
:param query: shape (batch_size, num_sensor, hidden_dim)
|
| 29 |
+
:param pos: shape (batch_size, num_sensor, hidden_dim)
|
| 30 |
+
:param neg: shape (batch_size, num_sensor, hidden_dim) or (batch_size, num_sensor, num_prototypes, hidden_dim)
|
| 31 |
+
:param mask: shape (batch_size, num_sensor, num_prototypes) True means positives
|
| 32 |
+
"""
|
| 33 |
+
contrastive_loss = nn.TripletMarginLoss(margin=self.margin)
|
| 34 |
+
return contrastive_loss(query.detach(), pos, neg)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def print_model(model):
|
| 38 |
+
param_count = 0
|
| 39 |
+
logger.info('Trainable parameter list:')
|
| 40 |
+
for name, param in model.named_parameters():
|
| 41 |
+
if param.requires_grad:
|
| 42 |
+
print(name, param.shape, param.numel())
|
| 43 |
+
param_count += param.numel()
|
| 44 |
+
logger.info(f'In total: {param_count} trainable parameters.')
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
def get_model():
|
| 48 |
+
adj_mx = load_adj(adj_mx_path, args.adj_type)
|
| 49 |
+
adjs = [torch.tensor(i).to(device) for i in adj_mx]
|
| 50 |
+
model = STSSDL(num_nodes=args.num_nodes, input_dim=args.input_dim, output_dim=args.output_dim, horizon=args.horizon,
|
| 51 |
+
rnn_units=args.rnn_units, rnn_layers=args.rnn_layers, cheb_k = args.cheb_k, prototype_num=args.prototype_num,
|
| 52 |
+
prototype_dim=args.prototype_dim, tod_embed_dim=args.tod_embed_dim, adj_mx = adjs, cl_decay_steps=args.cl_decay_steps,
|
| 53 |
+
use_curriculum_learning=args.use_curriculum_learning, use_STE=args.use_STE, adaptive_embedding_dim=args.adaptive_embedding_dim,node_embedding_dim=args.node_embedding_dim,input_embedding_dim=args.input_embedding_dim,device=device).to(device)
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
def prepare_x_y(x, y):
|
| 57 |
+
"""
|
| 58 |
+
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
| 59 |
+
:param y: shape (batch_size, horizon, num_sensor, input_dim)
|
| 60 |
+
:return1: x shape (seq_len, batch_size, num_sensor, input_dim)
|
| 61 |
+
y shape (horizon, batch_size, num_sensor, input_dim)
|
| 62 |
+
:return2: x: shape (seq_len, batch_size, num_sensor * input_dim)
|
| 63 |
+
y: shape (horizon, batch_size, num_sensor * output_dim)
|
| 64 |
+
"""
|
| 65 |
+
x0 = x[..., 0:1]
|
| 66 |
+
x1 = x[..., 1:2]
|
| 67 |
+
x2 = x[..., 2:3]
|
| 68 |
+
y0 = y[..., 0:1]
|
| 69 |
+
y1 = y[..., 1:2]
|
| 70 |
+
return x0, x1, x2, y0, y1 # x, x_cov, x_his, y, y_cov
|
| 71 |
+
|
| 72 |
+
def evaluate(model, mode):
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
model = model.eval()
|
| 75 |
+
data_iter = data[f'{mode}_loader']
|
| 76 |
+
ys_true, ys_pred = [], []
|
| 77 |
+
losses = []
|
| 78 |
+
for x, y in data_iter:
|
| 79 |
+
x = x.to(device)
|
| 80 |
+
y = y.to(device)
|
| 81 |
+
x, x_cov, x_his, y, y_cov = prepare_x_y(x, y)
|
| 82 |
+
output, _, _, _, _, _, _ = model(x, x_cov, x_his, y_cov)
|
| 83 |
+
y_pred = scaler.inverse_transform(output)
|
| 84 |
+
y_true = y
|
| 85 |
+
ys_true.append(y_true)
|
| 86 |
+
ys_pred.append(y_pred)
|
| 87 |
+
losses.append(masked_mae_loss(y_pred, y_true).item())
|
| 88 |
+
|
| 89 |
+
ys_true, ys_pred = torch.cat(ys_true, dim=0), torch.cat(ys_pred, dim=0)
|
| 90 |
+
loss = masked_mae_loss(ys_pred, ys_true)
|
| 91 |
+
|
| 92 |
+
if mode == 'test':
|
| 93 |
+
mae = masked_mae_loss(ys_pred, ys_true).item()
|
| 94 |
+
mape = masked_mape_loss(ys_pred, ys_true).item()
|
| 95 |
+
rmse = masked_rmse_loss(ys_pred, ys_true).item()
|
| 96 |
+
mae_3 = masked_mae_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
|
| 97 |
+
mape_3 = masked_mape_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
|
| 98 |
+
rmse_3 = masked_rmse_loss(ys_pred[:, 2, ...], ys_true[:, 2, ...]).item()
|
| 99 |
+
mae_6 = masked_mae_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
|
| 100 |
+
mape_6 = masked_mape_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
|
| 101 |
+
rmse_6 = masked_rmse_loss(ys_pred[:, 5, ...], ys_true[:, 5, ...]).item()
|
| 102 |
+
mae_12 = masked_mae_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
|
| 103 |
+
mape_12 = masked_mape_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
|
| 104 |
+
rmse_12 = masked_rmse_loss(ys_pred[:, 11, ...], ys_true[:, 11, ...]).item()
|
| 105 |
+
|
| 106 |
+
logger.info('Horizon overall: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae, mape * 100, rmse))
|
| 107 |
+
logger.info('Horizon 15mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_3, mape_3 * 100, rmse_3))
|
| 108 |
+
logger.info('Horizon 30mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_6, mape_6 * 100, rmse_6))
|
| 109 |
+
logger.info('Horizon 60mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mae_12, mape_12 * 100, rmse_12))
|
| 110 |
+
|
| 111 |
+
return np.mean(losses), ys_true, ys_pred
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def traintest_model():
|
| 115 |
+
model = get_model()
|
| 116 |
+
print_model(model)
|
| 117 |
+
|
| 118 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=args.epsilon, weight_decay=args.weight_decay)
|
| 119 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.steps, gamma=args.lr_decay_ratio)
|
| 120 |
+
min_val_loss = float('inf')
|
| 121 |
+
wait = 0
|
| 122 |
+
batches_seen = 0
|
| 123 |
+
for epoch_num in range(args.epochs):
|
| 124 |
+
start_time = time.time()
|
| 125 |
+
model = model.train()
|
| 126 |
+
data_iter = data['train_loader']
|
| 127 |
+
losses, mae_losses, contra_losses, deviation_losses = [], [], [], []
|
| 128 |
+
for x, y in data_iter:
|
| 129 |
+
optimizer.zero_grad()
|
| 130 |
+
x = x.to(device)
|
| 131 |
+
y = y.to(device)
|
| 132 |
+
x, x_cov, x_his, y, y_cov = prepare_x_y(x, y)
|
| 133 |
+
output, query, pos, neg, mask, query_simi, pos_simi = model(x, x_cov, x_his, y_cov, scaler.transform(y), batches_seen)
|
| 134 |
+
y_pred = scaler.inverse_transform(output)
|
| 135 |
+
y_true = y
|
| 136 |
+
|
| 137 |
+
mae_loss = masked_mae_loss(y_pred, y_true) # masked_mae_loss(y_pred, y_true)
|
| 138 |
+
contrastive_loss = ContrastiveLoss(contra_loss=args.contra_loss, mask=mask, temp=args.temp)
|
| 139 |
+
|
| 140 |
+
loss_c = contrastive_loss.calculate(query[0], pos[0], neg[0], mask[0])
|
| 141 |
+
loss_d = F.l1_loss(query_simi.detach(), pos_simi)
|
| 142 |
+
loss = mae_loss + args.lamb_c * loss_c + args.lamb_d * loss_d
|
| 143 |
+
|
| 144 |
+
losses.append(loss.item())
|
| 145 |
+
mae_losses.append(mae_loss.item())
|
| 146 |
+
contra_losses.append(loss_c.item())
|
| 147 |
+
deviation_losses.append(loss_d.item())
|
| 148 |
+
losses.append(loss.item())
|
| 149 |
+
batches_seen += 1
|
| 150 |
+
loss.backward()
|
| 151 |
+
if args.max_grad_norm:
|
| 152 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) # gradient clipping - this does it in place
|
| 153 |
+
optimizer.step()
|
| 154 |
+
|
| 155 |
+
end_time2 = time.time()
|
| 156 |
+
train_loss = np.mean(losses)
|
| 157 |
+
train_mae_loss = np.mean(mae_losses)
|
| 158 |
+
train_contra_loss = np.mean(contra_losses)
|
| 159 |
+
train_deviation_loss = np.mean(deviation_losses)
|
| 160 |
+
lr_scheduler.step()
|
| 161 |
+
val_loss, _, _ = evaluate(model, 'val')
|
| 162 |
+
message = 'Epoch [{}/{}] ({}) train_loss: {:.4f}, train_mae_loss: {:.4f}, train_contra_loss: {:.4f}, train_deviation_loss: {:.4f}, val_loss: {:.4f}, lr: {:.6f}, {:.2f}s'.format(epoch_num + 1, args.epochs, batches_seen, train_loss, train_mae_loss, train_contra_loss, train_deviation_loss, val_loss, optimizer.param_groups[0]['lr'], (end_time2 - start_time))
|
| 163 |
+
logger.info(message)
|
| 164 |
+
|
| 165 |
+
test_loss, _, _ = evaluate(model, 'test')
|
| 166 |
+
logger.info("\n")
|
| 167 |
+
|
| 168 |
+
if val_loss < min_val_loss:
|
| 169 |
+
wait = 0
|
| 170 |
+
min_val_loss = val_loss
|
| 171 |
+
torch.save(model.state_dict(), modelpt_path)
|
| 172 |
+
elif val_loss >= min_val_loss:
|
| 173 |
+
wait += 1
|
| 174 |
+
if wait == args.patience:
|
| 175 |
+
logger.info('Early stopping at epoch: %d' % (epoch_num + 1))
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
logger.info('=' * 35 + 'Best val_loss model performance' + '=' * 35)
|
| 179 |
+
logger.info('=' * 22 + 'Better results might be found from model at different epoch' + '=' * 22)
|
| 180 |
+
model = get_model()
|
| 181 |
+
model.load_state_dict(torch.load(modelpt_path))
|
| 182 |
+
start=time.time()
|
| 183 |
+
test_loss, _, _ = evaluate(model, 'test')
|
| 184 |
+
end=time.time()
|
| 185 |
+
logger.info(f"Inference Time: {(end-start):.2f}s")
|
| 186 |
+
|
| 187 |
+
#########################################################################################
|
| 188 |
+
parser = argparse.ArgumentParser()
|
| 189 |
+
parser.add_argument('--dataset', type=str, choices=['METRLA', 'PEMSBAY','PEMS04','PEMS07','PEMS08','PEMSD7M'], default='METRLA', help='which dataset to run')
|
| 190 |
+
parser.add_argument('--num_nodes', type=int, default=207, help='num_nodes')
|
| 191 |
+
parser.add_argument('--seq_len', type=int, default=12, help='input sequence length')
|
| 192 |
+
parser.add_argument('--horizon', type=int, default=12, help='output sequence length')
|
| 193 |
+
parser.add_argument('--input_dim', type=int, default=1, help='number of input channel')
|
| 194 |
+
parser.add_argument('--output_dim', type=int, default=1, help='number of output channel')
|
| 195 |
+
parser.add_argument('--tod_embed_dim', type=int, default=10, help='embedding dimension for adaptive graph')
|
| 196 |
+
parser.add_argument('--cheb_k', type=int, default=3, help='max diffusion step or Cheb K')
|
| 197 |
+
parser.add_argument('--rnn_layers', type=int, default=1, help='number of rnn layers')
|
| 198 |
+
parser.add_argument('--rnn_units', type=int, default=128, help='number of rnn units')
|
| 199 |
+
parser.add_argument('--prototype_num', type=int, default=20, help='number of meta-nodes/prototypes')
|
| 200 |
+
parser.add_argument('--prototype_dim', type=int, default=64, help='dimension of meta-nodes/prototypes')
|
| 201 |
+
parser.add_argument("--loss", type=str, default='mask_mae_loss', help="mask_mae_loss")
|
| 202 |
+
parser.add_argument("--epochs", type=int, default=200, help="number of epochs of training")
|
| 203 |
+
parser.add_argument("--patience", type=int, default=30, help="patience used for early stop")
|
| 204 |
+
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
|
| 205 |
+
parser.add_argument("--lr", type=float, default=0.01, help="base learning rate")
|
| 206 |
+
parser.add_argument("--steps", type=eval, default=[50, 100], help="steps")
|
| 207 |
+
parser.add_argument("--lr_decay_ratio", type=float, default=0.1, help="lr_decay_ratio")
|
| 208 |
+
parser.add_argument("--weight_decay", type=float, default=0, help="weight_decay_ratio")
|
| 209 |
+
parser.add_argument("--epsilon", type=float, default=1e-3, help="optimizer epsilon")
|
| 210 |
+
parser.add_argument("--max_grad_norm", type=int, default=5, help="max_grad_norm")
|
| 211 |
+
parser.add_argument("--use_curriculum_learning", type=eval, choices=[True, False], default='True', help="use_curriculum_learning")
|
| 212 |
+
parser.add_argument("--adj_type", type=str, default='symadj', help="scalap, normlap, symadj, transition, doubletransition")
|
| 213 |
+
parser.add_argument("--cl_decay_steps", type=int, default=2000, help="cl_decay_steps")
|
| 214 |
+
parser.add_argument('--gpu', type=int, default=0, help='which gpu to use')
|
| 215 |
+
parser.add_argument('--seed', type=int, default=100, help='random seed.')
|
| 216 |
+
parser.add_argument('--temp', type=float, default=1.0, help='temperature parameter')
|
| 217 |
+
parser.add_argument('--lamb_c', type=float, default=0.1, help='contra loss lambda')
|
| 218 |
+
parser.add_argument('--lamb_d', type=float, default=1.0, help='deviation loss lambda')
|
| 219 |
+
parser.add_argument('--contra_loss', type=str, choices=['triplet', 'infonce'], default='triplet', help='whether to triplet or infonce contra loss')
|
| 220 |
+
parser.add_argument("--use_STE", type=eval, choices=[True, False], default='True', help="use spatio-temporal embedding")
|
| 221 |
+
parser.add_argument("--adaptive_embedding_dim", type=int,default=48, help="use spatio-temporal adaptive embedding")
|
| 222 |
+
parser.add_argument("--node_embedding_dim", type=int,default=20, help="use spatio-temporal adaptive embedding")
|
| 223 |
+
parser.add_argument("--input_embedding_dim", type=int,default=128, help="use spatio-temporal adaptive embedding")
|
| 224 |
+
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
num_nodes_dict={
|
| 227 |
+
"METRLA": 207,
|
| 228 |
+
"PEMSBAY": 325,
|
| 229 |
+
"PEMS04": 307,
|
| 230 |
+
"PEMS07": 883,
|
| 231 |
+
"PEMS08": 170,
|
| 232 |
+
"PEMSD7M": 228,
|
| 233 |
+
}
|
| 234 |
+
if args.dataset == 'METRLA':
|
| 235 |
+
data_path = f'../{args.dataset}/metr-la.h5'
|
| 236 |
+
adj_mx_path = f'../{args.dataset}/adj_mx.pkl'
|
| 237 |
+
args.num_nodes = 207
|
| 238 |
+
args.use_STE=True
|
| 239 |
+
rand_seed=random.randint(0, 1000000)# 31340
|
| 240 |
+
args.seed=999
|
| 241 |
+
args.lamb_c=0.01
|
| 242 |
+
args.lamb_d=1
|
| 243 |
+
args.steps = [50,70]
|
| 244 |
+
args.input_embedding_dim=3
|
| 245 |
+
args.node_embedding_dim=25
|
| 246 |
+
args.tod_embed_dim=20 #TOD embedding
|
| 247 |
+
args.adaptive_embedding_dim=0
|
| 248 |
+
|
| 249 |
+
elif args.dataset == 'PEMSBAY':
|
| 250 |
+
data_path = f'../{args.dataset}/pems-bay.h5'
|
| 251 |
+
adj_mx_path = f'../{args.dataset}/adj_mx_bay.pkl'
|
| 252 |
+
args.num_nodes = 325
|
| 253 |
+
args.use_STE=True
|
| 254 |
+
args.cl_decay_steps = 8000
|
| 255 |
+
args.steps = [10, 70,150]
|
| 256 |
+
args.seed=666
|
| 257 |
+
args.lamb_c=0.01
|
| 258 |
+
args.lamb_d=1
|
| 259 |
+
args.input_embedding_dim=10
|
| 260 |
+
args.node_embedding_dim=20
|
| 261 |
+
args.tod_embed_dim=20 #TOD embedding
|
| 262 |
+
args.adaptive_embedding_dim=0
|
| 263 |
+
|
| 264 |
+
elif args.dataset == 'PEMS04':
|
| 265 |
+
data_path = f'../{args.dataset}/{args.dataset}.npz'
|
| 266 |
+
adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
|
| 267 |
+
args.num_nodes = num_nodes_dict[args.dataset]
|
| 268 |
+
rand_seed=random.randint(0, 1000000)# 31340
|
| 269 |
+
args.seed=610958
|
| 270 |
+
args.patience=30
|
| 271 |
+
args.batch_size=16
|
| 272 |
+
args.lr=0.001
|
| 273 |
+
args.epochs=200
|
| 274 |
+
args.steps=[50, 100]
|
| 275 |
+
args.weight_decay=0
|
| 276 |
+
args.max_grad_norm=0
|
| 277 |
+
args.rnn_units=32
|
| 278 |
+
args.prototype_num=20
|
| 279 |
+
args.prototype_dim=64
|
| 280 |
+
args.cl_decay_steps=6000
|
| 281 |
+
args.max_diffusion_step=3
|
| 282 |
+
args.input_embedding_dim=32
|
| 283 |
+
args.node_embedding_dim=24
|
| 284 |
+
args.tod_embed_dim=40 #TOD embedding
|
| 285 |
+
args.adaptive_embedding_dim=0
|
| 286 |
+
args.use_curriculum_learning=True
|
| 287 |
+
args.lamb_c=0.01
|
| 288 |
+
args.lamb_d=0.01
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
elif args.dataset == 'PEMS07':
|
| 292 |
+
data_path = f'../{args.dataset}/{args.dataset}.npz'
|
| 293 |
+
adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
|
| 294 |
+
args.num_nodes = num_nodes_dict[args.dataset]
|
| 295 |
+
args.patience=20
|
| 296 |
+
args.batch_size=16
|
| 297 |
+
args.lr=0.001
|
| 298 |
+
args.steps=[50, 100]
|
| 299 |
+
args.weight_decay=0
|
| 300 |
+
args.max_grad_norm=0
|
| 301 |
+
args.rnn_units=64
|
| 302 |
+
args.prototype_num=20
|
| 303 |
+
args.prototype_dim=64
|
| 304 |
+
args.cl_decay_steps=6000
|
| 305 |
+
args.max_diffusion_step=3
|
| 306 |
+
args.lamb_c=0.01
|
| 307 |
+
args.lamb_d=1
|
| 308 |
+
args.seed=100
|
| 309 |
+
args.input_embedding_dim=64
|
| 310 |
+
args.node_embedding_dim=16
|
| 311 |
+
args.tod_embed_dim=16
|
| 312 |
+
args.adaptive_embedding_dim=0
|
| 313 |
+
elif args.dataset == 'PEMS08':
|
| 314 |
+
data_path = f'../{args.dataset}/{args.dataset}.npz'
|
| 315 |
+
adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
|
| 316 |
+
args.num_nodes = num_nodes_dict[args.dataset]
|
| 317 |
+
args.use_STE=True
|
| 318 |
+
args.patience=20
|
| 319 |
+
args.batch_size=16
|
| 320 |
+
rand_seed=random.randint(0, 1000000)# 31340
|
| 321 |
+
args.seed=rand_seed
|
| 322 |
+
args.cl_decay_steps=6000
|
| 323 |
+
args.max_diffusion_step=3
|
| 324 |
+
args.steps=[70, 100]
|
| 325 |
+
args.prototype_num=20
|
| 326 |
+
args.prototype_dim=64
|
| 327 |
+
args.use_curriculum_learning=True
|
| 328 |
+
args.rnn_units = 12
|
| 329 |
+
args.lamb_c=0.1
|
| 330 |
+
args.lamb_d=1
|
| 331 |
+
args.input_embedding_dim=16
|
| 332 |
+
args.node_embedding_dim=20
|
| 333 |
+
args.tod_embed_dim=20 #TOD embedding
|
| 334 |
+
args.adaptive_embedding_dim=0
|
| 335 |
+
|
| 336 |
+
elif args.dataset == 'PEMSD7M':
|
| 337 |
+
data_path = f'../{args.dataset}/{args.dataset}.npz'
|
| 338 |
+
adj_mx_path = f'../{args.dataset}/adj_{args.dataset}_distance.pkl'
|
| 339 |
+
args.num_nodes = num_nodes_dict[args.dataset]
|
| 340 |
+
rand_seed=random.randint(0, 1000000)# 31340
|
| 341 |
+
args.seed=119089
|
| 342 |
+
args.patience=30
|
| 343 |
+
args.batch_size=16
|
| 344 |
+
args.lr=0.001
|
| 345 |
+
args.steps=[50, 100]
|
| 346 |
+
args.weight_decay=0
|
| 347 |
+
args.max_grad_norm=0
|
| 348 |
+
args.rnn_units=32
|
| 349 |
+
args.prototype_num=16
|
| 350 |
+
args.prototype_dim=64
|
| 351 |
+
args.cl_decay_steps=4000
|
| 352 |
+
args.max_diffusion_step=3
|
| 353 |
+
args.lamb_c=0.1
|
| 354 |
+
args.lamb_d=1
|
| 355 |
+
args.input_embedding_dim=32
|
| 356 |
+
args.node_embedding_dim=20
|
| 357 |
+
args.tod_embed_dim=16 #TOD embedding
|
| 358 |
+
args.adaptive_embedding_dim=0
|
| 359 |
+
|
| 360 |
+
model_name = 'STSSDL'
|
| 361 |
+
timestring = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
| 362 |
+
path = f'../save/{args.dataset}_{model_name}_{timestring}'
|
| 363 |
+
logging_path = f'{path}/{model_name}_{timestring}_logging.txt'
|
| 364 |
+
score_path = f'{path}/{model_name}_{timestring}_scores.txt'
|
| 365 |
+
epochlog_path = f'{path}/{model_name}_{timestring}_epochlog.txt'
|
| 366 |
+
modelpt_path = f'{path}/{model_name}_{timestring}.pt'
|
| 367 |
+
if not os.path.exists(path): os.makedirs(path)
|
| 368 |
+
shutil.copy2(sys.argv[0], path)
|
| 369 |
+
shutil.copy2(f'{model_name}.py', path)
|
| 370 |
+
shutil.copy2('utils.py', path)
|
| 371 |
+
|
| 372 |
+
logger = logging.getLogger(__name__)
|
| 373 |
+
logger.setLevel(level = logging.INFO)
|
| 374 |
+
class MyFormatter(logging.Formatter):
|
| 375 |
+
def format(self, record):
|
| 376 |
+
spliter = ' '
|
| 377 |
+
record.msg = str(record.msg) + spliter + spliter.join(map(str, record.args))
|
| 378 |
+
record.args = tuple() # set empty to args
|
| 379 |
+
return super().format(record)
|
| 380 |
+
formatter = MyFormatter()
|
| 381 |
+
handler = logging.FileHandler(logging_path, mode='a')
|
| 382 |
+
handler.setLevel(logging.INFO)
|
| 383 |
+
handler.setFormatter(formatter)
|
| 384 |
+
console = logging.StreamHandler()
|
| 385 |
+
console.setLevel(logging.INFO)
|
| 386 |
+
console.setFormatter(formatter)
|
| 387 |
+
logger.addHandler(handler)
|
| 388 |
+
logger.addHandler(console)
|
| 389 |
+
message = ''.join([f'{k}: {v}\n' for k, v in vars(args).items()])
|
| 390 |
+
logger.info(message)
|
| 391 |
+
|
| 392 |
+
cpu_num = 1
|
| 393 |
+
os.environ ['OMP_NUM_THREADS'] = str(cpu_num)
|
| 394 |
+
os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
|
| 395 |
+
os.environ ['MKL_NUM_THREADS'] = str(cpu_num)
|
| 396 |
+
os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
|
| 397 |
+
os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
|
| 398 |
+
torch.set_num_threads(cpu_num)
|
| 399 |
+
device = torch.device("cuda:{}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu")
|
| 400 |
+
|
| 401 |
+
np.random.seed(args.seed)
|
| 402 |
+
torch.manual_seed(args.seed)
|
| 403 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed)
|
| 404 |
+
#####################################################################################################
|
| 405 |
+
|
| 406 |
+
data = {}
|
| 407 |
+
for category in ['train', 'val', 'test']:
|
| 408 |
+
cat_data = np.load(os.path.join(f'../{args.dataset}', category + 'his.npz'))
|
| 409 |
+
data['x_' + category] = np.nan_to_num(cat_data['x']) if True in np.isnan(cat_data['x']) else cat_data['x']
|
| 410 |
+
data['y_' + category] = np.nan_to_num(cat_data['y']) if True in np.isnan(cat_data['y']) else cat_data['y']
|
| 411 |
+
scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std())
|
| 412 |
+
for category in ['train', 'val', 'test']:
|
| 413 |
+
data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
|
| 414 |
+
data['x_' + category][..., 2] = scaler.transform(data['x_' + category][..., 2]) # x_his
|
| 415 |
+
|
| 416 |
+
data['train_loader'] = torch.utils.data.DataLoader(
|
| 417 |
+
torch.utils.data.TensorDataset(torch.FloatTensor(data['x_train']), torch.FloatTensor(data['y_train'])),
|
| 418 |
+
batch_size=args.batch_size,
|
| 419 |
+
shuffle=True
|
| 420 |
+
)
|
| 421 |
+
data['val_loader'] = torch.utils.data.DataLoader(
|
| 422 |
+
torch.utils.data.TensorDataset(torch.FloatTensor(data['x_val']), torch.FloatTensor(data['y_val'])),
|
| 423 |
+
batch_size=args.batch_size,
|
| 424 |
+
shuffle=False
|
| 425 |
+
)
|
| 426 |
+
data['test_loader'] = torch.utils.data.DataLoader(
|
| 427 |
+
torch.utils.data.TensorDataset(torch.FloatTensor(data['x_test']), torch.FloatTensor(data['y_test'])),
|
| 428 |
+
batch_size=args.batch_size,
|
| 429 |
+
shuffle=False
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
logger.info(args.dataset, 'training and testing started', time.ctime())
|
| 434 |
+
logger.info('train xs.shape, ys.shape', data['x_train'].shape, data['y_train'].shape)
|
| 435 |
+
logger.info('val xs.shape, ys.shape', data['x_val'].shape, data['y_val'].shape)
|
| 436 |
+
logger.info('test xs.shape, ys.shape', data['x_test'].shape, data['y_test'].shape)
|
| 437 |
+
traintest_model()
|
| 438 |
+
logger.info(args.dataset, 'training and testing ended', time.ctime())
|
| 439 |
+
|
| 440 |
+
if __name__ == '__main__':
|
| 441 |
+
main()
|
| 442 |
+
|
model_STSSDL/utils.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import scipy.sparse as sp
|
| 6 |
+
from scipy.sparse import linalg
|
| 7 |
+
|
| 8 |
+
class DataLoader(object):
|
| 9 |
+
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
:param xs:
|
| 13 |
+
:param ys:
|
| 14 |
+
:param batch_size:
|
| 15 |
+
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
|
| 16 |
+
"""
|
| 17 |
+
self.batch_size = batch_size
|
| 18 |
+
self.current_ind = 0
|
| 19 |
+
if pad_with_last_sample:
|
| 20 |
+
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
|
| 21 |
+
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
|
| 22 |
+
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
|
| 23 |
+
xs = np.concatenate([xs, x_padding], axis=0)
|
| 24 |
+
ys = np.concatenate([ys, y_padding], axis=0)
|
| 25 |
+
self.size = len(xs)
|
| 26 |
+
self.num_batch = int(self.size // self.batch_size)
|
| 27 |
+
if shuffle:
|
| 28 |
+
permutation = np.random.permutation(self.size)
|
| 29 |
+
xs, ys = xs[permutation], ys[permutation]
|
| 30 |
+
self.xs = xs
|
| 31 |
+
self.ys = ys
|
| 32 |
+
|
| 33 |
+
def get_iterator(self):
|
| 34 |
+
self.current_ind = 0
|
| 35 |
+
|
| 36 |
+
def _wrapper():
|
| 37 |
+
while self.current_ind < self.num_batch:
|
| 38 |
+
start_ind = self.batch_size * self.current_ind
|
| 39 |
+
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
|
| 40 |
+
x_i = self.xs[start_ind: end_ind, ...]
|
| 41 |
+
y_i = self.ys[start_ind: end_ind, ...]
|
| 42 |
+
yield (x_i, y_i)
|
| 43 |
+
self.current_ind += 1
|
| 44 |
+
|
| 45 |
+
return _wrapper()
|
| 46 |
+
|
| 47 |
+
class StandardScaler():
|
| 48 |
+
def __init__(self, mean, std):
|
| 49 |
+
self.mean = mean
|
| 50 |
+
self.std = std
|
| 51 |
+
|
| 52 |
+
def transform(self, data):
|
| 53 |
+
return (data - self.mean) / self.std
|
| 54 |
+
|
| 55 |
+
def inverse_transform(self, data):
|
| 56 |
+
return (data * self.std) + self.mean
|
| 57 |
+
|
| 58 |
+
def getTimestamp(data):
|
| 59 |
+
num_samples, num_nodes = data.shape
|
| 60 |
+
time_ind = (data.index.values - data.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D")
|
| 61 |
+
time_in_day = np.tile(time_ind, [num_nodes,1]).transpose((1, 0))
|
| 62 |
+
return time_in_day
|
| 63 |
+
|
| 64 |
+
def getDayTimestamp(data):
|
| 65 |
+
# 288 timeslots each day for dataset has 5 minutes time interval.
|
| 66 |
+
df = pd.DataFrame({'timestamp':data.index.values})
|
| 67 |
+
df['weekdaytime'] = df['timestamp'].dt.weekday * 288 + (df['timestamp'].dt.hour * 60 + df['timestamp'].dt.minute)//5
|
| 68 |
+
df['weekdaytime'] = df['weekdaytime'] / df['weekdaytime'].max()
|
| 69 |
+
num_samples, num_nodes = data.shape
|
| 70 |
+
time_ind = df['weekdaytime'].values
|
| 71 |
+
time_ind_node = np.tile(time_ind, [num_nodes,1]).transpose((1, 0))
|
| 72 |
+
return time_ind_node
|
| 73 |
+
|
| 74 |
+
def getDayTimestamp_(start, end, freq, num_nodes):
|
| 75 |
+
# 288 timeslots each day for dataset has 5 minutes time interval.
|
| 76 |
+
df = pd.DataFrame({'timestamp':pd.date_range(start=start, end=end, freq=freq)})
|
| 77 |
+
df['weekdaytime'] = df['timestamp'].dt.weekday * 288 + (df['timestamp'].dt.hour * 60 + df['timestamp'].dt.minute)//5
|
| 78 |
+
df['weekdaytime'] = df['weekdaytime'] / df['weekdaytime'].max()
|
| 79 |
+
time_ind = df['weekdaytime'].values
|
| 80 |
+
time_ind_node = np.tile(time_ind, [num_nodes, 1]).transpose((1, 0))
|
| 81 |
+
return time_ind_node
|
| 82 |
+
|
| 83 |
+
def masked_mse(preds, labels, null_val=1e-3):
|
| 84 |
+
if np.isnan(null_val):
|
| 85 |
+
mask = ~torch.isnan(labels)
|
| 86 |
+
else:
|
| 87 |
+
mask = (labels > null_val)
|
| 88 |
+
mask = mask.float()
|
| 89 |
+
mask /= torch.mean((mask))
|
| 90 |
+
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
|
| 91 |
+
loss = (preds-labels)**2
|
| 92 |
+
loss = loss * mask
|
| 93 |
+
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
|
| 94 |
+
return torch.mean(loss)
|
| 95 |
+
|
| 96 |
+
def masked_rmse(preds, labels, null_val=1e-3):
|
| 97 |
+
return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def masked_mae(preds, labels, null_val=1e-3):
|
| 101 |
+
if np.isnan(null_val):
|
| 102 |
+
mask = ~torch.isnan(labels)
|
| 103 |
+
else:
|
| 104 |
+
mask = (labels > null_val)
|
| 105 |
+
mask = mask.float()
|
| 106 |
+
mask /= torch.mean((mask))
|
| 107 |
+
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
|
| 108 |
+
loss = torch.abs(preds-labels)
|
| 109 |
+
loss = loss * mask
|
| 110 |
+
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
|
| 111 |
+
return torch.mean(loss)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def masked_mape(preds, labels, null_val=1e-3):
|
| 115 |
+
if np.isnan(null_val):
|
| 116 |
+
mask = ~torch.isnan(labels)
|
| 117 |
+
else:
|
| 118 |
+
mask = (labels > null_val)
|
| 119 |
+
mask = mask.float()
|
| 120 |
+
mask /= torch.mean((mask))
|
| 121 |
+
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
|
| 122 |
+
loss = torch.abs(preds-labels)/labels
|
| 123 |
+
loss = loss * mask
|
| 124 |
+
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
|
| 125 |
+
return torch.mean(loss)
|
| 126 |
+
|
| 127 |
+
# DCRNN
|
| 128 |
+
def masked_mae_loss(y_pred, y_true):
|
| 129 |
+
mask = (y_true != 0).float()
|
| 130 |
+
mask /= mask.mean()
|
| 131 |
+
loss = torch.abs(y_pred - y_true)
|
| 132 |
+
loss = loss * mask
|
| 133 |
+
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
| 134 |
+
loss[loss != loss] = 0
|
| 135 |
+
return loss.mean()
|
| 136 |
+
|
| 137 |
+
def masked_mape_loss(y_pred, y_true):
|
| 138 |
+
mask = (y_true != 0).float()
|
| 139 |
+
mask /= mask.mean()
|
| 140 |
+
loss = torch.abs(torch.div(y_true - y_pred, y_true))
|
| 141 |
+
loss = loss * mask
|
| 142 |
+
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
| 143 |
+
loss[loss != loss] = 0
|
| 144 |
+
return loss.mean()
|
| 145 |
+
|
| 146 |
+
def masked_rmse_loss(y_pred, y_true):
|
| 147 |
+
mask = (y_true != 0).float()
|
| 148 |
+
mask /= mask.mean()
|
| 149 |
+
loss = torch.pow(y_true - y_pred, 2)
|
| 150 |
+
loss = loss * mask
|
| 151 |
+
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
| 152 |
+
loss[loss != loss] = 0
|
| 153 |
+
return torch.sqrt(loss.mean())
|
| 154 |
+
|
| 155 |
+
def masked_mse_loss(y_pred, y_true):
|
| 156 |
+
mask = (y_true != 0).float()
|
| 157 |
+
mask /= mask.mean()
|
| 158 |
+
loss = torch.pow(y_true - y_pred, 2)
|
| 159 |
+
loss = loss * mask
|
| 160 |
+
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
| 161 |
+
loss[loss != loss] = 0
|
| 162 |
+
return loss.mean()
|
| 163 |
+
|
| 164 |
+
def load_pickle(pickle_file):
|
| 165 |
+
try:
|
| 166 |
+
with open(pickle_file, 'rb') as f:
|
| 167 |
+
pickle_data = pickle.load(f)
|
| 168 |
+
except UnicodeDecodeError as e:
|
| 169 |
+
with open(pickle_file, 'rb') as f:
|
| 170 |
+
pickle_data = pickle.load(f, encoding='latin1')
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print('Unable to load data ', pickle_file, ':', e)
|
| 173 |
+
raise
|
| 174 |
+
return pickle_data
|
| 175 |
+
|
| 176 |
+
def sym_adj(adj):
|
| 177 |
+
"""Symmetrically normalize adjacency matrix."""
|
| 178 |
+
adj = sp.coo_matrix(adj)
|
| 179 |
+
rowsum = np.array(adj.sum(1))
|
| 180 |
+
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
| 181 |
+
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
| 182 |
+
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
| 183 |
+
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense()
|
| 184 |
+
|
| 185 |
+
def asym_adj(adj):
|
| 186 |
+
adj = sp.coo_matrix(adj)
|
| 187 |
+
rowsum = np.array(adj.sum(1)).flatten()
|
| 188 |
+
d_inv = np.power(rowsum, -1).flatten()
|
| 189 |
+
d_inv[np.isinf(d_inv)] = 0.
|
| 190 |
+
d_mat = sp.diags(d_inv)
|
| 191 |
+
return d_mat.dot(adj).astype(np.float32).todense()
|
| 192 |
+
|
| 193 |
+
def calculate_normalized_laplacian(adj):
|
| 194 |
+
"""
|
| 195 |
+
# L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
|
| 196 |
+
# D = diag(A 1)
|
| 197 |
+
:param adj:
|
| 198 |
+
:return:
|
| 199 |
+
"""
|
| 200 |
+
adj = sp.coo_matrix(adj)
|
| 201 |
+
d = np.array(adj.sum(1))
|
| 202 |
+
d_inv_sqrt = np.power(d, -0.5).flatten()
|
| 203 |
+
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
|
| 204 |
+
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
| 205 |
+
normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
| 206 |
+
return normalized_laplacian
|
| 207 |
+
|
| 208 |
+
def calculate_random_walk_matrix(adj_mx):
|
| 209 |
+
adj_mx = sp.coo_matrix(adj_mx)
|
| 210 |
+
d = np.array(adj_mx.sum(1))
|
| 211 |
+
d_inv = np.power(d, -1).flatten()
|
| 212 |
+
d_inv[np.isinf(d_inv)] = 0.
|
| 213 |
+
d_mat_inv = sp.diags(d_inv)
|
| 214 |
+
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
|
| 215 |
+
return random_walk_mx
|
| 216 |
+
|
| 217 |
+
def calculate_reverse_random_walk_matrix(adj_mx):
|
| 218 |
+
return calculate_random_walk_matrix(np.transpose(adj_mx))
|
| 219 |
+
|
| 220 |
+
def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
|
| 221 |
+
if undirected:
|
| 222 |
+
adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
|
| 223 |
+
L = calculate_normalized_laplacian(adj_mx)
|
| 224 |
+
if lambda_max is None:
|
| 225 |
+
lambda_max, _ = linalg.eigsh(L, 1, which='LM')
|
| 226 |
+
lambda_max = lambda_max[0]
|
| 227 |
+
L = sp.csr_matrix(L)
|
| 228 |
+
M, _ = L.shape
|
| 229 |
+
I = sp.identity(M, format='csr', dtype=L.dtype)
|
| 230 |
+
L = (2 / lambda_max * L) - I
|
| 231 |
+
return L.astype(np.float32)
|
| 232 |
+
|
| 233 |
+
def load_adj(pkl_filename, adjtype):
|
| 234 |
+
if "PEMS0" in pkl_filename or "D7" in pkl_filename:
|
| 235 |
+
adj_mx = load_pickle(pkl_filename)
|
| 236 |
+
else:
|
| 237 |
+
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
|
| 238 |
+
if adjtype == "scalap":
|
| 239 |
+
adj = [calculate_scaled_laplacian(adj_mx)]
|
| 240 |
+
elif adjtype == "normlap":
|
| 241 |
+
adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32).todense()]
|
| 242 |
+
elif adjtype == "symadj":
|
| 243 |
+
adj = [sym_adj(adj_mx)]
|
| 244 |
+
elif adjtype == "transition":
|
| 245 |
+
adj = [asym_adj(adj_mx)]
|
| 246 |
+
elif adjtype == "doubletransition":
|
| 247 |
+
adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))]
|
| 248 |
+
elif adjtype == "identity":
|
| 249 |
+
adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)]
|
| 250 |
+
else:
|
| 251 |
+
error = 0
|
| 252 |
+
assert error, "adj type not defined"
|
| 253 |
+
return adj
|
| 254 |
+
|
| 255 |
+
def print_params(model):
|
| 256 |
+
# print trainable params
|
| 257 |
+
param_count = 0
|
| 258 |
+
print('Trainable parameter list:')
|
| 259 |
+
for name, param in model.named_parameters():
|
| 260 |
+
if param.requires_grad:
|
| 261 |
+
print(name, param.shape, param.numel())
|
| 262 |
+
param_count += param.numel()
|
| 263 |
+
print(f'\n In total: {param_count} trainable parameters. \n')
|
| 264 |
+
return
|