erermeev-d commited on
Commit
c746c39
Β·
1 Parent(s): a4da241

Refactored experiments code

Browse files
exp/gnn/__init__.py ADDED
File without changes
exp/gnn/loss.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ ### Based on https://arxiv.org/pdf/2205.03169
5
+ def nt_xent_loss(sim, temperature):
6
+ sim = sim / temperature
7
+ n = sim.shape[0] // 2 # n = |user_batch|
8
+
9
+ aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
10
+
11
+ mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
12
+ sim = torch.where(mask, -torch.inf, sim)
13
+ sim = sim[:n, :]
14
+ distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
15
+
16
+ loss = aligment_loss + distribution_loss
17
+ return loss
exp/gnn/model.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import dgl
3
+
4
+
5
+ class GNNLayer(torch.nn.Module):
6
+ def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
7
+ super().__init__()
8
+ self._skip_connection = skip_connection
9
+ self._bidirectional = bidirectional
10
+
11
+ self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
12
+ self._activation = torch.nn.ReLU()
13
+
14
+ if bidirectional:
15
+ self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
16
+ self._activation_rev = torch.nn.ReLU()
17
+
18
+ def forward(self, graph, x):
19
+ edge_weights = graph.edata["weights"]
20
+
21
+ y = self._activation(self._conv(graph, x, edge_weights))
22
+ if self._bidirectional:
23
+ reversed_graph = dgl.reverse(graph, copy_edata=True)
24
+ edge_weights = reversed_graph.edata["weights"]
25
+ y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
26
+
27
+ if self._skip_connection:
28
+ return x + y
29
+ else:
30
+ return y
31
+
32
+
33
+ class GNNModel(torch.nn.Module):
34
+ def __init__(
35
+ self,
36
+ bipartite_graph,
37
+ text_embeddings,
38
+ num_layers,
39
+ hidden_dim,
40
+ aggregator_type,
41
+ skip_connection,
42
+ bidirectional,
43
+ num_traversals,
44
+ termination_prob,
45
+ num_random_walks,
46
+ num_neighbor,
47
+ ):
48
+ super().__init__()
49
+
50
+ self._bipartite_graph = bipartite_graph
51
+ self._text_embeddings = text_embeddings
52
+
53
+ self._sampler = dgl.sampling.PinSAGESampler(
54
+ bipartite_graph, "Item", "User", num_traversals,
55
+ termination_prob, num_random_walks, num_neighbor)
56
+
57
+ self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
58
+
59
+ self._layers = torch.nn.ModuleList()
60
+ for _ in range(num_layers):
61
+ self._layers.append(GNNLayer(
62
+ hidden_dim, aggregator_type, skip_connection, bidirectional))
63
+
64
+ def _sample_subraph(self, frontier_ids):
65
+ num_layers = len(self._layers)
66
+ device = self._bipartite_graph.device
67
+
68
+ subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
69
+ prev_ids = set()
70
+ weights = []
71
+
72
+ for _ in range(num_layers):
73
+ frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
74
+ new_sample = self._sampler(frontier_ids)
75
+ new_weights = new_sample.edata["weights"]
76
+ new_edges = new_sample.edges()
77
+
78
+ subgraph.add_edges(*new_edges)
79
+ weights.append(new_weights)
80
+
81
+ prev_ids |= set(frontier_ids.cpu().tolist())
82
+ frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
83
+ frontier_ids = list(frontier_ids - prev_ids)
84
+
85
+ subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
86
+ return subgraph
87
+
88
+ def forward(self, ids):
89
+ ### Sample subgraph
90
+ sampled_subgraph = self._sample_subraph(ids)
91
+ sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
92
+
93
+ ### Encode text embeddings
94
+ text_embeddings = self._text_embeddings[
95
+ sampled_subgraph.ndata[dgl.NID]]
96
+ features = self._text_encoder(text_embeddings)
97
+
98
+ ### GNN goes brr...
99
+ for layer in self._layers:
100
+ features = layer(sampled_subgraph, features)
101
+
102
+ ### Select features for initial ids
103
+ # TODO: write it more efficiently?
104
+ matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
105
+ ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
106
+ features = features[ids_in_subgraph]
107
+
108
+ ### Normalize and return
109
+ features = features / torch.linalg.norm(features, dim=1, keepdim=True)
110
+ return features
exp/{gnn.py β†’ gnn/train.py} RENAMED
@@ -9,159 +9,14 @@ import torch
9
  import wandb
10
  from tqdm.auto import tqdm
