erermeev-d
commited on
Commit
·
a4da241
1
Parent(s):
b8f4763
Updated baseline
Browse files- exp/gnn.py +19 -14
exp/gnn.py
CHANGED
@@ -20,19 +20,22 @@ class GNNLayer(torch.nn.Module):
|
|
20 |
self._skip_connection = skip_connection
|
21 |
self._bidirectional = bidirectional
|
22 |
|
23 |
-
self._norm = torch.nn.LayerNorm(hidden_dim)
|
24 |
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
25 |
self._activation = torch.nn.ReLU()
|
26 |
|
27 |
if bidirectional:
|
28 |
-
self._norm_rev = torch.nn.LayerNorm(hidden_dim)
|
29 |
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
30 |
self._activation_rev = torch.nn.ReLU()
|
31 |
|
32 |
def forward(self, graph, x):
|
33 |
-
|
|
|
|
|
34 |
if self._bidirectional:
|
35 |
-
|
|
|
|
|
|
|
36 |
if self._skip_connection:
|
37 |
return x + y
|
38 |
else:
|
@@ -76,15 +79,22 @@ class GNNModel(torch.nn.Module):
|
|
76 |
|
77 |
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
|
78 |
prev_ids = set()
|
|
|
79 |
|
80 |
for _ in range(num_layers):
|
81 |
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
|
82 |
-
|
|
|
|
|
|
|
83 |
subgraph.add_edges(*new_edges)
|
|
|
|
|
84 |
prev_ids |= set(frontier_ids.cpu().tolist())
|
85 |
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
|
86 |
frontier_ids = list(frontier_ids - prev_ids)
|
87 |
|
|
|
88 |
return subgraph
|
89 |
|
90 |
def forward(self, ids):
|
@@ -215,12 +225,7 @@ def prepare_gnn_embeddings(
|
|
215 |
|
216 |
### Prepare optimizer & LR scheduler
|
217 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
218 |
-
|
219 |
-
lr_schedule = LRSchedule(
|
220 |
-
total_steps=total_steps,
|
221 |
-
warmup_steps=int(0.1*total_steps),
|
222 |
-
final_factor=0.1) # TODO: move to args
|
223 |
-
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
|
224 |
|
225 |
### Train loop
|
226 |
model.train()
|
@@ -284,15 +289,15 @@ if __name__ == "__main__":
|
|
284 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
|
285 |
parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
|
286 |
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
|
287 |
-
parser.add_argument("--
|
288 |
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
|
289 |
parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
|
290 |
parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
|
291 |
parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
|
292 |
-
parser.add_argument("--num_neighbor", type=int, default=
|
293 |
|
294 |
# Misc
|
295 |
-
parser.add_argument("--validate_every_n_epoch", type=int, default=
|
296 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
297 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
298 |
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
|
|
20 |
self._skip_connection = skip_connection
|
21 |
self._bidirectional = bidirectional
|
22 |
|
|
|
23 |
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
24 |
self._activation = torch.nn.ReLU()
|
25 |
|
26 |
if bidirectional:
|
|
|
27 |
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
28 |
self._activation_rev = torch.nn.ReLU()
|
29 |
|
30 |
def forward(self, graph, x):
|
31 |
+
edge_weights = graph.edata["weights"]
|
32 |
+
|
33 |
+
y = self._activation(self._conv(graph, x, edge_weights))
|
34 |
if self._bidirectional:
|
35 |
+
reversed_graph = dgl.reverse(graph, copy_edata=True)
|
36 |
+
edge_weights = reversed_graph.edata["weights"]
|
37 |
+
y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
|
38 |
+
|
39 |
if self._skip_connection:
|
40 |
return x + y
|
41 |
else:
|
|
|
79 |
|
80 |
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
|
81 |
prev_ids = set()
|
82 |
+
weights = []
|
83 |
|
84 |
for _ in range(num_layers):
|
85 |
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
|
86 |
+
new_sample = self._sampler(frontier_ids)
|
87 |
+
new_weights = new_sample.edata["weights"]
|
88 |
+
new_edges = new_sample.edges()
|
89 |
+
|
90 |
subgraph.add_edges(*new_edges)
|
91 |
+
weights.append(new_weights)
|
92 |
+
|
93 |
prev_ids |= set(frontier_ids.cpu().tolist())
|
94 |
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
|
95 |
frontier_ids = list(frontier_ids - prev_ids)
|
96 |
|
97 |
+
subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
|
98 |
return subgraph
|
99 |
|
100 |
def forward(self, ids):
|
|
|
225 |
|
226 |
### Prepare optimizer & LR scheduler
|
227 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
228 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
### Train loop
|
231 |
model.train()
|
|
|
289 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
|
290 |
parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
|
291 |
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
|
292 |
+
parser.add_argument("--skip_connection", action="store_true", dest="skip_connection", help="Disable skip connections")
|
293 |
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
|
294 |
parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
|
295 |
parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
|
296 |
parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
|
297 |
+
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
|
298 |
|
299 |
# Misc
|
300 |
+
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
|
301 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
302 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
303 |
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|