erermeev-d commited on
Commit
b8f4763
·
1 Parent(s): d4852d9

Updated experiments code

Browse files
app/database.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import io
2
  import sqlite3
3
 
@@ -26,8 +27,8 @@ class ItemDatabase:
26
  rows = c.fetchall()[:n_items]
27
  return [row[0] for row in rows]
28
 
 
29
  def get_item(self, item_id):
30
-
31
  with self._connect() as conn:
32
  c = conn.cursor()
33
  c.row_factory = sqlite3.Row
 
1
+ import functools
2
  import io
3
  import sqlite3
4
 
 
27
  rows = c.fetchall()[:n_items]
28
  return [row[0] for row in rows]
29
 
30
+ @functools.lru_cache(maxsize=2**14)
31
  def get_item(self, item_id):
 
32
  with self._connect() as conn:
33
  c = conn.cursor()
34
  c.row_factory = sqlite3.Row
exp/deepwalk.py DELETED
@@ -1,80 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import dgl
7
- import torch
8
- import wandb
9
- from tqdm.auto import tqdm
10
-
11
- from utils import prepare_graphs, extract_item_embeddings, normalize_embeddings
12
-
13
-
14
- def prepare_deepwalk_embeddings(
15
- items_path,
16
- ratings_path,
17
- embeddings_savepath,
18
- emb_dim,
19
- window_size,
20
- batch_size,
21
- lr,
22
- num_epochs,
23
- device,
24
- wandb_name,
25
- use_wandb
26
- ):
27
- ### Prepare graph
28
- bipartite_graph, graph = prepare_graphs(items_path, ratings_path)
29
- bipartite_graph = bipartite_graph.to(device)
30
- graph = graph.to(device)
31
-
32
- ### Run deepwalk
33
- if use_wandb:
34
- wandb.init(project="graph-recs-deepwalk", name=wandb_name)
35
-
36
- model = dgl.nn.DeepWalk(graph, emb_dim=emb_dim, window_size=window_size)
37
- model = model.to(device)
38
- dataloader = torch.utils.data.DataLoader(
39
- torch.arange(graph.num_nodes()),
40
- batch_size=batch_size,
41
- shuffle=True,
42
- collate_fn=model.sample)
43
-
44
- optimizer = torch.optim.SparseAdam(model.parameters(), lr=lr)
45
- for epoch in range(num_epochs):
46
- for batch_walk in tqdm(dataloader):
47
- loss = model(batch_walk)
48
- if use_wandb:
49
- wandb.log({"loss": loss.item()})
50
- optimizer.zero_grad()
51
- loss.backward()
52
- optimizer.step()
53
-
54
- if use_wandb:
55
- wandb.finish()
56
-
57
- node_embeddings = model.node_embed.weight.detach().to(device)
58
-
59
- ### Extract & save item embeddings
60
- item_embeddings = extract_item_embeddings(node_embeddings, bipartite_graph, graph)
61
- item_embeddings = normalize_embeddings(item_embeddings)
62
- np.save(embeddings_savepath, item_embeddings)
63
-
64
-
65
- if __name__ == "__main__":
66
- parser = argparse.ArgumentParser(description="Prepare DeepWalk embeddings.")
67
- parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
68
- parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file.")
69
- parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where embeddings will be saved.")
70
- parser.add_argument("--emb_dim", type=int, default=384, help="Dimensionality of the embeddings.")
71
- parser.add_argument("--window_size", type=int, default=4, help="Window size for the DeepWalk algorithm.")
72
- parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training.")
73
- parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for training.")
74
- parser.add_argument("--num_epochs", type=int, default=2, help="Number of epochs for training.")
75
- parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
76
- parser.add_argument("--wandb_name", type=str, help="Name for WandB run.")
77
- parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
78
- args = parser.parse_args()
79
-
80
- prepare_deepwalk_embeddings(**vars(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/evaluate.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import json
 
3
 
4
  import pandas as pd
5
  import numpy as np
@@ -9,16 +10,26 @@ from app.recommendations import RecommenderSystem
9
 
10
  def precision_at_k(recommended_items, relevant_items, k):
11
  recommended_at_k = set(recommended_items[:k])
12
- relevant_set = set(relevant_items)
13
- return len(recommended_at_k & relevant_set) / k
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def evaluate_recsys(
17
- metrics_savepath,
18
  val_ratings_path,
19
  faiss_index_path,
20
  db_path,
21
- n_recommend_items,
 
22
  ):
23
  recsys = RecommenderSystem(
24
  faiss_index_path=faiss_index_path,
@@ -30,16 +41,14 @@ def evaluate_recsys(
30
 
31
 
32
  metric_arrays = {
33
- "precision@1": [],
34
- "precision@3": [],
35
- "precision@10": []
36
  }
37
 
38
- for item_group in grouped_items:
39
  if len(item_group) == 1:
40
  continue
41
 
42
- ### Precision@k is computed for each edge.
43
  ### We will first aggregate it over all edges for user
44
  ### And after that - aggregate over all users
45
  user_metric_arrays = dict()
@@ -50,12 +59,8 @@ def evaluate_recsys(
50
  recommend_items = list(recsys.recommend_items(item, n_recommend_items))
51
  relevant_items = set(item_group) - {item}
52
 
53
- user_metric_arrays["precision@1"].append(
54
- precision_at_k(recommend_items, relevant_items, k=1))
55
- user_metric_arrays["precision@3"].append(
56
- precision_at_k(recommend_items, relevant_items, k=3))
57
- user_metric_arrays["precision@10"].append(
58
- precision_at_k(recommend_items, relevant_items, k=10))
59
 
60
  for metric in metric_arrays.keys():
61
  user_metric = np.mean(user_metric_arrays[metric])
@@ -65,9 +70,11 @@ def evaluate_recsys(
65
  for metric, array in metric_arrays.items():
66
  metrics[metric] = np.mean(array)
67
 
68
- with open(metrics_savepath, "w") as f:
69
- json.dump(metrics, f)
70
- print(f"Saved metrics to {metrics_savepath}")
 
 
71
 
72
 
73
  if __name__ == "__main__":
 
1
  import argparse
2
  import json
3
+ from tqdm.auto import tqdm
4
 
5
  import pandas as pd
6
  import numpy as np
 
10
 
11
  def precision_at_k(recommended_items, relevant_items, k):
12
  recommended_at_k = set(recommended_items[:k])
13
+ return len(recommended_at_k & relevant_items) / k
14
+
15
+
16
+ def average_precision_at_k(recommended_items, relevant_items, k):
17
+ relevant_items = set(relevant_items)
18
+
19
+ apk_sum = 0.0
20
+ for m, item in enumerate(recommended_items):
21
+ if item in relevant_items:
22
+ apk_sum += precision_at_k(recommended_items, relevant_items, m+1)
23
+
24
+ return apk_sum / min(k, len(relevant_items))
25
 
26
 
27
  def evaluate_recsys(
 
28
  val_ratings_path,
29
  faiss_index_path,
30
  db_path,
31
+ n_recommend_items=10,
32
+ metrics_savepath=None
33
  ):
34
  recsys = RecommenderSystem(
35
  faiss_index_path=faiss_index_path,
 
41
 
42
 
43
  metric_arrays = {
44
+ "ap@5": [],
 
 
45
  }
46
 
47
+ for item_group in tqdm(grouped_items):
48
  if len(item_group) == 1:
49
  continue
50
 
51
+ ### Metrics are computed for each edge.
52
  ### We will first aggregate it over all edges for user
53
  ### And after that - aggregate over all users
54
  user_metric_arrays = dict()
 
59
  recommend_items = list(recsys.recommend_items(item, n_recommend_items))
60
  relevant_items = set(item_group) - {item}
61
 
62
+ user_metric_arrays["ap@5"].append(
63
+ average_precision_at_k(recommend_items, relevant_items, k=5))
 
 
 
 
64
 
65
  for metric in metric_arrays.keys():
66
  user_metric = np.mean(user_metric_arrays[metric])
 
70
  for metric, array in metric_arrays.items():
71
  metrics[metric] = np.mean(array)
72
 
73
+ if metrics_savepath is not None:
74
+ with open(metrics_savepath, "w") as f:
75
+ json.dump(metrics, f)
76
+ print(f"Saved metrics to {metrics_savepath}")
77
+ return metrics
78
 
79
 
80
  if __name__ == "__main__":
exp/gnn.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import os
 
3
 
4
  import numpy as np
5
  import pandas as pd
@@ -8,7 +9,9 @@ import torch
8
  import wandb
9
  from tqdm.auto import tqdm
10
 
11
- from utils import prepare_graphs, normalize_embeddings, LRSchedule
 
 
12
 
13
 
14
  class GNNLayer(torch.nn.Module):
@@ -41,7 +44,6 @@ class GNNModel(torch.nn.Module):
41
  self,
42
  bipartite_graph,
43
  text_embeddings,
44
- deepwalk_embeddings,
45
  num_layers,
46
  hidden_dim,
47
  aggregator_type,
@@ -56,14 +58,12 @@ class GNNModel(torch.nn.Module):
56
 
57
  self._bipartite_graph = bipartite_graph
58
  self._text_embeddings = text_embeddings
59
- self._deepwalk_embeddings = deepwalk_embeddings
60
 
61
  self._sampler = dgl.sampling.PinSAGESampler(
62
  bipartite_graph, "Item", "User", num_traversals,
63
  termination_prob, num_random_walks, num_neighbor)
64
 
65
  self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
66
- self._deepwalk_encoder = torch.nn.Linear(deepwalk_embeddings.shape[-1], hidden_dim)
67
 
68
  self._layers = torch.nn.ModuleList()
69
  for _ in range(num_layers):
@@ -92,13 +92,10 @@ class GNNModel(torch.nn.Module):
92
  sampled_subgraph = self._sample_subraph(ids)
93
  sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
94
 
95
- ### Encode text & DeepWalk embeddings
96
  text_embeddings = self._text_embeddings[
97
  sampled_subgraph.ndata[dgl.NID]]
98
- deepwalk_embeddings = self._deepwalk_embeddings[
99
- sampled_subgraph.ndata[dgl.NID]]
100
- features = self._text_encoder(text_embeddings) \
101
- + self._deepwalk_encoder(deepwalk_embeddings)
102
 
103
  ### GNN goes brr...
104
  for layer in self._layers:
@@ -142,12 +139,27 @@ def sample_item_batch(user_batch, bipartite_graph):
142
  return item_batch
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def prepare_gnn_embeddings(
146
  # Paths
147
  items_path,
148
- ratings_path,
 
149
  text_embeddings_path,
150
- deepwalk_embeddings_path,
151
  embeddings_savepath,
152
  # Learning hyperparameters
153
  temperature,
@@ -165,12 +177,13 @@ def prepare_gnn_embeddings(
165
  num_random_walks,
166
  num_neighbor,
167
  # Misc
 
168
  device,
169
  wandb_name,
170
  use_wandb,
171
  ):
172
  ### Prepare graph
173
- bipartite_graph, _ = prepare_graphs(items_path, ratings_path)
174
  bipartite_graph = bipartite_graph.to(device)
175
 
176
  ### Init wandb
@@ -179,11 +192,9 @@ def prepare_gnn_embeddings(
179
 
180
  ### Prepare model
181
  text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
182
- deepwalk_embeddings = torch.tensor(np.load(deepwalk_embeddings_path)).to(device)
183
  model = GNNModel(
184
  bipartite_graph=bipartite_graph,
185
  text_embeddings=text_embeddings,
186
- deepwalk_embeddings=deepwalk_embeddings,
187
  num_layers=num_layers,
188
  hidden_dim=hidden_dim,
189
  aggregator_type=aggregator_type,
@@ -214,6 +225,7 @@ def prepare_gnn_embeddings(
214
  ### Train loop
215
  model.train()
216
  for epoch in range(num_epochs):
 
217
  for user_batch in tqdm(dataloader):
218
  item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
219
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
@@ -226,24 +238,29 @@ def prepare_gnn_embeddings(
226
  loss.backward()
227
  optimizer.step()
228
  lr_scheduler.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  if use_wandb:
231
  wandb.finish()
232
 
233
  ### Process full dataset
234
- model.eval()
235
- with torch.no_grad():
236
- hidden_dim = text_embeddings.shape[-1]
237
- item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
238
- for items_batch in tqdm(torch.utils.data.DataLoader(
239
- torch.arange(bipartite_graph.num_nodes("Item")),
240
- batch_size=batch_size,
241
- shuffle=True
242
- )):
243
- item_embeddings[items_batch] = model(items_batch.to(device))
244
-
245
- ### Extract & save item embeddings
246
- item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
247
  np.save(embeddings_savepath, item_embeddings)
248
 
249
 
@@ -252,9 +269,9 @@ if __name__ == "__main__":
252
 
253
  # Paths
254
  parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
255
- parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file")
 
256
  parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
257
- parser.add_argument("--deepwalk_embeddings_path", type=str, required=True, help="Path to the deepwalk embeddings file")
258
  parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
259
 
260
  # Learning hyperparameters
@@ -265,7 +282,7 @@ if __name__ == "__main__":
265
 
266
  # Model hyperparameters
267
  parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
268
- parser.add_argument("--hidden_dim", type=int, default=384, help="Hidden dimension size")
269
  parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
270
  parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
271
  parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
@@ -275,6 +292,7 @@ if __name__ == "__main__":
275
  parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
276
 
277
  # Misc
 
278
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
279
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
280
  parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
 
1
  import argparse
2
  import os
3
+ import tempfile
4
 
5
  import numpy as np
6
  import pandas as pd
 
9
  import wandb
10
  from tqdm.auto import tqdm
11
 
12
+ from exp.utils import prepare_graphs, normalize_embeddings, LRSchedule
13
+ from exp.prepare_recsys import prepare_recsys
14
+ from exp.evaluate import evaluate_recsys
15
 
16
 
17
  class GNNLayer(torch.nn.Module):
 
44
  self,
45
  bipartite_graph,
46
  text_embeddings,
 
47
  num_layers,
48
  hidden_dim,
49
  aggregator_type,
 
58
 
59
  self._bipartite_graph = bipartite_graph
60
  self._text_embeddings = text_embeddings
 
61
 
62
  self._sampler = dgl.sampling.PinSAGESampler(
63
  bipartite_graph, "Item", "User", num_traversals,
64
  termination_prob, num_random_walks, num_neighbor)
65
 
66
  self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
 
67
 
68
  self._layers = torch.nn.ModuleList()
69
  for _ in range(num_layers):
 
92
  sampled_subgraph = self._sample_subraph(ids)
93
  sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
94
 
95
+ ### Encode text embeddings
96
  text_embeddings = self._text_embeddings[
97
  sampled_subgraph.ndata[dgl.NID]]
98
+ features = self._text_encoder(text_embeddings)
 
 
 
99
 
100
  ### GNN goes brr...
101
  for layer in self._layers:
 
139
  return item_batch
140
 
141
 
142
+ @torch.no_grad()
143
+ def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
144
+ model.eval()
145
+ item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
146
+ for items_batch in tqdm(torch.utils.data.DataLoader(
147
+ torch.arange(bipartite_graph.num_nodes("Item")),
148
+ batch_size=batch_size,
149
+ shuffle=True
150
+ )):
151
+ item_embeddings[items_batch] = model(items_batch.to(device))
152
+
153
+ item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
154
+ return item_embeddings
155
+
156
+
157
  def prepare_gnn_embeddings(
158
  # Paths
159
  items_path,
160
+ train_ratings_path,
161
+ val_ratings_path,
162
  text_embeddings_path,
 
163
  embeddings_savepath,
164
  # Learning hyperparameters
165
  temperature,
 
177
  num_random_walks,
178
  num_neighbor,
179
  # Misc
180
+ validate_every_n_epoch,
181
  device,
182
  wandb_name,
183
  use_wandb,
184
  ):
185
  ### Prepare graph
186
+ bipartite_graph, _ = prepare_graphs(items_path, train_ratings_path)
187
  bipartite_graph = bipartite_graph.to(device)
188
 
189
  ### Init wandb
 
192
 
193
  ### Prepare model
194
  text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
 
195
  model = GNNModel(
196
  bipartite_graph=bipartite_graph,
197
  text_embeddings=text_embeddings,
 
198
  num_layers=num_layers,
199
  hidden_dim=hidden_dim,
200
  aggregator_type=aggregator_type,
 
225
  ### Train loop
226
  model.train()
227
  for epoch in range(num_epochs):
228
+ ### Train
229
  for user_batch in tqdm(dataloader):
230
  item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
231
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
 
238
  loss.backward()
239
  optimizer.step()
240
  lr_scheduler.step()
241
+ ### Validation
242
+ if (validate_every_n_epoch is not None) and (((epoch + 1) % validate_every_n_epoch) == 0):
243
+ item_embeddings = inference_model(
244
+ model, bipartite_graph, batch_size, hidden_dim, device)
245
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
246
+ tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
247
+ np.save(tmp_embeddings_path, item_embeddings)
248
+ prepare_recsys(items_path, tmp_embeddings_path, tmp_dir_name)
249
+ metrics = evaluate_recsys(
250
+ val_ratings_path,
251
+ os.path.join(tmp_dir_name, "index.faiss"),
252
+ os.path.join(tmp_dir_name, "items.db"))
253
+ print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
254
+ if use_wandb:
255
+ wandb.log(metrics)
256
+
257
+
258
 
259
  if use_wandb:
260
  wandb.finish()
261
 
262
  ### Process full dataset
263
+ item_embeddings = inference_model(model, bipartite_graph, batch_size, hidden_dim, device)
 
 
 
 
 
 
 
 
 
 
 
 
264
  np.save(embeddings_savepath, item_embeddings)
265
 
266
 
 
269
 
270
  # Paths
271
  parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
272
+ parser.add_argument("--train_ratings_path", type=str, required=True, help="Path to the train ratings file")
273
+ parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
274
  parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
 
275
  parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
276
 
277
  # Learning hyperparameters
 
282
 
283
  # Model hyperparameters
284
  parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
285
+ parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
286
  parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
287
  parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
288
  parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
 
292
  parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
293
 
294
  # Misc
295
+ parser.add_argument("--validate_every_n_epoch", type=int, default=2, help="Perform RecSys validation every n train epochs.")
296
  parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
297
  parser.add_argument("--wandb_name", type=str, help="WandB run name")
298
  parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
exp/prepare_embeddings.sh CHANGED
@@ -10,14 +10,7 @@ echo Running on "$device".
10
  PYTHONPATH=. python exp/process_raw_data.py \
11
  --input_directory "$input_directory" \
12
  --save_directory "$save_directory" \
13
- --create_train_val_split
14
-
15
- PYTHONPATH=. python exp/deepwalk.py \
16
- --items_path "$save_directory/items.csv" \
17
- --ratings_path "$save_directory/train_ratings.csv" \
18
- --embeddings_savepath "$save_directory/deepwalk_embeddings.npy" \
19
- --device $device \
20
- --no_wandb
21
 
22
  PYTHONPATH=. python exp/sbert.py \
23
  --items_path "$save_directory/items.csv" \
@@ -26,25 +19,21 @@ PYTHONPATH=. python exp/sbert.py \
26
 
27
  PYTHONPATH=. python exp/gnn.py \
28
  --items_path "$save_directory/items.csv" \
29
- --ratings_path "$save_directory/train_ratings.csv" \
 
30
  --text_embeddings_path "$save_directory/text_embeddings.npy" \
31
- --deepwalk_embeddings_path "$save_directory/deepwalk_embeddings.npy" \
32
  --embeddings_savepath "$save_directory/embeddings.npy"\
33
  --device $device \
34
- --no_wandb
35
-
36
- PYTHONPATH=. python exp/prepare_index.py \
37
- --embeddings_path "$save_directory/embeddings.npy" \
38
- --save_path "$save_directory/index.faiss"
39
 
40
- PYTHONPATH=. python exp/prepare_db.py \
41
  --items_path "$save_directory/items.csv" \
42
  --embeddings_path "$save_directory/embeddings.npy" \
43
- --db_path "$save_directory/items.db"
44
 
45
  PYTHONPATH=. python exp/evaluate.py \
46
  --metrics_savepath "$save_directory/metrics.json" \
47
- --val_ratings_path "$save_directory/val_ratings.csv" \
48
  --faiss_index_path "$save_directory/index.faiss" \
49
  --db_path "$save_directory/items.db"
50
 
 
10
  PYTHONPATH=. python exp/process_raw_data.py \
11
  --input_directory "$input_directory" \
12
  --save_directory "$save_directory" \
13
+ --create_train_val_test_split
 
 
 
 
 
 
 
14
 
15
  PYTHONPATH=. python exp/sbert.py \
16
  --items_path "$save_directory/items.csv" \
 
19
 
20
  PYTHONPATH=. python exp/gnn.py \
21
  --items_path "$save_directory/items.csv" \
22
+ --train_ratings_path "$save_directory/train_ratings.csv" \
23
+ --val_ratings_path "$save_directory/val_ratings.csv" \
24
  --text_embeddings_path "$save_directory/text_embeddings.npy" \
 
25
  --embeddings_savepath "$save_directory/embeddings.npy"\
26
  --device $device \
27
+ --no_wandb
 
 
 
 
28
 
29
+ PYTHONPATH=. python exp/prepare_recsys.py \
30
  --items_path "$save_directory/items.csv" \
31
  --embeddings_path "$save_directory/embeddings.npy" \
32
+ --save_directory "$save_directory"
33
 
34
  PYTHONPATH=. python exp/evaluate.py \
35
  --metrics_savepath "$save_directory/metrics.json" \
36
+ --val_ratings_path "$save_directory/test_ratings.csv" \
37
  --faiss_index_path "$save_directory/index.faiss" \
38
  --db_path "$save_directory/items.db"
39
 
exp/prepare_index.py DELETED
@@ -1,20 +0,0 @@
1
- import argparse
2
-
3
- import faiss
4
- import numpy as np
5
-
6
-
7
- def build_index(embeddings_path, save_path, n_neighbors):
8
- embeddings = np.load(embeddings_path)
9
- index = faiss.IndexHNSWFlat(embeddings.shape[-1], 32)
10
- index.add(embeddings)
11
- faiss.write_index(index, save_path)
12
-
13
-
14
- if __name__ == "__main__":
15
- parser = argparse.ArgumentParser(description="Build an HNSW index from embeddings.")
16
- parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the embeddings file.")
17
- parser.add_argument("--save_path", type=str, required=True, help="Path to save the built index.")
18
- parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
19
- args = parser.parse_args()
20
- build_index(**vars(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
exp/{prepare_db.py → prepare_recsys.py} RENAMED
@@ -1,7 +1,9 @@
1
  import argparse
2
  import sqlite3
3
  import io
 
4
 
 
5
  import pandas as pd
6
  import numpy as np
7
 
@@ -23,11 +25,30 @@ def prepare_items_db(items_path, embeddings_path, db_path):
23
  items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
- parser = argparse.ArgumentParser(description="Prepare items database from a CSV file.")
 
28
  parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
29
  parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
30
- parser.add_argument("--db_path", required=True, type=str, help="Path to the SQLite database file.")
 
31
 
32
  args = parser.parse_args()
33
- prepare_items_db(**vars(args))
 
1
  import argparse
2
  import sqlite3
3
  import io
4
+ import os
5
 
6
+ import faiss
7
  import pandas as pd
8
  import numpy as np
9
 
 
25
  items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
26
 
27
 
28
+ def build_index(embeddings_path, save_path, n_neighbors):
29
+ embeddings = np.load(embeddings_path)
30
+ index = faiss.IndexHNSWFlat(embeddings.shape[-1], n_neighbors)
31
+ index.add(embeddings)
32
+ faiss.write_index(index, save_path)
33
+
34
+
35
+ def prepare_recsys(
36
+ items_path,
37
+ embeddings_path,
38
+ save_directory,
39
+ n_neighbors=32,
40
+ ):
41
+ prepare_items_db(items_path, embeddings_path, os.path.join(save_directory, "items.db"))
42
+ build_index(embeddings_path, os.path.join(save_directory, "index.faiss"), n_neighbors)
43
+
44
+
45
  if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser(description="Prepare items database and HNSW index from a CSV file and embeddings.")
47
+
48
  parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
49
  parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
50
+ parser.add_argument("--save_directory", required=True, type=str, help="Path to the save directory.")
51
+ parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
52
 
53
  args = parser.parse_args()
54
+ prepare_recsys(**vars(args))
exp/process_raw_data.py CHANGED
@@ -80,37 +80,46 @@ def process_raw_data_goodreads(input_directory, save_directory, positive_rating_
80
  ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
81
 
82
 
83
- def create_train_val_split(ratings_path, train_savepath, val_savepath, seed=42):
84
  ratings = pd.read_csv(ratings_path)
85
  user_ids = ratings["user_id"].unique()
86
 
87
  rng = np.random.default_rng(seed=seed)
88
- train_size = int(len(user_ids) * 0.9)
89
- train_indices = rng.choice(user_ids, size=train_size, replace=False)
 
 
 
 
 
90
 
91
  train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
92
- val_data = ratings.loc[~ratings["user_id"].isin(train_indices)]
 
93
 
94
  print(f"Train size: {len(train_data)}.")
95
  print(f"Validation size: {len(val_data)}.")
 
96
 
97
  train_data.to_csv(train_savepath, index=False)
98
  val_data.to_csv(val_savepath, index=False)
 
99
 
100
 
101
  if __name__ == "__main__":
102
  parser = argparse.ArgumentParser(description="Process raw data.")
103
  parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
104
  parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
105
- parser.add_argument("--create_train_val_split", action="store_true", help="Flag to indicate whether to create a train-validation split.")
106
  args = parser.parse_args()
107
 
108
  print("Processing raw data...")
109
  process_raw_data_goodreads(args.input_directory, args.save_directory)
110
- if args.create_train_val_split:
111
- create_train_val_split(
112
  os.path.join(args.save_directory, "ratings.csv"),
113
  os.path.join(args.save_directory, "train_ratings.csv"),
114
- os.path.join(args.save_directory, "val_ratings.csv")
 
115
  )
116
  print("The raw data has been successfully processed.")
 
80
  ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
81
 
82
 
83
+ def create_train_val_test_split(ratings_path, train_savepath, val_savepath, test_savepath, seed=42):
84
  ratings = pd.read_csv(ratings_path)
85
  user_ids = ratings["user_id"].unique()
86
 
87
  rng = np.random.default_rng(seed=seed)
88
+ train_size = int(len(user_ids) * 0.7)
89
+ val_size = int(len(user_ids) * 0.15)
90
+
91
+ indices = rng.permutation(user_ids)
92
+ train_indices = indices[:train_size]
93
+ val_indices = indices[train_size:train_size+val_size]
94
+ test_indices = indices[train_size+val_size:]
95
 
96
  train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
97
+ val_data = ratings.loc[ratings["user_id"].isin(val_indices)]
98
+ test_data = ratings.loc[ratings["user_id"].isin(test_indices)]
99
 
100
  print(f"Train size: {len(train_data)}.")
101
  print(f"Validation size: {len(val_data)}.")
102
+ print(f"Test size: {len(test_data)}.")
103
 
104
  train_data.to_csv(train_savepath, index=False)
105
  val_data.to_csv(val_savepath, index=False)
106
+ test_data.to_csv(test_savepath, index=False)
107
 
108
 
109
  if __name__ == "__main__":
110
  parser = argparse.ArgumentParser(description="Process raw data.")
111
  parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
112
  parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
113
+ parser.add_argument("--create_train_val_test_split", action="store_true", help="Flag to indicate whether to create a train-validation split.")
114
  args = parser.parse_args()
115
 
116
  print("Processing raw data...")
117
  process_raw_data_goodreads(args.input_directory, args.save_directory)
118
+ if args.create_train_val_test_split:
119
+ create_train_val_test_split(
120
  os.path.join(args.save_directory, "ratings.csv"),
121
  os.path.join(args.save_directory, "train_ratings.csv"),
122
+ os.path.join(args.save_directory, "val_ratings.csv"),
123
+ os.path.join(args.save_directory, "test_ratings.csv")
124
  )
125
  print("The raw data has been successfully processed.")
exp/requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  -r ../requirements.txt # install base requirements of app
2
- dgl==2.1.0
3
- torch==2.1.2
 
4
  wandb==0.17.0
5
  tqdm==4.66.4
6
  pydantic==2.5.3
 
1
  -r ../requirements.txt # install base requirements of app
2
+ -f https://data.dgl.ai/wheels/torch-2.2/repo.html
3
+ dgl==1.1.3
4
+ torch==2.2.1
5
  wandb==0.17.0
6
  tqdm==4.66.4
7
  pydantic==2.5.3
exp/requirements_gpu.txt CHANGED
@@ -1,7 +1,7 @@
1
  -f https://data.dgl.ai/wheels/cu121/repo.html
2
  -r ../requirements.txt # install base requirements of app
3
- dgl==2.1.0
4
- torch==2.1.0
5
  wandb==0.17.0
6
  tqdm==4.66.4
7
  pydantic==2.5.3
 
1
  -f https://data.dgl.ai/wheels/cu121/repo.html
2
  -r ../requirements.txt # install base requirements of app
3
+ dgl==1.1.3
4
+ torch==2.2.1
5
  wandb==0.17.0
6
  tqdm==4.66.4
7
  pydantic==2.5.3
exp/sbert.py CHANGED
@@ -19,6 +19,7 @@ def prepare_sbert_embeddings(
19
  items = pd.read_csv(items_path).sort_values("item_id")
20
  sentences = items["description"].values
21
  model = SentenceTransformer(model_name).to(device)
 
22
  embeddings = []
23
  for start_index in tqdm(range(0, len(sentences), batch_size)):
24
  batch = sentences[start_index:start_index+batch_size]
 
19
  items = pd.read_csv(items_path).sort_values("item_id")
20
  sentences = items["description"].values
21
  model = SentenceTransformer(model_name).to(device)
22
+ model.eval()
23
  embeddings = []
24
  for start_index in tqdm(range(0, len(sentences), batch_size)):
25
  batch = sentences[start_index:start_index+batch_size]