erermeev-d
commited on
Commit
Β·
c746c39
1
Parent(s):
a4da241
Refactored experiments code
Browse files- exp/gnn/__init__.py +0 -0
- exp/gnn/loss.py +17 -0
- exp/gnn/model.py +110 -0
- exp/{gnn.py β gnn/train.py} +7 -154
- exp/gnn/utils.py +82 -0
- exp/{prepare_embeddings.sh β pipeline.sh} +1 -1
- exp/utils.py +1 -62
exp/gnn/__init__.py
ADDED
File without changes
|
exp/gnn/loss.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
### Based on https://arxiv.org/pdf/2205.03169
|
5 |
+
def nt_xent_loss(sim, temperature):
|
6 |
+
sim = sim / temperature
|
7 |
+
n = sim.shape[0] // 2 # n = |user_batch|
|
8 |
+
|
9 |
+
aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
|
10 |
+
|
11 |
+
mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
|
12 |
+
sim = torch.where(mask, -torch.inf, sim)
|
13 |
+
sim = sim[:n, :]
|
14 |
+
distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
|
15 |
+
|
16 |
+
loss = aligment_loss + distribution_loss
|
17 |
+
return loss
|
exp/gnn/model.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import dgl
|
3 |
+
|
4 |
+
|
5 |
+
class GNNLayer(torch.nn.Module):
|
6 |
+
def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
|
7 |
+
super().__init__()
|
8 |
+
self._skip_connection = skip_connection
|
9 |
+
self._bidirectional = bidirectional
|
10 |
+
|
11 |
+
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
12 |
+
self._activation = torch.nn.ReLU()
|
13 |
+
|
14 |
+
if bidirectional:
|
15 |
+
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
16 |
+
self._activation_rev = torch.nn.ReLU()
|
17 |
+
|
18 |
+
def forward(self, graph, x):
|
19 |
+
edge_weights = graph.edata["weights"]
|
20 |
+
|
21 |
+
y = self._activation(self._conv(graph, x, edge_weights))
|
22 |
+
if self._bidirectional:
|
23 |
+
reversed_graph = dgl.reverse(graph, copy_edata=True)
|
24 |
+
edge_weights = reversed_graph.edata["weights"]
|
25 |
+
y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
|
26 |
+
|
27 |
+
if self._skip_connection:
|
28 |
+
return x + y
|
29 |
+
else:
|
30 |
+
return y
|
31 |
+
|
32 |
+
|
33 |
+
class GNNModel(torch.nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
bipartite_graph,
|
37 |
+
text_embeddings,
|
38 |
+
num_layers,
|
39 |
+
hidden_dim,
|
40 |
+
aggregator_type,
|
41 |
+
skip_connection,
|
42 |
+
bidirectional,
|
43 |
+
num_traversals,
|
44 |
+
termination_prob,
|
45 |
+
num_random_walks,
|
46 |
+
num_neighbor,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self._bipartite_graph = bipartite_graph
|
51 |
+
self._text_embeddings = text_embeddings
|
52 |
+
|
53 |
+
self._sampler = dgl.sampling.PinSAGESampler(
|
54 |
+
bipartite_graph, "Item", "User", num_traversals,
|
55 |
+
termination_prob, num_random_walks, num_neighbor)
|
56 |
+
|
57 |
+
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
|
58 |
+
|
59 |
+
self._layers = torch.nn.ModuleList()
|
60 |
+
for _ in range(num_layers):
|
61 |
+
self._layers.append(GNNLayer(
|
62 |
+
hidden_dim, aggregator_type, skip_connection, bidirectional))
|
63 |
+
|
64 |
+
def _sample_subraph(self, frontier_ids):
|
65 |
+
num_layers = len(self._layers)
|
66 |
+
device = self._bipartite_graph.device
|
67 |
+
|
68 |
+
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
|
69 |
+
prev_ids = set()
|
70 |
+
weights = []
|
71 |
+
|
72 |
+
for _ in range(num_layers):
|
73 |
+
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
|
74 |
+
new_sample = self._sampler(frontier_ids)
|
75 |
+
new_weights = new_sample.edata["weights"]
|
76 |
+
new_edges = new_sample.edges()
|
77 |
+
|
78 |
+
subgraph.add_edges(*new_edges)
|
79 |
+
weights.append(new_weights)
|
80 |
+
|
81 |
+
prev_ids |= set(frontier_ids.cpu().tolist())
|
82 |
+
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
|
83 |
+
frontier_ids = list(frontier_ids - prev_ids)
|
84 |
+
|
85 |
+
subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
|
86 |
+
return subgraph
|
87 |
+
|
88 |
+
def forward(self, ids):
|
89 |
+
### Sample subgraph
|
90 |
+
sampled_subgraph = self._sample_subraph(ids)
|
91 |
+
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
|
92 |
+
|
93 |
+
### Encode text embeddings
|
94 |
+
text_embeddings = self._text_embeddings[
|
95 |
+
sampled_subgraph.ndata[dgl.NID]]
|
96 |
+
features = self._text_encoder(text_embeddings)
|
97 |
+
|
98 |
+
### GNN goes brr...
|
99 |
+
for layer in self._layers:
|
100 |
+
features = layer(sampled_subgraph, features)
|
101 |
+
|
102 |
+
### Select features for initial ids
|
103 |
+
# TODO: write it more efficiently?
|
104 |
+
matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
|
105 |
+
ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
|
106 |
+
features = features[ids_in_subgraph]
|
107 |
+
|
108 |
+
### Normalize and return
|
109 |
+
features = features / torch.linalg.norm(features, dim=1, keepdim=True)
|
110 |
+
return features
|
exp/{gnn.py β gnn/train.py}
RENAMED
@@ -9,159 +9,14 @@ import torch
|
|
9 |
import wandb
|
10 |
from tqdm.auto import tqdm
|
11 |
|
12 |
-
from exp.utils import
|
13 |
from exp.prepare_recsys import prepare_recsys
|
14 |
from exp.evaluate import evaluate_recsys
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
self._skip_connection = skip_connection
|
21 |
-
self._bidirectional = bidirectional
|
22 |
-
|
23 |
-
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
24 |
-
self._activation = torch.nn.ReLU()
|
25 |
-
|
26 |
-
if bidirectional:
|
27 |
-
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
28 |
-
self._activation_rev = torch.nn.ReLU()
|
29 |
-
|
30 |
-
def forward(self, graph, x):
|
31 |
-
edge_weights = graph.edata["weights"]
|
32 |
-
|
33 |
-
y = self._activation(self._conv(graph, x, edge_weights))
|
34 |
-
if self._bidirectional:
|
35 |
-
reversed_graph = dgl.reverse(graph, copy_edata=True)
|
36 |
-
edge_weights = reversed_graph.edata["weights"]
|
37 |
-
y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
|
38 |
-
|
39 |
-
if self._skip_connection:
|
40 |
-
return x + y
|
41 |
-
else:
|
42 |
-
return y
|
43 |
-
|
44 |
-
|
45 |
-
class GNNModel(torch.nn.Module):
|
46 |
-
def __init__(
|
47 |
-
self,
|
48 |
-
bipartite_graph,
|
49 |
-
text_embeddings,
|
50 |
-
num_layers,
|
51 |
-
hidden_dim,
|
52 |
-
aggregator_type,
|
53 |
-
skip_connection,
|
54 |
-
bidirectional,
|
55 |
-
num_traversals,
|
56 |
-
termination_prob,
|
57 |
-
num_random_walks,
|
58 |
-
num_neighbor,
|
59 |
-
):
|
60 |
-
super().__init__()
|
61 |
-
|
62 |
-
self._bipartite_graph = bipartite_graph
|
63 |
-
self._text_embeddings = text_embeddings
|
64 |
-
|
65 |
-
self._sampler = dgl.sampling.PinSAGESampler(
|
66 |
-
bipartite_graph, "Item", "User", num_traversals,
|
67 |
-
termination_prob, num_random_walks, num_neighbor)
|
68 |
-
|
69 |
-
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
|
70 |
-
|
71 |
-
self._layers = torch.nn.ModuleList()
|
72 |
-
for _ in range(num_layers):
|
73 |
-
self._layers.append(GNNLayer(
|
74 |
-
hidden_dim, aggregator_type, skip_connection, bidirectional))
|
75 |
-
|
76 |
-
def _sample_subraph(self, frontier_ids):
|
77 |
-
num_layers = len(self._layers)
|
78 |
-
device = self._bipartite_graph.device
|
79 |
-
|
80 |
-
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
|
81 |
-
prev_ids = set()
|
82 |
-
weights = []
|
83 |
-
|
84 |
-
for _ in range(num_layers):
|
85 |
-
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
|
86 |
-
new_sample = self._sampler(frontier_ids)
|
87 |
-
new_weights = new_sample.edata["weights"]
|
88 |
-
new_edges = new_sample.edges()
|
89 |
-
|
90 |
-
subgraph.add_edges(*new_edges)
|
91 |
-
weights.append(new_weights)
|
92 |
-
|
93 |
-
prev_ids |= set(frontier_ids.cpu().tolist())
|
94 |
-
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
|
95 |
-
frontier_ids = list(frontier_ids - prev_ids)
|
96 |
-
|
97 |
-
subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
|
98 |
-
return subgraph
|
99 |
-
|
100 |
-
def forward(self, ids):
|
101 |
-
### Sample subgraph
|
102 |
-
sampled_subgraph = self._sample_subraph(ids)
|
103 |
-
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
|
104 |
-
|
105 |
-
### Encode text embeddings
|
106 |
-
text_embeddings = self._text_embeddings[
|
107 |
-
sampled_subgraph.ndata[dgl.NID]]
|
108 |
-
features = self._text_encoder(text_embeddings)
|
109 |
-
|
110 |
-
### GNN goes brr...
|
111 |
-
for layer in self._layers:
|
112 |
-
features = layer(sampled_subgraph, features)
|
113 |
-
|
114 |
-
### Select features for initial ids
|
115 |
-
# TODO: write it more efficiently?
|
116 |
-
matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
|
117 |
-
ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
|
118 |
-
features = features[ids_in_subgraph]
|
119 |
-
|
120 |
-
### Normalize and return
|
121 |
-
features = features / torch.linalg.norm(features, dim=1, keepdim=True)
|
122 |
-
return features
|
123 |
-
|
124 |
-
|
125 |
-
### Based on https://arxiv.org/pdf/2205.03169
|
126 |
-
def nt_xent_loss(sim, temperature):
|
127 |
-
sim = sim / temperature
|
128 |
-
n = sim.shape[0] // 2 # n = |user_batch|
|
129 |
-
|
130 |
-
aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
|
131 |
-
|
132 |
-
mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
|
133 |
-
sim = torch.where(mask, -torch.inf, sim)
|
134 |
-
sim = sim[:n, :]
|
135 |
-
distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
|
136 |
-
|
137 |
-
loss = aligment_loss + distribution_loss
|
138 |
-
return loss
|
139 |
-
|
140 |
-
|
141 |
-
def sample_item_batch(user_batch, bipartite_graph):
|
142 |
-
sampled_edges = dgl.sampling.sample_neighbors(
|
143 |
-
bipartite_graph, {"User": user_batch}, fanout=2
|
144 |
-
).edges(etype="ItemUser")
|
145 |
-
item_batch = sampled_edges[0]
|
146 |
-
item_batch = item_batch[torch.argsort(sampled_edges[1])]
|
147 |
-
item_batch = item_batch.reshape(-1, 2)
|
148 |
-
item_batch = item_batch.T
|
149 |
-
return item_batch
|
150 |
-
|
151 |
-
|
152 |
-
@torch.no_grad()
|
153 |
-
def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
|
154 |
-
model.eval()
|
155 |
-
item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
|
156 |
-
for items_batch in tqdm(torch.utils.data.DataLoader(
|
157 |
-
torch.arange(bipartite_graph.num_nodes("Item")),
|
158 |
-
batch_size=batch_size,
|
159 |
-
shuffle=True
|
160 |
-
)):
|
161 |
-
item_embeddings[items_batch] = model(items_batch.to(device))
|
162 |
-
|
163 |
-
item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
|
164 |
-
return item_embeddings
|
165 |
|
166 |
|
167 |
def prepare_gnn_embeddings(
|
@@ -228,9 +83,9 @@ def prepare_gnn_embeddings(
|
|
228 |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
|
229 |
|
230 |
### Train loop
|
231 |
-
model.train()
|
232 |
for epoch in range(num_epochs):
|
233 |
### Train
|
|
|
234 |
for user_batch in tqdm(dataloader):
|
235 |
item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
|
236 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
@@ -258,8 +113,6 @@ def prepare_gnn_embeddings(
|
|
258 |
print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
|
259 |
if use_wandb:
|
260 |
wandb.log(metrics)
|
261 |
-
|
262 |
-
|
263 |
|
264 |
if use_wandb:
|
265 |
wandb.finish()
|
|
|
9 |
import wandb
|
10 |
from tqdm.auto import tqdm
|
11 |
|
12 |
+
from exp.utils import normalize_embeddings
|
13 |
from exp.prepare_recsys import prepare_recsys
|
14 |
from exp.evaluate import evaluate_recsys
|
15 |
+
from exp.gnn.model import GNNModel
|
16 |
+
from exp.gnn.loss import nt_xent_loss
|
17 |
+
from exp.gnn.utils import (
|
18 |
+
prepare_graphs, LRSchedule,
|
19 |
+
sample_item_batch, inference_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def prepare_gnn_embeddings(
|
|
|
83 |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
|
84 |
|
85 |
### Train loop
|
|
|
86 |
for epoch in range(num_epochs):
|
87 |
### Train
|
88 |
+
model.train()
|
89 |
for user_batch in tqdm(dataloader):
|
90 |
item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
|
91 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
|
|
113 |
print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
|
114 |
if use_wandb:
|
115 |
wandb.log(metrics)
|
|
|
|
|
116 |
|
117 |
if use_wandb:
|
118 |
wandb.finish()
|
exp/gnn/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import dgl
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
from exp.utils import normalize_embeddings
|
8 |
+
|
9 |
+
|
10 |
+
class LRSchedule:
|
11 |
+
def __init__(self, total_steps, warmup_steps, final_factor):
|
12 |
+
self._total_steps = total_steps
|
13 |
+
self._warmup_steps = warmup_steps
|
14 |
+
self._final_factor = final_factor
|
15 |
+
|
16 |
+
def __call__(self, step):
|
17 |
+
if step >= self._total_steps:
|
18 |
+
return self._final_factor
|
19 |
+
|
20 |
+
if self._warmup_steps > 0:
|
21 |
+
warmup_factor = step / self._warmup_steps
|
22 |
+
else:
|
23 |
+
warmup_factor = 1.0
|
24 |
+
|
25 |
+
steps_after_warmup = step - self._warmup_steps
|
26 |
+
total_steps_after_warmup = self._total_steps - self._warmup_steps
|
27 |
+
after_warmup_factor = 1 \
|
28 |
+
- (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
|
29 |
+
|
30 |
+
factor = min(warmup_factor, after_warmup_factor)
|
31 |
+
return min(max(factor, 0), 1)
|
32 |
+
|
33 |
+
|
34 |
+
def prepare_graphs(items_path, ratings_path):
|
35 |
+
items = pd.read_csv(items_path)
|
36 |
+
ratings = pd.read_csv(ratings_path)
|
37 |
+
|
38 |
+
n_users = np.max(ratings["user_id"].unique()) + 1
|
39 |
+
item_ids = torch.tensor(sorted(items["item_id"].unique()))
|
40 |
+
|
41 |
+
edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
|
42 |
+
reverse_edges = (edges[1], edges[0])
|
43 |
+
|
44 |
+
bipartite_graph = dgl.heterograph(
|
45 |
+
data_dict={
|
46 |
+
("User", "UserItem", "Item"): edges,
|
47 |
+
("Item", "ItemUser", "User"): reverse_edges
|
48 |
+
},
|
49 |
+
num_nodes_dict={
|
50 |
+
"User": n_users,
|
51 |
+
"Item": len(item_ids)
|
52 |
+
}
|
53 |
+
)
|
54 |
+
graph = dgl.to_homogeneous(bipartite_graph)
|
55 |
+
graph = dgl.add_self_loop(graph)
|
56 |
+
return bipartite_graph, graph
|
57 |
+
|
58 |
+
|
59 |
+
def sample_item_batch(user_batch, bipartite_graph):
|
60 |
+
sampled_edges = dgl.sampling.sample_neighbors(
|
61 |
+
bipartite_graph, {"User": user_batch}, fanout=2
|
62 |
+
).edges(etype="ItemUser")
|
63 |
+
item_batch = sampled_edges[0]
|
64 |
+
item_batch = item_batch[torch.argsort(sampled_edges[1])]
|
65 |
+
item_batch = item_batch.reshape(-1, 2)
|
66 |
+
item_batch = item_batch.T
|
67 |
+
return item_batch
|
68 |
+
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
|
72 |
+
model.eval()
|
73 |
+
item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
|
74 |
+
for items_batch in tqdm(torch.utils.data.DataLoader(
|
75 |
+
torch.arange(bipartite_graph.num_nodes("Item")),
|
76 |
+
batch_size=batch_size,
|
77 |
+
shuffle=True
|
78 |
+
)):
|
79 |
+
item_embeddings[items_batch] = model(items_batch.to(device))
|
80 |
+
|
81 |
+
item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
|
82 |
+
return item_embeddings
|
exp/{prepare_embeddings.sh β pipeline.sh}
RENAMED
@@ -17,7 +17,7 @@ PYTHONPATH=. python exp/sbert.py \
|
|
17 |
--embeddings_savepath "$save_directory/text_embeddings.npy" \
|
18 |
--device $device
|
19 |
|
20 |
-
PYTHONPATH=. python exp/gnn.py \
|
21 |
--items_path "$save_directory/items.csv" \
|
22 |
--train_ratings_path "$save_directory/train_ratings.csv" \
|
23 |
--val_ratings_path "$save_directory/val_ratings.csv" \
|
|
|
17 |
--embeddings_savepath "$save_directory/text_embeddings.npy" \
|
18 |
--device $device
|
19 |
|
20 |
+
PYTHONPATH=. python exp/gnn/train.py \
|
21 |
--items_path "$save_directory/items.csv" \
|
22 |
--train_ratings_path "$save_directory/train_ratings.csv" \
|
23 |
--val_ratings_path "$save_directory/val_ratings.csv" \
|
exp/utils.py
CHANGED
@@ -1,69 +1,8 @@
|
|
1 |
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
-
import dgl
|
4 |
-
import torch
|
5 |
|
6 |
|
7 |
def normalize_embeddings(embeddings):
|
8 |
embeddings_norm = np.linalg.norm(embeddings, axis=1)
|
9 |
nonzero_embeddings = embeddings_norm > 0.0
|
10 |
embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
|
11 |
-
return embeddings
|
12 |
-
|
13 |
-
|
14 |
-
def prepare_graphs(items_path, ratings_path):
|
15 |
-
items = pd.read_csv(items_path)
|
16 |
-
ratings = pd.read_csv(ratings_path)
|
17 |
-
|
18 |
-
n_users = np.max(ratings["user_id"].unique()) + 1
|
19 |
-
item_ids = torch.tensor(sorted(items["item_id"].unique()))
|
20 |
-
|
21 |
-
edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
|
22 |
-
reverse_edges = (edges[1], edges[0])
|
23 |
-
|
24 |
-
bipartite_graph = dgl.heterograph(
|
25 |
-
data_dict={
|
26 |
-
("User", "UserItem", "Item"): edges,
|
27 |
-
("Item", "ItemUser", "User"): reverse_edges
|
28 |
-
},
|
29 |
-
num_nodes_dict={
|
30 |
-
"User": n_users,
|
31 |
-
"Item": len(item_ids)
|
32 |
-
}
|
33 |
-
)
|
34 |
-
graph = dgl.to_homogeneous(bipartite_graph)
|
35 |
-
graph = dgl.add_self_loop(graph)
|
36 |
-
return bipartite_graph, graph
|
37 |
-
|
38 |
-
|
39 |
-
def extract_item_embeddings(node_embeddings, bipartite_graph, graph):
|
40 |
-
item_ntype = bipartite_graph.ntypes.index("Item")
|
41 |
-
item_mask = graph.ndata[dgl.NTYPE] == item_ntype
|
42 |
-
item_embeddings = node_embeddings[item_mask]
|
43 |
-
original_ids = graph.ndata[dgl.NID][item_mask]
|
44 |
-
item_embeddings = item_embeddings[torch.argsort(original_ids)]
|
45 |
-
return item_embeddings.cpu().numpy()
|
46 |
-
|
47 |
-
|
48 |
-
class LRSchedule:
|
49 |
-
def __init__(self, total_steps, warmup_steps, final_factor):
|
50 |
-
self._total_steps = total_steps
|
51 |
-
self._warmup_steps = warmup_steps
|
52 |
-
self._final_factor = final_factor
|
53 |
-
|
54 |
-
def __call__(self, step):
|
55 |
-
if step >= self._total_steps:
|
56 |
-
return self._final_factor
|
57 |
-
|
58 |
-
if self._warmup_steps > 0:
|
59 |
-
warmup_factor = step / self._warmup_steps
|
60 |
-
else:
|
61 |
-
warmup_factor = 1.0
|
62 |
-
|
63 |
-
steps_after_warmup = step - self._warmup_steps
|
64 |
-
total_steps_after_warmup = self._total_steps - self._warmup_steps
|
65 |
-
after_warmup_factor = 1 \
|
66 |
-
- (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
|
67 |
-
|
68 |
-
factor = min(warmup_factor, after_warmup_factor)
|
69 |
-
return min(max(factor, 0), 1)
|
|
|
1 |
import numpy as np
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def normalize_embeddings(embeddings):
|
5 |
embeddings_norm = np.linalg.norm(embeddings, axis=1)
|
6 |
nonzero_embeddings = embeddings_norm > 0.0
|
7 |
embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
|
8 |
+
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|