11
 
12
- from exp.utils import prepare_graphs, normalize_embeddings, LRSchedule
13
  from exp.prepare_recsys import prepare_recsys
14
  from exp.evaluate import evaluate_recsys
15
-
16
-
17
- class GNNLayer(torch.nn.Module):
18
- def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
19
- super().__init__()
20
- self._skip_connection = skip_connection
21
- self._bidirectional = bidirectional
22
-
23
- self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
24
- self._activation = torch.nn.ReLU()
25
-
26
- if bidirectional:
27
- self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
28
- self._activation_rev = torch.nn.ReLU()
29
-
30
- def forward(self, graph, x):
31
- edge_weights = graph.edata["weights"]
32
-
33
- y = self._activation(self._conv(graph, x, edge_weights))
34
- if self._bidirectional:
35
- reversed_graph = dgl.reverse(graph, copy_edata=True)
36
- edge_weights = reversed_graph.edata["weights"]
37
- y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
38
-
39
- if self._skip_connection:
40
- return x + y
41
- else:
42
- return y
43
-
44
-
45
- class GNNModel(torch.nn.Module):
46
- def __init__(
47
- self,
48
- bipartite_graph,
49
- text_embeddings,
50
- num_layers,
51
- hidden_dim,
52
- aggregator_type,
53
- skip_connection,
54
- bidirectional,
55
- num_traversals,
56
- termination_prob,
57
- num_random_walks,
58
- num_neighbor,
59
- ):
60
- super().__init__()
61
-
62
- self._bipartite_graph = bipartite_graph
63
- self._text_embeddings = text_embeddings
64
-
65
- self._sampler = dgl.sampling.PinSAGESampler(
66
- bipartite_graph, "Item", "User", num_traversals,
67
- termination_prob, num_random_walks, num_neighbor)
68
-
69
- self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
70
-
71
- self._layers = torch.nn.ModuleList()
72
- for _ in range(num_layers):
73
- self._layers.append(GNNLayer(
74
- hidden_dim, aggregator_type, skip_connection, bidirectional))
75
-
76
- def _sample_subraph(self, frontier_ids):
77
- num_layers = len(self._layers)
78
- device = self._bipartite_graph.device
79
-
80
- subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
81
- prev_ids = set()
82
- weights = []
83
-
84
- for _ in range(num_layers):
85
- frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
86
- new_sample = self._sampler(frontier_ids)
87
- new_weights = new_sample.edata["weights"]
88
- new_edges = new_sample.edges()
89
-
90
- subgraph.add_edges(*new_edges)
91
- weights.append(new_weights)
92
-
93
- prev_ids |= set(frontier_ids.cpu().tolist())
94
- frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
95
- frontier_ids = list(frontier_ids - prev_ids)
96
-
97
- subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
98
- return subgraph
99
-
100
- def forward(self, ids):
101
- ### Sample subgraph
102
- sampled_subgraph = self._sample_subraph(ids)
103
- sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
104
-
105
- ### Encode text embeddings
106
- text_embeddings = self._text_embeddings[
107
- sampled_subgraph.ndata[dgl.NID]]
108
- features = self._text_encoder(text_embeddings)
109
-
110
- ### GNN goes brr...
111
- for layer in self._layers:
112
- features = layer(sampled_subgraph, features)
113
-
114
- ### Select features for initial ids
115
- # TODO: write it more efficiently?
116
- matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
117
- ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
118
- features = features[ids_in_subgraph]
119
-
120
- ### Normalize and return
121
- features = features / torch.linalg.norm(features, dim=1, keepdim=True)
122
- return features
123
-
124
-
125
- ### Based on https://arxiv.org/pdf/2205.03169
126
- def nt_xent_loss(sim, temperature):
127
- sim = sim / temperature
128
- n = sim.shape[0] // 2 # n = |user_batch|
129
-
130
- aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
131
-
132
- mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
133
- sim = torch.where(mask, -torch.inf, sim)
134
- sim = sim[:n, :]
135
- distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
136
-
137
- loss = aligment_loss + distribution_loss
138
- return loss
139
-
140
-
141
- def sample_item_batch(user_batch, bipartite_graph):
142
- sampled_edges = dgl.sampling.sample_neighbors(
143
- bipartite_graph, {"User": user_batch}, fanout=2
144
- ).edges(etype="ItemUser")
145
- item_batch = sampled_edges[0]
146
- item_batch = item_batch[torch.argsort(sampled_edges[1])]
147
- item_batch = item_batch.reshape(-1, 2)
148
- item_batch = item_batch.T
149
- return item_batch
150
-
151
-
152
- @torch.no_grad()
153
- def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
154
- model.eval()
155
- item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
156
- for items_batch in tqdm(torch.utils.data.DataLoader(
157
- torch.arange(bipartite_graph.num_nodes("Item")),
158
- batch_size=batch_size,
159
- shuffle=True
160
- )):
161
- item_embeddings[items_batch] = model(items_batch.to(device))
162
-
163
- item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
164
- return item_embeddings
165
 
