erermeev-d
commited on
Commit
·
b8f4763
1
Parent(s):
d4852d9
Updated experiments code
Browse files- app/database.py +2 -1
- exp/deepwalk.py +0 -80
- exp/evaluate.py +25 -18
- exp/gnn.py +48 -30
- exp/prepare_embeddings.sh +7 -18
- exp/prepare_index.py +0 -20
- exp/{prepare_db.py → prepare_recsys.py} +24 -3
- exp/process_raw_data.py +17 -8
- exp/requirements.txt +3 -2
- exp/requirements_gpu.txt +2 -2
- exp/sbert.py +1 -0
app/database.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import io
|
2 |
import sqlite3
|
3 |
|
@@ -26,8 +27,8 @@ class ItemDatabase:
|
|
26 |
rows = c.fetchall()[:n_items]
|
27 |
return [row[0] for row in rows]
|
28 |
|
|
|
29 |
def get_item(self, item_id):
|
30 |
-
|
31 |
with self._connect() as conn:
|
32 |
c = conn.cursor()
|
33 |
c.row_factory = sqlite3.Row
|
|
|
1 |
+
import functools
|
2 |
import io
|
3 |
import sqlite3
|
4 |
|
|
|
27 |
rows = c.fetchall()[:n_items]
|
28 |
return [row[0] for row in rows]
|
29 |
|
30 |
+
@functools.lru_cache(maxsize=2**14)
|
31 |
def get_item(self, item_id):
|
|
|
32 |
with self._connect() as conn:
|
33 |
c = conn.cursor()
|
34 |
c.row_factory = sqlite3.Row
|
exp/deepwalk.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import pandas as pd
|
6 |
-
import dgl
|
7 |
-
import torch
|
8 |
-
import wandb
|
9 |
-
from tqdm.auto import tqdm
|
10 |
-
|
11 |
-
from utils import prepare_graphs, extract_item_embeddings, normalize_embeddings
|
12 |
-
|
13 |
-
|
14 |
-
def prepare_deepwalk_embeddings(
|
15 |
-
items_path,
|
16 |
-
ratings_path,
|
17 |
-
embeddings_savepath,
|
18 |
-
emb_dim,
|
19 |
-
window_size,
|
20 |
-
batch_size,
|
21 |
-
lr,
|
22 |
-
num_epochs,
|
23 |
-
device,
|
24 |
-
wandb_name,
|
25 |
-
use_wandb
|
26 |
-
):
|
27 |
-
### Prepare graph
|
28 |
-
bipartite_graph, graph = prepare_graphs(items_path, ratings_path)
|
29 |
-
bipartite_graph = bipartite_graph.to(device)
|
30 |
-
graph = graph.to(device)
|
31 |
-
|
32 |
-
### Run deepwalk
|
33 |
-
if use_wandb:
|
34 |
-
wandb.init(project="graph-recs-deepwalk", name=wandb_name)
|
35 |
-
|
36 |
-
model = dgl.nn.DeepWalk(graph, emb_dim=emb_dim, window_size=window_size)
|
37 |
-
model = model.to(device)
|
38 |
-
dataloader = torch.utils.data.DataLoader(
|
39 |
-
torch.arange(graph.num_nodes()),
|
40 |
-
batch_size=batch_size,
|
41 |
-
shuffle=True,
|
42 |
-
collate_fn=model.sample)
|
43 |
-
|
44 |
-
optimizer = torch.optim.SparseAdam(model.parameters(), lr=lr)
|
45 |
-
for epoch in range(num_epochs):
|
46 |
-
for batch_walk in tqdm(dataloader):
|
47 |
-
loss = model(batch_walk)
|
48 |
-
if use_wandb:
|
49 |
-
wandb.log({"loss": loss.item()})
|
50 |
-
optimizer.zero_grad()
|
51 |
-
loss.backward()
|
52 |
-
optimizer.step()
|
53 |
-
|
54 |
-
if use_wandb:
|
55 |
-
wandb.finish()
|
56 |
-
|
57 |
-
node_embeddings = model.node_embed.weight.detach().to(device)
|
58 |
-
|
59 |
-
### Extract & save item embeddings
|
60 |
-
item_embeddings = extract_item_embeddings(node_embeddings, bipartite_graph, graph)
|
61 |
-
item_embeddings = normalize_embeddings(item_embeddings)
|
62 |
-
np.save(embeddings_savepath, item_embeddings)
|
63 |
-
|
64 |
-
|
65 |
-
if __name__ == "__main__":
|
66 |
-
parser = argparse.ArgumentParser(description="Prepare DeepWalk embeddings.")
|
67 |
-
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
|
68 |
-
parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file.")
|
69 |
-
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where embeddings will be saved.")
|
70 |
-
parser.add_argument("--emb_dim", type=int, default=384, help="Dimensionality of the embeddings.")
|
71 |
-
parser.add_argument("--window_size", type=int, default=4, help="Window size for the DeepWalk algorithm.")
|
72 |
-
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training.")
|
73 |
-
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for training.")
|
74 |
-
parser.add_argument("--num_epochs", type=int, default=2, help="Number of epochs for training.")
|
75 |
-
parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
|
76 |
-
parser.add_argument("--wandb_name", type=str, help="Name for WandB run.")
|
77 |
-
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
78 |
-
args = parser.parse_args()
|
79 |
-
|
80 |
-
prepare_deepwalk_embeddings(**vars(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp/evaluate.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import argparse
|
2 |
import json
|
|
|
3 |
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
@@ -9,16 +10,26 @@ from app.recommendations import RecommenderSystem
|
|
9 |
|
10 |
def precision_at_k(recommended_items, relevant_items, k):
|
11 |
recommended_at_k = set(recommended_items[:k])
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def evaluate_recsys(
|
17 |
-
metrics_savepath,
|
18 |
val_ratings_path,
|
19 |
faiss_index_path,
|
20 |
db_path,
|
21 |
-
n_recommend_items,
|
|
|
22 |
):
|
23 |
recsys = RecommenderSystem(
|
24 |
faiss_index_path=faiss_index_path,
|
@@ -30,16 +41,14 @@ def evaluate_recsys(
|
|
30 |
|
31 |
|
32 |
metric_arrays = {
|
33 |
-
"
|
34 |
-
"precision@3": [],
|
35 |
-
"precision@10": []
|
36 |
}
|
37 |
|
38 |
-
for item_group in grouped_items:
|
39 |
if len(item_group) == 1:
|
40 |
continue
|
41 |
|
42 |
-
###
|
43 |
### We will first aggregate it over all edges for user
|
44 |
### And after that - aggregate over all users
|
45 |
user_metric_arrays = dict()
|
@@ -50,12 +59,8 @@ def evaluate_recsys(
|
|
50 |
recommend_items = list(recsys.recommend_items(item, n_recommend_items))
|
51 |
relevant_items = set(item_group) - {item}
|
52 |
|
53 |
-
user_metric_arrays["
|
54 |
-
|
55 |
-
user_metric_arrays["precision@3"].append(
|
56 |
-
precision_at_k(recommend_items, relevant_items, k=3))
|
57 |
-
user_metric_arrays["precision@10"].append(
|
58 |
-
precision_at_k(recommend_items, relevant_items, k=10))
|
59 |
|
60 |
for metric in metric_arrays.keys():
|
61 |
user_metric = np.mean(user_metric_arrays[metric])
|
@@ -65,9 +70,11 @@ def evaluate_recsys(
|
|
65 |
for metric, array in metric_arrays.items():
|
66 |
metrics[metric] = np.mean(array)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
71 |
|
72 |
|
73 |
if __name__ == "__main__":
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
+
from tqdm.auto import tqdm
|
4 |
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
|
|
10 |
|
11 |
def precision_at_k(recommended_items, relevant_items, k):
|
12 |
recommended_at_k = set(recommended_items[:k])
|
13 |
+
return len(recommended_at_k & relevant_items) / k
|
14 |
+
|
15 |
+
|
16 |
+
def average_precision_at_k(recommended_items, relevant_items, k):
|
17 |
+
relevant_items = set(relevant_items)
|
18 |
+
|
19 |
+
apk_sum = 0.0
|
20 |
+
for m, item in enumerate(recommended_items):
|
21 |
+
if item in relevant_items:
|
22 |
+
apk_sum += precision_at_k(recommended_items, relevant_items, m+1)
|
23 |
+
|
24 |
+
return apk_sum / min(k, len(relevant_items))
|
25 |
|
26 |
|
27 |
def evaluate_recsys(
|
|
|
28 |
val_ratings_path,
|
29 |
faiss_index_path,
|
30 |
db_path,
|
31 |
+
n_recommend_items=10,
|
32 |
+
metrics_savepath=None
|
33 |
):
|
34 |
recsys = RecommenderSystem(
|
35 |
faiss_index_path=faiss_index_path,
|
|
|
41 |
|
42 |
|
43 |
metric_arrays = {
|
44 |
+
"ap@5": [],
|
|
|
|
|
45 |
}
|
46 |
|
47 |
+
for item_group in tqdm(grouped_items):
|
48 |
if len(item_group) == 1:
|
49 |
continue
|
50 |
|
51 |
+
### Metrics are computed for each edge.
|
52 |
### We will first aggregate it over all edges for user
|
53 |
### And after that - aggregate over all users
|
54 |
user_metric_arrays = dict()
|
|
|
59 |
recommend_items = list(recsys.recommend_items(item, n_recommend_items))
|
60 |
relevant_items = set(item_group) - {item}
|
61 |
|
62 |
+
user_metric_arrays["ap@5"].append(
|
63 |
+
average_precision_at_k(recommend_items, relevant_items, k=5))
|
|
|
|
|
|
|
|
|
64 |
|
65 |
for metric in metric_arrays.keys():
|
66 |
user_metric = np.mean(user_metric_arrays[metric])
|
|
|
70 |
for metric, array in metric_arrays.items():
|
71 |
metrics[metric] = np.mean(array)
|
72 |
|
73 |
+
if metrics_savepath is not None:
|
74 |
+
with open(metrics_savepath, "w") as f:
|
75 |
+
json.dump(metrics, f)
|
76 |
+
print(f"Saved metrics to {metrics_savepath}")
|
77 |
+
return metrics
|
78 |
|
79 |
|
80 |
if __name__ == "__main__":
|
exp/gnn.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import argparse
|
2 |
import os
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
@@ -8,7 +9,9 @@ import torch
|
|
8 |
import wandb
|
9 |
from tqdm.auto import tqdm
|
10 |
|
11 |
-
from utils import prepare_graphs, normalize_embeddings, LRSchedule
|
|
|
|
|
12 |
|
13 |
|
14 |
class GNNLayer(torch.nn.Module):
|
@@ -41,7 +44,6 @@ class GNNModel(torch.nn.Module):
|
|
41 |
self,
|
42 |
bipartite_graph,
|
43 |
text_embeddings,
|
44 |
-
deepwalk_embeddings,
|
45 |
num_layers,
|
46 |
hidden_dim,
|
47 |
aggregator_type,
|
@@ -56,14 +58,12 @@ class GNNModel(torch.nn.Module):
|
|
56 |
|
57 |
self._bipartite_graph = bipartite_graph
|
58 |
self._text_embeddings = text_embeddings
|
59 |
-
self._deepwalk_embeddings = deepwalk_embeddings
|
60 |
|
61 |
self._sampler = dgl.sampling.PinSAGESampler(
|
62 |
bipartite_graph, "Item", "User", num_traversals,
|
63 |
termination_prob, num_random_walks, num_neighbor)
|
64 |
|
65 |
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
|
66 |
-
self._deepwalk_encoder = torch.nn.Linear(deepwalk_embeddings.shape[-1], hidden_dim)
|
67 |
|
68 |
self._layers = torch.nn.ModuleList()
|
69 |
for _ in range(num_layers):
|
@@ -92,13 +92,10 @@ class GNNModel(torch.nn.Module):
|
|
92 |
sampled_subgraph = self._sample_subraph(ids)
|
93 |
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
|
94 |
|
95 |
-
### Encode text
|
96 |
text_embeddings = self._text_embeddings[
|
97 |
sampled_subgraph.ndata[dgl.NID]]
|
98 |
-
|
99 |
-
sampled_subgraph.ndata[dgl.NID]]
|
100 |
-
features = self._text_encoder(text_embeddings) \
|
101 |
-
+ self._deepwalk_encoder(deepwalk_embeddings)
|
102 |
|
103 |
### GNN goes brr...
|
104 |
for layer in self._layers:
|
@@ -142,12 +139,27 @@ def sample_item_batch(user_batch, bipartite_graph):
|
|
142 |
return item_batch
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def prepare_gnn_embeddings(
|
146 |
# Paths
|
147 |
items_path,
|
148 |
-
|
|
|
149 |
text_embeddings_path,
|
150 |
-
deepwalk_embeddings_path,
|
151 |
embeddings_savepath,
|
152 |
# Learning hyperparameters
|
153 |
temperature,
|
@@ -165,12 +177,13 @@ def prepare_gnn_embeddings(
|
|
165 |
num_random_walks,
|
166 |
num_neighbor,
|
167 |
# Misc
|
|
|
168 |
device,
|
169 |
wandb_name,
|
170 |
use_wandb,
|
171 |
):
|
172 |
### Prepare graph
|
173 |
-
bipartite_graph, _ = prepare_graphs(items_path,
|
174 |
bipartite_graph = bipartite_graph.to(device)
|
175 |
|
176 |
### Init wandb
|
@@ -179,11 +192,9 @@ def prepare_gnn_embeddings(
|
|
179 |
|
180 |
### Prepare model
|
181 |
text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
|
182 |
-
deepwalk_embeddings = torch.tensor(np.load(deepwalk_embeddings_path)).to(device)
|
183 |
model = GNNModel(
|
184 |
bipartite_graph=bipartite_graph,
|
185 |
text_embeddings=text_embeddings,
|
186 |
-
deepwalk_embeddings=deepwalk_embeddings,
|
187 |
num_layers=num_layers,
|
188 |
hidden_dim=hidden_dim,
|
189 |
aggregator_type=aggregator_type,
|
@@ -214,6 +225,7 @@ def prepare_gnn_embeddings(
|
|
214 |
### Train loop
|
215 |
model.train()
|
216 |
for epoch in range(num_epochs):
|
|
|
217 |
for user_batch in tqdm(dataloader):
|
218 |
item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
|
219 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
@@ -226,24 +238,29 @@ def prepare_gnn_embeddings(
|
|
226 |
loss.backward()
|
227 |
optimizer.step()
|
228 |
lr_scheduler.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
if use_wandb:
|
231 |
wandb.finish()
|
232 |
|
233 |
### Process full dataset
|
234 |
-
model
|
235 |
-
with torch.no_grad():
|
236 |
-
hidden_dim = text_embeddings.shape[-1]
|
237 |
-
item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
|
238 |
-
for items_batch in tqdm(torch.utils.data.DataLoader(
|
239 |
-
torch.arange(bipartite_graph.num_nodes("Item")),
|
240 |
-
batch_size=batch_size,
|
241 |
-
shuffle=True
|
242 |
-
)):
|
243 |
-
item_embeddings[items_batch] = model(items_batch.to(device))
|
244 |
-
|
245 |
-
### Extract & save item embeddings
|
246 |
-
item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
|
247 |
np.save(embeddings_savepath, item_embeddings)
|
248 |
|
249 |
|
@@ -252,9 +269,9 @@ if __name__ == "__main__":
|
|
252 |
|
253 |
# Paths
|
254 |
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
|
255 |
-
parser.add_argument("--
|
|
|
256 |
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
|
257 |
-
parser.add_argument("--deepwalk_embeddings_path", type=str, required=True, help="Path to the deepwalk embeddings file")
|
258 |
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
|
259 |
|
260 |
# Learning hyperparameters
|
@@ -265,7 +282,7 @@ if __name__ == "__main__":
|
|
265 |
|
266 |
# Model hyperparameters
|
267 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
|
268 |
-
parser.add_argument("--hidden_dim", type=int, default=
|
269 |
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
|
270 |
parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
|
271 |
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
|
@@ -275,6 +292,7 @@ if __name__ == "__main__":
|
|
275 |
parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
|
276 |
|
277 |
# Misc
|
|
|
278 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
279 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
280 |
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
+
import tempfile
|
4 |
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
|
|
9 |
import wandb
|
10 |
from tqdm.auto import tqdm
|
11 |
|
12 |
+
from exp.utils import prepare_graphs, normalize_embeddings, LRSchedule
|
13 |
+
from exp.prepare_recsys import prepare_recsys
|
14 |
+
from exp.evaluate import evaluate_recsys
|
15 |
|
16 |
|
17 |
class GNNLayer(torch.nn.Module):
|
|
|
44 |
self,
|
45 |
bipartite_graph,
|
46 |
text_embeddings,
|
|
|
47 |
num_layers,
|
48 |
hidden_dim,
|
49 |
aggregator_type,
|
|
|
58 |
|
59 |
self._bipartite_graph = bipartite_graph
|
60 |
self._text_embeddings = text_embeddings
|
|
|
61 |
|
62 |
self._sampler = dgl.sampling.PinSAGESampler(
|
63 |
bipartite_graph, "Item", "User", num_traversals,
|
64 |
termination_prob, num_random_walks, num_neighbor)
|
65 |
|
66 |
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
|
|
|
67 |
|
68 |
self._layers = torch.nn.ModuleList()
|
69 |
for _ in range(num_layers):
|
|
|
92 |
sampled_subgraph = self._sample_subraph(ids)
|
93 |
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
|
94 |
|
95 |
+
### Encode text embeddings
|
96 |
text_embeddings = self._text_embeddings[
|
97 |
sampled_subgraph.ndata[dgl.NID]]
|
98 |
+
features = self._text_encoder(text_embeddings)
|
|
|
|
|
|
|
99 |
|
100 |
### GNN goes brr...
|
101 |
for layer in self._layers:
|
|
|
139 |
return item_batch
|
140 |
|
141 |
|
142 |
+
@torch.no_grad()
|
143 |
+
def inference_model(model, bipartite_graph, batch_size, hidden_dim, device):
|
144 |
+
model.eval()
|
145 |
+
item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
|
146 |
+
for items_batch in tqdm(torch.utils.data.DataLoader(
|
147 |
+
torch.arange(bipartite_graph.num_nodes("Item")),
|
148 |
+
batch_size=batch_size,
|
149 |
+
shuffle=True
|
150 |
+
)):
|
151 |
+
item_embeddings[items_batch] = model(items_batch.to(device))
|
152 |
+
|
153 |
+
item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
|
154 |
+
return item_embeddings
|
155 |
+
|
156 |
+
|
157 |
def prepare_gnn_embeddings(
|
158 |
# Paths
|
159 |
items_path,
|
160 |
+
train_ratings_path,
|
161 |
+
val_ratings_path,
|
162 |
text_embeddings_path,
|
|
|
163 |
embeddings_savepath,
|
164 |
# Learning hyperparameters
|
165 |
temperature,
|
|
|
177 |
num_random_walks,
|
178 |
num_neighbor,
|
179 |
# Misc
|
180 |
+
validate_every_n_epoch,
|
181 |
device,
|
182 |
wandb_name,
|
183 |
use_wandb,
|
184 |
):
|
185 |
### Prepare graph
|
186 |
+
bipartite_graph, _ = prepare_graphs(items_path, train_ratings_path)
|
187 |
bipartite_graph = bipartite_graph.to(device)
|
188 |
|
189 |
### Init wandb
|
|
|
192 |
|
193 |
### Prepare model
|
194 |
text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
|
|
|
195 |
model = GNNModel(
|
196 |
bipartite_graph=bipartite_graph,
|
197 |
text_embeddings=text_embeddings,
|
|
|
198 |
num_layers=num_layers,
|
199 |
hidden_dim=hidden_dim,
|
200 |
aggregator_type=aggregator_type,
|
|
|
225 |
### Train loop
|
226 |
model.train()
|
227 |
for epoch in range(num_epochs):
|
228 |
+
### Train
|
229 |
for user_batch in tqdm(dataloader):
|
230 |
item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
|
231 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
|
|
238 |
loss.backward()
|
239 |
optimizer.step()
|
240 |
lr_scheduler.step()
|
241 |
+
### Validation
|
242 |
+
if (validate_every_n_epoch is not None) and (((epoch + 1) % validate_every_n_epoch) == 0):
|
243 |
+
item_embeddings = inference_model(
|
244 |
+
model, bipartite_graph, batch_size, hidden_dim, device)
|
245 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
246 |
+
tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
|
247 |
+
np.save(tmp_embeddings_path, item_embeddings)
|
248 |
+
prepare_recsys(items_path, tmp_embeddings_path, tmp_dir_name)
|
249 |
+
metrics = evaluate_recsys(
|
250 |
+
val_ratings_path,
|
251 |
+
os.path.join(tmp_dir_name, "index.faiss"),
|
252 |
+
os.path.join(tmp_dir_name, "items.db"))
|
253 |
+
print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
|
254 |
+
if use_wandb:
|
255 |
+
wandb.log(metrics)
|
256 |
+
|
257 |
+
|
258 |
|
259 |
if use_wandb:
|
260 |
wandb.finish()
|
261 |
|
262 |
### Process full dataset
|
263 |
+
item_embeddings = inference_model(model, bipartite_graph, batch_size, hidden_dim, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
np.save(embeddings_savepath, item_embeddings)
|
265 |
|
266 |
|
|
|
269 |
|
270 |
# Paths
|
271 |
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
|
272 |
+
parser.add_argument("--train_ratings_path", type=str, required=True, help="Path to the train ratings file")
|
273 |
+
parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
|
274 |
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
|
|
|
275 |
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
|
276 |
|
277 |
# Learning hyperparameters
|
|
|
282 |
|
283 |
# Model hyperparameters
|
284 |
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
|
285 |
+
parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size")
|
286 |
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
|
287 |
parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
|
288 |
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
|
|
|
292 |
parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
|
293 |
|
294 |
# Misc
|
295 |
+
parser.add_argument("--validate_every_n_epoch", type=int, default=2, help="Perform RecSys validation every n train epochs.")
|
296 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
297 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
298 |
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
exp/prepare_embeddings.sh
CHANGED
@@ -10,14 +10,7 @@ echo Running on "$device".
|
|
10 |
PYTHONPATH=. python exp/process_raw_data.py \
|
11 |
--input_directory "$input_directory" \
|
12 |
--save_directory "$save_directory" \
|
13 |
-
--
|
14 |
-
|
15 |
-
PYTHONPATH=. python exp/deepwalk.py \
|
16 |
-
--items_path "$save_directory/items.csv" \
|
17 |
-
--ratings_path "$save_directory/train_ratings.csv" \
|
18 |
-
--embeddings_savepath "$save_directory/deepwalk_embeddings.npy" \
|
19 |
-
--device $device \
|
20 |
-
--no_wandb
|
21 |
|
22 |
PYTHONPATH=. python exp/sbert.py \
|
23 |
--items_path "$save_directory/items.csv" \
|
@@ -26,25 +19,21 @@ PYTHONPATH=. python exp/sbert.py \
|
|
26 |
|
27 |
PYTHONPATH=. python exp/gnn.py \
|
28 |
--items_path "$save_directory/items.csv" \
|
29 |
-
--
|
|
|
30 |
--text_embeddings_path "$save_directory/text_embeddings.npy" \
|
31 |
-
--deepwalk_embeddings_path "$save_directory/deepwalk_embeddings.npy" \
|
32 |
--embeddings_savepath "$save_directory/embeddings.npy"\
|
33 |
--device $device \
|
34 |
-
--no_wandb
|
35 |
-
|
36 |
-
PYTHONPATH=. python exp/prepare_index.py \
|
37 |
-
--embeddings_path "$save_directory/embeddings.npy" \
|
38 |
-
--save_path "$save_directory/index.faiss"
|
39 |
|
40 |
-
PYTHONPATH=. python exp/
|
41 |
--items_path "$save_directory/items.csv" \
|
42 |
--embeddings_path "$save_directory/embeddings.npy" \
|
43 |
-
--
|
44 |
|
45 |
PYTHONPATH=. python exp/evaluate.py \
|
46 |
--metrics_savepath "$save_directory/metrics.json" \
|
47 |
-
--val_ratings_path "$save_directory/
|
48 |
--faiss_index_path "$save_directory/index.faiss" \
|
49 |
--db_path "$save_directory/items.db"
|
50 |
|
|
|
10 |
PYTHONPATH=. python exp/process_raw_data.py \
|
11 |
--input_directory "$input_directory" \
|
12 |
--save_directory "$save_directory" \
|
13 |
+
--create_train_val_test_split
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
PYTHONPATH=. python exp/sbert.py \
|
16 |
--items_path "$save_directory/items.csv" \
|
|
|
19 |
|
20 |
PYTHONPATH=. python exp/gnn.py \
|
21 |
--items_path "$save_directory/items.csv" \
|
22 |
+
--train_ratings_path "$save_directory/train_ratings.csv" \
|
23 |
+
--val_ratings_path "$save_directory/val_ratings.csv" \
|
24 |
--text_embeddings_path "$save_directory/text_embeddings.npy" \
|
|
|
25 |
--embeddings_savepath "$save_directory/embeddings.npy"\
|
26 |
--device $device \
|
27 |
+
--no_wandb
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
PYTHONPATH=. python exp/prepare_recsys.py \
|
30 |
--items_path "$save_directory/items.csv" \
|
31 |
--embeddings_path "$save_directory/embeddings.npy" \
|
32 |
+
--save_directory "$save_directory"
|
33 |
|
34 |
PYTHONPATH=. python exp/evaluate.py \
|
35 |
--metrics_savepath "$save_directory/metrics.json" \
|
36 |
+
--val_ratings_path "$save_directory/test_ratings.csv" \
|
37 |
--faiss_index_path "$save_directory/index.faiss" \
|
38 |
--db_path "$save_directory/items.db"
|
39 |
|
exp/prepare_index.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import faiss
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
|
7 |
-
def build_index(embeddings_path, save_path, n_neighbors):
|
8 |
-
embeddings = np.load(embeddings_path)
|
9 |
-
index = faiss.IndexHNSWFlat(embeddings.shape[-1], 32)
|
10 |
-
index.add(embeddings)
|
11 |
-
faiss.write_index(index, save_path)
|
12 |
-
|
13 |
-
|
14 |
-
if __name__ == "__main__":
|
15 |
-
parser = argparse.ArgumentParser(description="Build an HNSW index from embeddings.")
|
16 |
-
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the embeddings file.")
|
17 |
-
parser.add_argument("--save_path", type=str, required=True, help="Path to save the built index.")
|
18 |
-
parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
|
19 |
-
args = parser.parse_args()
|
20 |
-
build_index(**vars(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exp/{prepare_db.py → prepare_recsys.py}
RENAMED
@@ -1,7 +1,9 @@
|
|
1 |
import argparse
|
2 |
import sqlite3
|
3 |
import io
|
|
|
4 |
|
|
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
7 |
|
@@ -23,11 +25,30 @@ def prepare_items_db(items_path, embeddings_path, db_path):
|
|
23 |
items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if __name__ == "__main__":
|
27 |
-
parser = argparse.ArgumentParser(description="Prepare items database from a CSV file.")
|
|
|
28 |
parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
|
29 |
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
|
30 |
-
parser.add_argument("--
|
|
|
31 |
|
32 |
args = parser.parse_args()
|
33 |
-
|
|
|
1 |
import argparse
|
2 |
import sqlite3
|
3 |
import io
|
4 |
+
import os
|
5 |
|
6 |
+
import faiss
|
7 |
import pandas as pd
|
8 |
import numpy as np
|
9 |
|
|
|
25 |
items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
|
26 |
|
27 |
|
28 |
+
def build_index(embeddings_path, save_path, n_neighbors):
|
29 |
+
embeddings = np.load(embeddings_path)
|
30 |
+
index = faiss.IndexHNSWFlat(embeddings.shape[-1], n_neighbors)
|
31 |
+
index.add(embeddings)
|
32 |
+
faiss.write_index(index, save_path)
|
33 |
+
|
34 |
+
|
35 |
+
def prepare_recsys(
|
36 |
+
items_path,
|
37 |
+
embeddings_path,
|
38 |
+
save_directory,
|
39 |
+
n_neighbors=32,
|
40 |
+
):
|
41 |
+
prepare_items_db(items_path, embeddings_path, os.path.join(save_directory, "items.db"))
|
42 |
+
build_index(embeddings_path, os.path.join(save_directory, "index.faiss"), n_neighbors)
|
43 |
+
|
44 |
+
|
45 |
if __name__ == "__main__":
|
46 |
+
parser = argparse.ArgumentParser(description="Prepare items database and HNSW index from a CSV file and embeddings.")
|
47 |
+
|
48 |
parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
|
49 |
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
|
50 |
+
parser.add_argument("--save_directory", required=True, type=str, help="Path to the save directory.")
|
51 |
+
parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
|
52 |
|
53 |
args = parser.parse_args()
|
54 |
+
prepare_recsys(**vars(args))
|
exp/process_raw_data.py
CHANGED
@@ -80,37 +80,46 @@ def process_raw_data_goodreads(input_directory, save_directory, positive_rating_
|
|
80 |
ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
|
81 |
|
82 |
|
83 |
-
def
|
84 |
ratings = pd.read_csv(ratings_path)
|
85 |
user_ids = ratings["user_id"].unique()
|
86 |
|
87 |
rng = np.random.default_rng(seed=seed)
|
88 |
-
train_size = int(len(user_ids) * 0.
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
|
92 |
-
val_data = ratings.loc[
|
|
|
93 |
|
94 |
print(f"Train size: {len(train_data)}.")
|
95 |
print(f"Validation size: {len(val_data)}.")
|
|
|
96 |
|
97 |
train_data.to_csv(train_savepath, index=False)
|
98 |
val_data.to_csv(val_savepath, index=False)
|
|
|
99 |
|
100 |
|
101 |
if __name__ == "__main__":
|
102 |
parser = argparse.ArgumentParser(description="Process raw data.")
|
103 |
parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
|
104 |
parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
|
105 |
-
parser.add_argument("--
|
106 |
args = parser.parse_args()
|
107 |
|
108 |
print("Processing raw data...")
|
109 |
process_raw_data_goodreads(args.input_directory, args.save_directory)
|
110 |
-
if args.
|
111 |
-
|
112 |
os.path.join(args.save_directory, "ratings.csv"),
|
113 |
os.path.join(args.save_directory, "train_ratings.csv"),
|
114 |
-
os.path.join(args.save_directory, "val_ratings.csv")
|
|
|
115 |
)
|
116 |
print("The raw data has been successfully processed.")
|
|
|
80 |
ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
|
81 |
|
82 |
|
83 |
+
def create_train_val_test_split(ratings_path, train_savepath, val_savepath, test_savepath, seed=42):
|
84 |
ratings = pd.read_csv(ratings_path)
|
85 |
user_ids = ratings["user_id"].unique()
|
86 |
|
87 |
rng = np.random.default_rng(seed=seed)
|
88 |
+
train_size = int(len(user_ids) * 0.7)
|
89 |
+
val_size = int(len(user_ids) * 0.15)
|
90 |
+
|
91 |
+
indices = rng.permutation(user_ids)
|
92 |
+
train_indices = indices[:train_size]
|
93 |
+
val_indices = indices[train_size:train_size+val_size]
|
94 |
+
test_indices = indices[train_size+val_size:]
|
95 |
|
96 |
train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
|
97 |
+
val_data = ratings.loc[ratings["user_id"].isin(val_indices)]
|
98 |
+
test_data = ratings.loc[ratings["user_id"].isin(test_indices)]
|
99 |
|
100 |
print(f"Train size: {len(train_data)}.")
|
101 |
print(f"Validation size: {len(val_data)}.")
|
102 |
+
print(f"Test size: {len(test_data)}.")
|
103 |
|
104 |
train_data.to_csv(train_savepath, index=False)
|
105 |
val_data.to_csv(val_savepath, index=False)
|
106 |
+
test_data.to_csv(test_savepath, index=False)
|
107 |
|
108 |
|
109 |
if __name__ == "__main__":
|
110 |
parser = argparse.ArgumentParser(description="Process raw data.")
|
111 |
parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
|
112 |
parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
|
113 |
+
parser.add_argument("--create_train_val_test_split", action="store_true", help="Flag to indicate whether to create a train-validation split.")
|
114 |
args = parser.parse_args()
|
115 |
|
116 |
print("Processing raw data...")
|
117 |
process_raw_data_goodreads(args.input_directory, args.save_directory)
|
118 |
+
if args.create_train_val_test_split:
|
119 |
+
create_train_val_test_split(
|
120 |
os.path.join(args.save_directory, "ratings.csv"),
|
121 |
os.path.join(args.save_directory, "train_ratings.csv"),
|
122 |
+
os.path.join(args.save_directory, "val_ratings.csv"),
|
123 |
+
os.path.join(args.save_directory, "test_ratings.csv")
|
124 |
)
|
125 |
print("The raw data has been successfully processed.")
|
exp/requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-r ../requirements.txt # install base requirements of app
|
2 |
-
dgl
|
3 |
-
|
|
|
4 |
wandb==0.17.0
|
5 |
tqdm==4.66.4
|
6 |
pydantic==2.5.3
|
|
|
1 |
-r ../requirements.txt # install base requirements of app
|
2 |
+
-f https://data.dgl.ai/wheels/torch-2.2/repo.html
|
3 |
+
dgl==1.1.3
|
4 |
+
torch==2.2.1
|
5 |
wandb==0.17.0
|
6 |
tqdm==4.66.4
|
7 |
pydantic==2.5.3
|
exp/requirements_gpu.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-f https://data.dgl.ai/wheels/cu121/repo.html
|
2 |
-r ../requirements.txt # install base requirements of app
|
3 |
-
dgl==
|
4 |
-
torch==2.1
|
5 |
wandb==0.17.0
|
6 |
tqdm==4.66.4
|
7 |
pydantic==2.5.3
|
|
|
1 |
-f https://data.dgl.ai/wheels/cu121/repo.html
|
2 |
-r ../requirements.txt # install base requirements of app
|
3 |
+
dgl==1.1.3
|
4 |
+
torch==2.2.1
|
5 |
wandb==0.17.0
|
6 |
tqdm==4.66.4
|
7 |
pydantic==2.5.3
|
exp/sbert.py
CHANGED
@@ -19,6 +19,7 @@ def prepare_sbert_embeddings(
|
|
19 |
items = pd.read_csv(items_path).sort_values("item_id")
|
20 |
sentences = items["description"].values
|
21 |
model = SentenceTransformer(model_name).to(device)
|
|
|
22 |
embeddings = []
|
23 |
for start_index in tqdm(range(0, len(sentences), batch_size)):
|
24 |
batch = sentences[start_index:start_index+batch_size]
|
|
|
19 |
items = pd.read_csv(items_path).sort_values("item_id")
|
20 |
sentences = items["description"].values
|
21 |
model = SentenceTransformer(model_name).to(device)
|
22 |
+
model.eval()
|
23 |
embeddings = []
|
24 |
for start_index in tqdm(range(0, len(sentences), batch_size)):
|
25 |
batch = sentences[start_index:start_index+batch_size]
|