erermeev-d commited on
Commit
a4da241
·
1 Parent(s): b8f4763

Updated baseline

Browse files
Files changed (1) hide show
  1. exp/gnn.py +19 -14
exp/gnn.py CHANGED
@@ -20,19 +20,22 @@ class GNNLayer(torch.nn.Module):
20
  self._skip_connection = skip_connection
21
  self._bidirectional = bidirectional
22
 
23
- self._norm = torch.nn.LayerNorm(hidden_dim)
24
  self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
25
  self._activation = torch.nn.ReLU()
26
 
27
  if bidirectional:
28
- self._norm_rev = torch.nn.LayerNorm(hidden_dim)
29
  self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
30
  self._activation_rev = torch.nn.ReLU()
31
 
32
  def forward(self, graph, x):
33
- y = self._activation(self._conv(graph, self._norm(x)))
 
 
34
  if self._bidirectional:
35
- y = y + self._activation_rev(self._conv_rev(dgl.reverse(graph), self._norm_rev(x)))
 
 
 
36
  if self._skip_connection:
37
  return x + y
38
  else:
@@ -76,15 +79,22 @@ class GNNModel(torch.nn.Module):
76
 
77
  subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
78
  prev_ids = set()
 
79
 
80
  for _ in range(num_layers):
81
  frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
82
- new_edges = self._sampler(frontier_ids).edges()
 
 
 
83
  subgraph.add_edges(*new_edges)
 
 
84
  prev_ids |= set(frontier_ids.cpu().tolist())
85
  frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
86
  frontier_ids = list(frontier_ids - prev_ids)
87
 
 
88
  return subgraph
89
 
90
  def forward(self, ids):
@@ -215,12 +225,7 @@ def prepare_gnn_embeddings(
215
 
216
  ### Prepare optimizer & LR scheduler
217
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
218
- total_steps = num_epochs * len(dataloader)
219
- lr_schedule = LRSchedule(
220
- total_steps=total_steps,
221
- warmup_steps=int(0.1*total_steps),
222
- final_factor=0.1) # TODO: move to args
223
- lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
224
 
225
  ### Train loop
226
  model.train()
@@ -284,15 +289,15 @@ if __name__ == "__main__":
284
  parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
285
  parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
286
  parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
287
- parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
288
  parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
289
  parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
290
  parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
291
  parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
292
- parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
293
 
294
  # Misc
295
- parser.add_argument("--validate_every_n_epoch", type=int, default=2, help="Perform RecSys validation every n train epochs.")
296
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
297
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
298
  parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
 
20
  self._skip_connection = skip_connection
21
  self._bidirectional = bidirectional
22
 
 
23
  self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
24
  self._activation = torch.nn.ReLU()
25
 
26
  if bidirectional:
 
27
  self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
28
  self._activation_rev = torch.nn.ReLU()
29
 
30
  def forward(self, graph, x):
31
+ edge_weights = graph.edata["weights"]
32
+
33
+ y = self._activation(self._conv(graph, x, edge_weights))
34
  if self._bidirectional:
35
+ reversed_graph = dgl.reverse(graph, copy_edata=True)
36
+ edge_weights = reversed_graph.edata["weights"]
37
+ y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
38
+
39
  if self._skip_connection:
40
  return x + y
41
  else:
 
79
 
80
  subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
81
  prev_ids = set()
82
+ weights = []
83
 
84
  for _ in range(num_layers):
85
  frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
86
+ new_sample = self._sampler(frontier_ids)
87
+ new_weights = new_sample.edata["weights"]
88
+ new_edges = new_sample.edges()
89
+
90
  subgraph.add_edges(*new_edges)
91
+ weights.append(new_weights)
92
+
93
  prev_ids |= set(frontier_ids.cpu().tolist())
94
  frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
95
  frontier_ids = list(frontier_ids - prev_ids)
96
 
97
+ subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
98
  return subgraph
99
 
100
  def forward(self, ids):
 
225
 
226
  ### Prepare optimizer & LR scheduler
227
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
228
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
 
 
 
 
 
229
 
230
  ### Train loop
231
  model.train()
 
289
  parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
290
  parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
291
  parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
292
+ parser.add_argument("--skip_connection", action="store_true", dest="skip_connection", help="Disable skip connections")
293
  parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
294
  parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
295
  parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
296
  parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
297
+ parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
298
 
299
  # Misc
300
+ parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
301
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
302
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
303
  parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")