166
 
167
  def prepare_gnn_embeddings(
@@ -228,9 +83,9 @@ def prepare_gnn_embeddings(
228
  lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
229
 
230
  ### Train loop
231
- model.train()
232
  for epoch in range(num_epochs):
233
  ### Train
 
234
  for user_batch in tqdm(dataloader):
235
  item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
236
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
@@ -258,8 +113,6 @@ def prepare_gnn_embeddings(
258
  print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
259
  if use_wandb:
260
  wandb.log(metrics)
261
-
262
-
263
 
264
  if use_wandb:
265
  wandb.finish()
 
9
  import wandb
10
  from tqdm.auto import tqdm
11
 
12
+ from exp.utils import normalize_embeddings
13
  from exp.prepare_recsys import prepare_recsys
14
  from exp.evaluate import evaluate_recsys
15
+ from exp.gnn.model import GNNModel
16
+ from exp.gnn.loss import nt_xent_loss
17
+ from exp.gnn.utils import (
18
+ prepare_graphs, LRSchedule,
19
+ sample_item_batch, inference_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def prepare_gnn_embeddings(
 
83
  lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
84
 
85
  ### Train loop
 
86
  for epoch in range(num_epochs):
87
  ### Train
88
+ model.train()
89
  for user_batch in tqdm(dataloader):
90
  item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
91
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
 
113
  print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
114
  if use_wandb:
115
  wandb.log(metrics)
 
 
116
 
117
  if use_wandb:
118
  wandb.finish()
exp/gnn/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import dgl
3
+ import pandas as pd
4
+ import numpy as np
5
+ from tqdm.auto import tqdm
6
+
7
+ from exp.utils import normalize_embeddings
8
+
9
+
10
+ class LRSchedule:
11
+ def __init__(self, total_steps, warmup_steps, final_factor):
12
+ self._total_steps = total_steps
13
+ self._warmup_steps = warmup_steps
14
+ self._final_factor = final_factor
15
+
16
+ def __call__(self, step):
17
+ if step >= self._total_steps:
18
+ return self._final_factor
19
+
20
+ if self._warmup_steps > 0:
21
+ warmup_factor = step / self._warmup_steps
22
+ else:
23
+ warmup_factor = 1.0
24
+
25
+ steps_after_warmup = step - self._warmup_steps
26
+ total_steps_after_warmup = self._total_steps - self._warmup_steps
27
+ after_warmup_factor = 1 \
28
+ - (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
29
+
30
+ factor = min(warmup_factor, after_warmup_factor)
31
+ return min(max(factor, 0), 1)
32
+
33
+
34
+ def prepare_graphs(items_path, ratings_path):
35
+ items = pd.read_csv(items_path)
36
+ ratings = pd.read_csv(ratings_path)
37
+
38
+ n_users = np.max(ratings["user_id"].unique()) + 1
39
+ item_ids = torch.tensor(sorted(items["item_id"].unique()))
40
+
41
+ edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
42
+ reverse_edges = (edges[1], edges[0])
43
+
44
+ bipartite_graph = dgl.heterograph(
45
+ data_dict={
46
+ ("User", "UserItem", "Item"): edges,
47
+ ("Item", "ItemUser", "User"): reverse_edges
48
+ },
49
+ num_nodes_dict={
50
+ "User": n_users,
51
+ "Item": len(item_ids)
52
+ }
53
+ )
54
+ graph = dgl.to_homogeneous(bipartite_graph)
55
+ graph = dgl.add_self_loop(graph)
56
+ return bipartite_graph, graph
57
+
58
+
59
+ def sample_item_batch(user_batch, bipartite_graph):
60
+ sampled_edges = dgl.sampling.sample_neighbors(
61
+ bipartite_graph, {"User": user_batch}, fanout=2
62
+ ).edges(etype="ItemUser")
63
+ item_batch = sampled_edges[0]
64
+ item_batch = item_batch[torch.argsort(sampled_edges[1])]
65
+ item_batch = item_batch.reshape(-1, 2)
66
+ item_batch = item_batch.T
67
+ return item_batch
68
+
69
+
70
+ @torch.no_grad()
71
+ def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
72
+ model.eval()
73
+ item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
74
+ for items_batch in tqdm(torch.utils.data.DataLoader(
75
+ torch.arange(bipartite_graph.num_nodes("Item")),
76
+ batch_size=batch_size,
77
+ shuffle=True
78
+ )):
79
+ item_embeddings[items_batch] = model(items_batch.to(device))
80
+
81
+ item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
82
+ return item_embeddings
exp/{prepare_embeddings.sh β†’ pipeline.sh} RENAMED
@@ -17,7 +17,7 @@ PYTHONPATH=. python exp/sbert.py \
17
  --embeddings_savepath "$save_directory/text_embeddings.npy" \
18
  --device $device
19
 
20
- PYTHONPATH=. python exp/gnn.py \
21
  --items_path "$save_directory/items.csv" \
22
  --train_ratings_path "$save_directory/train_ratings.csv" \
23
  --val_ratings_path "$save_directory/val_ratings.csv" \
 
17
  --embeddings_savepath "$save_directory/text_embeddings.npy" \
18
  --device $device
19
 
20
+ PYTHONPATH=. python exp/gnn/train.py \
21
  --items_path "$save_directory/items.csv" \
22
  --train_ratings_path "$save_directory/train_ratings.csv" \
23
  --val_ratings_path "$save_directory/val_ratings.csv" \
exp/utils.py CHANGED
@@ -1,69 +1,8 @@
1
  import numpy as np
2
- import pandas as pd
3
- import dgl
4
- import torch
5
 
6
 
7
  def normalize_embeddings(embeddings):
8
  embeddings_norm = np.linalg.norm(embeddings, axis=1)
9
  nonzero_embeddings = embeddings_norm > 0.0
10
  embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
11
- return embeddings
12
-
13
-
14
- def prepare_graphs(items_path, ratings_path):
15
- items = pd.read_csv(items_path)
16
- ratings = pd.read_csv(ratings_path)
17
-
18
- n_users = np.max(ratings["user_id"].unique()) + 1
19
- item_ids = torch.tensor(sorted(items["item_id"].unique()))
20
-
21
- edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
22
- reverse_edges = (edges[1], edges[0])
23
-
24
- bipartite_graph = dgl.heterograph(
25
- data_dict={
26
- ("User", "UserItem", "Item"): edges,
27
- ("Item", "ItemUser", "User"): reverse_edges
28
- },
29
- num_nodes_dict={
30
- "User": n_users,
31
- "Item": len(item_ids)
32
- }
33
- )
34
- graph = dgl.to_homogeneous(bipartite_graph)
35
- graph = dgl.add_self_loop(graph)
36
- return bipartite_graph, graph
37
-
38
-
39
- def extract_item_embeddings(node_embeddings, bipartite_graph, graph):
40
- item_ntype = bipartite_graph.ntypes.index("Item")
41
- item_mask = graph.ndata[dgl.NTYPE] == item_ntype
42
- item_embeddings = node_embeddings[item_mask]
43
- original_ids = graph.ndata[dgl.NID][item_mask]
44
- item_embeddings = item_embeddings[torch.argsort(original_ids)]
45
- return item_embeddings.cpu().numpy()
46
-
47
-
48
- class LRSchedule:
49
- def __init__(self, total_steps, warmup_steps, final_factor):
50
- self._total_steps = total_steps
51
- self._warmup_steps = warmup_steps
52
- self._final_factor = final_factor
53
-
54
- def __call__(self, step):
55
- if step >= self._total_steps:
56
- return self._final_factor
57
-
58
- if self._warmup_steps > 0:
59
- warmup_factor = step / self._warmup_steps
60
- else:
61
- warmup_factor = 1.0
62
-
63
- steps_after_warmup = step - self._warmup_steps
64
- total_steps_after_warmup = self._total_steps - self._warmup_steps
65
- after_warmup_factor = 1 \
66
- - (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
67
-
68
- factor = min(warmup_factor, after_warmup_factor)
69
- return min(max(factor, 0), 1)
 
1
  import numpy as np
 
 
 
2
 
3
 
4
  def normalize_embeddings(embeddings):
5
  embeddings_norm = np.linalg.norm(embeddings, axis=1)
6
  nonzero_embeddings = embeddings_norm > 0.0
7
  embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
8
+ return embeddings