erermeev-d commited on
Commit
d4852d9
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ data/* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.13
2
+
3
+ COPY requirements.txt .
4
+ RUN pip3 install -r requirements.txt
5
+
6
+ RUN mkdir /data
7
+ RUN wget https://storage.yandexcloud.net/eremeev-d-bucket-main/1722760266.tar -O data.tar
8
+ RUN tar -xf data.tar -C /data
9
+ RUN rm data.tar
10
+
11
+ RUN mkdir /app
12
+ COPY app app
13
+
14
+ EXPOSE 8501
15
+
16
+ ENV PYTHONPATH .
17
+
18
+ ENTRYPOINT ["streamlit", "run", "app/main.py", "--server.port=8501"]
Makefile ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ run-app:
2
+ docker build -t graph-rec-app .
3
+ docker run --rm -p 8501:8501 graph-rec-app
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: A simple Graph-based Recommender System
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: docker
7
+ app_port: 8501
8
+ ---
9
+ # A simple Graph-based Recommender System
10
+
11
+ ### What is it?
12
+
13
+ This app is a simple graph-based recommender system that searches for items and recommends similar ones. It can be applied to any dataset. For demonstration purposes, we use the (filtered) [Goodreads](https://mengtingwan.github.io/data/goodreads#datasets) dataset.
14
+
15
+ ### Where can I try this app?
16
+
17
+ The app is currently deployed at HuggingFace Spaces ([link](https://huggingface.co/spaces/eremeev-d/graph-rec)). You will probably need to wait a minute or two for app to start running.
18
+
19
+ ### How to use it?
20
+
21
+ Simply enter a keyword (e.g., "Brave") into the search bar and press the "Search" button. The app will display relevant books along with their short descriptions.
22
+
23
+ For each book, you can click "Recommend Similar Items" to see other books you might enjoy if you liked the selected one.
24
+
25
+ ### How to reproduce embeddings computation?
26
+
27
+ First, install needed requirements from `exp/requirements.txt` (or `exp/requirements_gpu.txt` for gpu) file.
28
+
29
+ Then, download needed raw data from [Goodreads website](https://mengtingwan.github.io/data/goodreads#datasets). We will need the following files: `book_id_map.csv`, `goodreads_books.json`, `goodreads_interactions.csv` and `user_id_map.csv`. You can download this files manually or use this [Kaggle dataset](https://www.kaggle.com/datasets/eremeevd/graph-rec-goodreads).
30
+
31
+ Finally, simply run the following command at the root of the repo:
32
+ ```
33
+ sh exp/prepare_embeddings.sh INPUT_DIRECTORY SAVE_DIRECTORY
34
+ ```
35
+ where `INPUT_DIRECTORY` is path to the directory with raw data (e.g. `/kaggle/input/graph-rec-goodreads/goodreads-books`). And `SAVE_DIRECTORY` is path to the directory, where results will be saved (e.g. `/kaggle/working/embeddings`). To use obtained embeddings, copy the following files to the `app/data`: `index.faiss` and `items.db`.
36
+
37
+ To run on GPU, run the following command:
38
+ ```
39
+ sh exp/prepare_embeddings.sh INPUT_DIRECTORY SAVE_DIRECTORY cuda
40
+ ```
41
+
42
+ For further information, refer to the `exp` directory in this repo.
app/__init__.py ADDED
File without changes
app/database.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sqlite3
3
+
4
+ import numpy as np
5
+
6
+
7
+ class ItemDatabase:
8
+ def __init__(self, db_path):
9
+ sqlite3.register_converter("embedding", self._text_to_numpy_array)
10
+ self._db_path = db_path
11
+
12
+ @staticmethod
13
+ def _text_to_numpy_array(text):
14
+ out = io.BytesIO(text)
15
+ out.seek(0)
16
+ return np.load(out)
17
+
18
+ def _connect(self):
19
+ return sqlite3.connect(
20
+ self._db_path, detect_types=sqlite3.PARSE_DECLTYPES)
21
+
22
+ def search_items(self, query, n_items=10):
23
+ with self._connect() as conn:
24
+ c = conn.cursor()
25
+ c.execute(f"select item_id from items where title like '%{query}%'")
26
+ rows = c.fetchall()[:n_items]
27
+ return [row[0] for row in rows]
28
+
29
+ def get_item(self, item_id):
30
+
31
+ with self._connect() as conn:
32
+ c = conn.cursor()
33
+ c.row_factory = sqlite3.Row
34
+ c.execute(f"select * from items where item_id like '{item_id}'")
35
+ return c.fetchone()
app/main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import sys
3
+
4
+ import streamlit as st
5
+
6
+ from app.database import ItemDatabase
7
+ from app.recommendations import RecommenderSystem
8
+
9
+
10
+ def show_item(item_id):
11
+ item = st.session_state["db"].get_item(item_id)
12
+ title = item["title"]
13
+ with st.container(border=True):
14
+ st.write(f"**{title}**")
15
+ st.write(item["description"])
16
+ if st.button("Recommend similar items", key=item["item_id"]):
17
+ st.session_state["recommendation_query"] = item["item_id"]
18
+ st.session_state["search_query"] = None # reset
19
+ st.rerun()
20
+
21
+
22
+ def main():
23
+ st.title("Graph-based RecSys")
24
+
25
+ if "db" not in st.session_state:
26
+ st.session_state["db"] = ItemDatabase(
27
+ db_path="/data/items.db")
28
+ if "recsys" not in st.session_state:
29
+ st.session_state["recsys"] = RecommenderSystem(
30
+ faiss_index_path="/data/index.faiss",
31
+ db_path="/data/items.db")
32
+
33
+ if "search_query" not in st.session_state:
34
+ st.session_state["search_query"] = None
35
+ if "recommendation_query" not in st.session_state:
36
+ st.session_state["recommendation_query"] = None
37
+
38
+ search_query = st.text_input("Enter item name", st.session_state["search_query"])
39
+
40
+ if st.button("Search"):
41
+ st.session_state["search_query"] = search_query
42
+ st.session_state["recommendation_query"] = None # reset
43
+
44
+ if st.session_state["recommendation_query"] is not None:
45
+ query = st.session_state["recommendation_query"]
46
+ base_item_title = st.session_state["db"].get_item(query)["title"]
47
+ st.subheader(f'Recommendation Results for "{base_item_title}"')
48
+ results = st.session_state["recsys"].recommend_items(query)
49
+ for item_id in results:
50
+ show_item(item_id)
51
+
52
+ elif st.session_state["search_query"] is not None:
53
+ query = st.session_state["search_query"]
54
+ st.subheader(f'Search Results for "{query}"')
55
+ results = st.session_state["db"].search_items(query)
56
+ for item_id in results:
57
+ show_item(item_id)
58
+
59
+ if __name__ == "__main__":
60
+ main()
app/recommendations.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+
4
+ import numpy as np
5
+ import faiss
6
+
7
+ from app.database import ItemDatabase
8
+
9
+
10
+ class RecommenderSystem:
11
+ def __init__(self, faiss_index_path, db_path):
12
+ self._index = faiss.read_index(faiss_index_path)
13
+ self._db = ItemDatabase(db_path)
14
+
15
+ def recommend_items(self, query, n_items=10):
16
+ query_embedding = self._db.get_item(query)["embedding"]
17
+ _, results = self._index.search(query_embedding, k=n_items+1)
18
+ results = filter(lambda item: item != query, results[0])
19
+ return itertools.islice(results, n_items)
exp/__init__.py ADDED
File without changes
exp/deepwalk.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import dgl
7
+ import torch
8
+ import wandb
9
+ from tqdm.auto import tqdm
10
+
11
+ from utils import prepare_graphs, extract_item_embeddings, normalize_embeddings
12
+
13
+
14
+ def prepare_deepwalk_embeddings(
15
+ items_path,
16
+ ratings_path,
17
+ embeddings_savepath,
18
+ emb_dim,
19
+ window_size,
20
+ batch_size,
21
+ lr,
22
+ num_epochs,
23
+ device,
24
+ wandb_name,
25
+ use_wandb
26
+ ):
27
+ ### Prepare graph
28
+ bipartite_graph, graph = prepare_graphs(items_path, ratings_path)
29
+ bipartite_graph = bipartite_graph.to(device)
30
+ graph = graph.to(device)
31
+
32
+ ### Run deepwalk
33
+ if use_wandb:
34
+ wandb.init(project="graph-recs-deepwalk", name=wandb_name)
35
+
36
+ model = dgl.nn.DeepWalk(graph, emb_dim=emb_dim, window_size=window_size)
37
+ model = model.to(device)
38
+ dataloader = torch.utils.data.DataLoader(
39
+ torch.arange(graph.num_nodes()),
40
+ batch_size=batch_size,
41
+ shuffle=True,
42
+ collate_fn=model.sample)
43
+
44
+ optimizer = torch.optim.SparseAdam(model.parameters(), lr=lr)
45
+ for epoch in range(num_epochs):
46
+ for batch_walk in tqdm(dataloader):
47
+ loss = model(batch_walk)
48
+ if use_wandb:
49
+ wandb.log({"loss": loss.item()})
50
+ optimizer.zero_grad()
51
+ loss.backward()
52
+ optimizer.step()
53
+
54
+ if use_wandb:
55
+ wandb.finish()
56
+
57
+ node_embeddings = model.node_embed.weight.detach().to(device)
58
+
59
+ ### Extract & save item embeddings
60
+ item_embeddings = extract_item_embeddings(node_embeddings, bipartite_graph, graph)
61
+ item_embeddings = normalize_embeddings(item_embeddings)
62
+ np.save(embeddings_savepath, item_embeddings)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description="Prepare DeepWalk embeddings.")
67
+ parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
68
+ parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file.")
69
+ parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where embeddings will be saved.")
70
+ parser.add_argument("--emb_dim", type=int, default=384, help="Dimensionality of the embeddings.")
71
+ parser.add_argument("--window_size", type=int, default=4, help="Window size for the DeepWalk algorithm.")
72
+ parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training.")
73
+ parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for training.")
74
+ parser.add_argument("--num_epochs", type=int, default=2, help="Number of epochs for training.")
75
+ parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
76
+ parser.add_argument("--wandb_name", type=str, help="Name for WandB run.")
77
+ parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
78
+ args = parser.parse_args()
79
+
80
+ prepare_deepwalk_embeddings(**vars(args))
exp/evaluate.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ from app.recommendations import RecommenderSystem
8
+
9
+
10
+ def precision_at_k(recommended_items, relevant_items, k):
11
+ recommended_at_k = set(recommended_items[:k])
12
+ relevant_set = set(relevant_items)
13
+ return len(recommended_at_k & relevant_set) / k
14
+
15
+
16
+ def evaluate_recsys(
17
+ metrics_savepath,
18
+ val_ratings_path,
19
+ faiss_index_path,
20
+ db_path,
21
+ n_recommend_items,
22
+ ):
23
+ recsys = RecommenderSystem(
24
+ faiss_index_path=faiss_index_path,
25
+ db_path=db_path)
26
+
27
+ val_ratings = pd.read_csv(val_ratings_path)
28
+ grouped_items = val_ratings.groupby("user_id")["item_id"].apply(list).reset_index()
29
+ grouped_items = grouped_items["item_id"].tolist()
30
+
31
+
32
+ metric_arrays = {
33
+ "precision@1": [],
34
+ "precision@3": [],
35
+ "precision@10": []
36
+ }
37
+
38
+ for item_group in grouped_items:
39
+ if len(item_group) == 1:
40
+ continue
41
+
42
+ ### Precision@k is computed for each edge.
43
+ ### We will first aggregate it over all edges for user
44
+ ### And after that - aggregate over all users
45
+ user_metric_arrays = dict()
46
+ for metric in metric_arrays.keys():
47
+ user_metric_arrays[metric] = []
48
+
49
+ for item in item_group:
50
+ recommend_items = list(recsys.recommend_items(item, n_recommend_items))
51
+ relevant_items = set(item_group) - {item}
52
+
53
+ user_metric_arrays["precision@1"].append(
54
+ precision_at_k(recommend_items, relevant_items, k=1))
55
+ user_metric_arrays["precision@3"].append(
56
+ precision_at_k(recommend_items, relevant_items, k=3))
57
+ user_metric_arrays["precision@10"].append(
58
+ precision_at_k(recommend_items, relevant_items, k=10))
59
+
60
+ for metric in metric_arrays.keys():
61
+ user_metric = np.mean(user_metric_arrays[metric])
62
+ metric_arrays[metric].append(user_metric)
63
+
64
+ metrics = dict()
65
+ for metric, array in metric_arrays.items():
66
+ metrics[metric] = np.mean(array)
67
+
68
+ with open(metrics_savepath, "w") as f:
69
+ json.dump(metrics, f)
70
+ print(f"Saved metrics to {metrics_savepath}")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ parser = argparse.ArgumentParser(description="Evaluate a recommendation system.")
75
+ parser.add_argument("--metrics_savepath", required=True, type=str, help="Path to save the evaluation metrics.")
76
+ parser.add_argument("--val_ratings_path", required=True, type=str, help="Path to the csv file with validation ratings.")
77
+ parser.add_argument("--faiss_index_path", required=True, type=str, help="Path to the FAISS index.")
78
+ parser.add_argument("--db_path", required=True, type=str, help="Path to the database file.")
79
+ parser.add_argument("--n_recommend_items", type=int, default=10, help="Number of items to recommend.")
80
+ args = parser.parse_args()
81
+ evaluate_recsys(**vars(args))
exp/gnn.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import dgl
7
+ import torch
8
+ import wandb
9
+ from tqdm.auto import tqdm
10
+
11
+ from utils import prepare_graphs, normalize_embeddings, LRSchedule
12
+
13
+
14
+ class GNNLayer(torch.nn.Module):
15
+ def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
16
+ super().__init__()
17
+ self._skip_connection = skip_connection
18
+ self._bidirectional = bidirectional
19
+
20
+ self._norm = torch.nn.LayerNorm(hidden_dim)
21
+ self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
22
+ self._activation = torch.nn.ReLU()
23
+
24
+ if bidirectional:
25
+ self._norm_rev = torch.nn.LayerNorm(hidden_dim)
26
+ self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
27
+ self._activation_rev = torch.nn.ReLU()
28
+
29
+ def forward(self, graph, x):
30
+ y = self._activation(self._conv(graph, self._norm(x)))
31
+ if self._bidirectional:
32
+ y = y + self._activation_rev(self._conv_rev(dgl.reverse(graph), self._norm_rev(x)))
33
+ if self._skip_connection:
34
+ return x + y
35
+ else:
36
+ return y
37
+
38
+
39
+ class GNNModel(torch.nn.Module):
40
+ def __init__(
41
+ self,
42
+ bipartite_graph,
43
+ text_embeddings,
44
+ deepwalk_embeddings,
45
+ num_layers,
46
+ hidden_dim,
47
+ aggregator_type,
48
+ skip_connection,
49
+ bidirectional,
50
+ num_traversals,
51
+ termination_prob,
52
+ num_random_walks,
53
+ num_neighbor,
54
+ ):
55
+ super().__init__()
56
+
57
+ self._bipartite_graph = bipartite_graph
58
+ self._text_embeddings = text_embeddings
59
+ self._deepwalk_embeddings = deepwalk_embeddings
60
+
61
+ self._sampler = dgl.sampling.PinSAGESampler(
62
+ bipartite_graph, "Item", "User", num_traversals,
63
+ termination_prob, num_random_walks, num_neighbor)
64
+
65
+ self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
66
+ self._deepwalk_encoder = torch.nn.Linear(deepwalk_embeddings.shape[-1], hidden_dim)
67
+
68
+ self._layers = torch.nn.ModuleList()
69
+ for _ in range(num_layers):
70
+ self._layers.append(GNNLayer(
71
+ hidden_dim, aggregator_type, skip_connection, bidirectional))
72
+
73
+ def _sample_subraph(self, frontier_ids):
74
+ num_layers = len(self._layers)
75
+ device = self._bipartite_graph.device
76
+
77
+ subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
78
+ prev_ids = set()
79
+
80
+ for _ in range(num_layers):
81
+ frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
82
+ new_edges = self._sampler(frontier_ids).edges()
83
+ subgraph.add_edges(*new_edges)
84
+ prev_ids |= set(frontier_ids.cpu().tolist())
85
+ frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
86
+ frontier_ids = list(frontier_ids - prev_ids)
87
+
88
+ return subgraph
89
+
90
+ def forward(self, ids):
91
+ ### Sample subgraph
92
+ sampled_subgraph = self._sample_subraph(ids)
93
+ sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
94
+
95
+ ### Encode text & DeepWalk embeddings
96
+ text_embeddings = self._text_embeddings[
97
+ sampled_subgraph.ndata[dgl.NID]]
98
+ deepwalk_embeddings = self._deepwalk_embeddings[
99
+ sampled_subgraph.ndata[dgl.NID]]
100
+ features = self._text_encoder(text_embeddings) \
101
+ + self._deepwalk_encoder(deepwalk_embeddings)
102
+
103
+ ### GNN goes brr...
104
+ for layer in self._layers:
105
+ features = layer(sampled_subgraph, features)
106
+
107
+ ### Select features for initial ids
108
+ # TODO: write it more efficiently?
109
+ matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
110
+ ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
111
+ features = features[ids_in_subgraph]
112
+
113
+ ### Normalize and return
114
+ features = features / torch.linalg.norm(features, dim=1, keepdim=True)
115
+ return features
116
+
117
+
118
+ ### Based on https://arxiv.org/pdf/2205.03169
119
+ def nt_xent_loss(sim, temperature):
120
+ sim = sim / temperature
121
+ n = sim.shape[0] // 2 # n = |user_batch|
122
+
123
+ aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
124
+
125
+ mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
126
+ sim = torch.where(mask, -torch.inf, sim)
127
+ sim = sim[:n, :]
128
+ distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
129
+
130
+ loss = aligment_loss + distribution_loss
131
+ return loss
132
+
133
+
134
+ def sample_item_batch(user_batch, bipartite_graph):
135
+ sampled_edges = dgl.sampling.sample_neighbors(
136
+ bipartite_graph, {"User": user_batch}, fanout=2
137
+ ).edges(etype="ItemUser")
138
+ item_batch = sampled_edges[0]
139
+ item_batch = item_batch[torch.argsort(sampled_edges[1])]
140
+ item_batch = item_batch.reshape(-1, 2)
141
+ item_batch = item_batch.T
142
+ return item_batch
143
+
144
+
145
+ def prepare_gnn_embeddings(
146
+ # Paths
147
+ items_path,
148
+ ratings_path,
149
+ text_embeddings_path,
150
+ deepwalk_embeddings_path,
151
+ embeddings_savepath,
152
+ # Learning hyperparameters
153
+ temperature,
154
+ batch_size,
155
+ lr,
156
+ num_epochs,
157
+ # Model hyperparameters
158
+ num_layers,
159
+ hidden_dim,
160
+ aggregator_type,
161
+ skip_connection,
162
+ bidirectional,
163
+ num_traversals,
164
+ termination_prob,
165
+ num_random_walks,
166
+ num_neighbor,
167
+ # Misc
168
+ device,
169
+ wandb_name,
170
+ use_wandb,
171
+ ):
172
+ ### Prepare graph
173
+ bipartite_graph, _ = prepare_graphs(items_path, ratings_path)
174
+ bipartite_graph = bipartite_graph.to(device)
175
+
176
+ ### Init wandb
177
+ if use_wandb:
178
+ wandb.init(project="graph-rec-gnn", name=wandb_name)
179
+
180
+ ### Prepare model
181
+ text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
182
+ deepwalk_embeddings = torch.tensor(np.load(deepwalk_embeddings_path)).to(device)
183
+ model = GNNModel(
184
+ bipartite_graph=bipartite_graph,
185
+ text_embeddings=text_embeddings,
186
+ deepwalk_embeddings=deepwalk_embeddings,
187
+ num_layers=num_layers,
188
+ hidden_dim=hidden_dim,
189
+ aggregator_type=aggregator_type,
190
+ skip_connection=skip_connection,
191
+ bidirectional=bidirectional,
192
+ num_traversals=num_traversals,
193
+ termination_prob=termination_prob,
194
+ num_random_walks=num_random_walks,
195
+ num_neighbor=num_neighbor
196
+ )
197
+ model = model.to(device)
198
+
199
+ ### Prepare dataloader
200
+ all_users = torch.arange(bipartite_graph.num_nodes("User")).to(device)
201
+ all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
202
+ dataloader = torch.utils.data.DataLoader(
203
+ all_users, batch_size=batch_size, shuffle=True, drop_last=True)
204
+
205
+ ### Prepare optimizer & LR scheduler
206
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
207
+ total_steps = num_epochs * len(dataloader)
208
+ lr_schedule = LRSchedule(
209
+ total_steps=total_steps,
210
+ warmup_steps=int(0.1*total_steps),
211
+ final_factor=0.1) # TODO: move to args
212
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
213
+
214
+ ### Train loop
215
+ model.train()
216
+ for epoch in range(num_epochs):
217
+ for user_batch in tqdm(dataloader):
218
+ item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
219
+ item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
220
+ features = model(item_batch) # (2 * |user_batch|, hidden_dim)
221
+ sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
222
+ loss = nt_xent_loss(sim, temperature)
223
+ if use_wandb:
224
+ wandb.log({"loss": loss.item()})
225
+ optimizer.zero_grad()
226
+ loss.backward()
227
+ optimizer.step()
228
+ lr_scheduler.step()
229
+
230
+ if use_wandb:
231
+ wandb.finish()
232
+
233
+ ### Process full dataset
234
+ model.eval()
235
+ with torch.no_grad():
236
+ hidden_dim = text_embeddings.shape[-1]
237
+ item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
238
+ for items_batch in tqdm(torch.utils.data.DataLoader(
239
+ torch.arange(bipartite_graph.num_nodes("Item")),
240
+ batch_size=batch_size,
241
+ shuffle=True
242
+ )):
243
+ item_embeddings[items_batch] = model(items_batch.to(device))
244
+
245
+ ### Extract & save item embeddings
246
+ item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
247
+ np.save(embeddings_savepath, item_embeddings)
248
+
249
+
250
+ if __name__ == "__main__":
251
+ parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
252
+
253
+ # Paths
254
+ parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
255
+ parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file")
256
+ parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
257
+ parser.add_argument("--deepwalk_embeddings_path", type=str, required=True, help="Path to the deepwalk embeddings file")
258
+ parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
259
+
260
+ # Learning hyperparameters
261
+ parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
262
+ parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")
263
+ parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
264
+ parser.add_argument("--num_epochs", type=int, default=4, help="Number of epochs")
265
+
266
+ # Model hyperparameters
267
+ parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
268
+ parser.add_argument("--hidden_dim", type=int, default=384, help="Hidden dimension size")
269
+ parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
270
+ parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
271
+ parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
272
+ parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
273
+ parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
274
+ parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
275
+ parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
276
+
277
+ # Misc
278
+ parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
279
+ parser.add_argument("--wandb_name", type=str, help="WandB run name")
280
+ parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
281
+
282
+ args = parser.parse_args()
283
+
284
+ prepare_gnn_embeddings(**vars(args))
exp/prepare_db.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sqlite3
3
+ import io
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+
9
+ def convert_numpy_array_to_text(array):
10
+ stream = io.BytesIO()
11
+ np.save(stream, array)
12
+ stream.seek(0)
13
+ return sqlite3.Binary(stream.read())
14
+
15
+
16
+ def prepare_items_db(items_path, embeddings_path, db_path):
17
+ items = pd.read_csv(items_path)
18
+ embeddings = np.load(embeddings_path)
19
+ items["embedding"] = np.split(embeddings, embeddings.shape[0])
20
+
21
+ sqlite3.register_adapter(np.ndarray, convert_numpy_array_to_text)
22
+ with sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
23
+ items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
24
+
25
+
26
+ if __name__ == "__main__":
27
+ parser = argparse.ArgumentParser(description="Prepare items database from a CSV file.")
28
+ parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
29
+ parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
30
+ parser.add_argument("--db_path", required=True, type=str, help="Path to the SQLite database file.")
31
+
32
+ args = parser.parse_args()
33
+ prepare_items_db(**vars(args))
exp/prepare_embeddings.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ input_directory="$1"
5
+ save_directory="$2"
6
+ device="${3:-cpu}"
7
+
8
+ echo Running on "$device".
9
+
10
+ PYTHONPATH=. python exp/process_raw_data.py \
11
+ --input_directory "$input_directory" \
12
+ --save_directory "$save_directory" \
13
+ --create_train_val_split
14
+
15
+ PYTHONPATH=. python exp/deepwalk.py \
16
+ --items_path "$save_directory/items.csv" \
17
+ --ratings_path "$save_directory/train_ratings.csv" \
18
+ --embeddings_savepath "$save_directory/deepwalk_embeddings.npy" \
19
+ --device $device \
20
+ --no_wandb
21
+
22
+ PYTHONPATH=. python exp/sbert.py \
23
+ --items_path "$save_directory/items.csv" \
24
+ --embeddings_savepath "$save_directory/text_embeddings.npy" \
25
+ --device $device
26
+
27
+ PYTHONPATH=. python exp/gnn.py \
28
+ --items_path "$save_directory/items.csv" \
29
+ --ratings_path "$save_directory/train_ratings.csv" \
30
+ --text_embeddings_path "$save_directory/text_embeddings.npy" \
31
+ --deepwalk_embeddings_path "$save_directory/deepwalk_embeddings.npy" \
32
+ --embeddings_savepath "$save_directory/embeddings.npy"\
33
+ --device $device \
34
+ --no_wandb
35
+
36
+ PYTHONPATH=. python exp/prepare_index.py \
37
+ --embeddings_path "$save_directory/embeddings.npy" \
38
+ --save_path "$save_directory/index.faiss"
39
+
40
+ PYTHONPATH=. python exp/prepare_db.py \
41
+ --items_path "$save_directory/items.csv" \
42
+ --embeddings_path "$save_directory/embeddings.npy" \
43
+ --db_path "$save_directory/items.db"
44
+
45
+ PYTHONPATH=. python exp/evaluate.py \
46
+ --metrics_savepath "$save_directory/metrics.json" \
47
+ --val_ratings_path "$save_directory/val_ratings.csv" \
48
+ --faiss_index_path "$save_directory/index.faiss" \
49
+ --db_path "$save_directory/items.db"
50
+
51
+ echo "Evaluation metrics:"
52
+ cat "$save_directory/metrics.json"
exp/prepare_index.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import faiss
4
+ import numpy as np
5
+
6
+
7
+ def build_index(embeddings_path, save_path, n_neighbors):
8
+ embeddings = np.load(embeddings_path)
9
+ index = faiss.IndexHNSWFlat(embeddings.shape[-1], 32)
10
+ index.add(embeddings)
11
+ faiss.write_index(index, save_path)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser(description="Build an HNSW index from embeddings.")
16
+ parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the embeddings file.")
17
+ parser.add_argument("--save_path", type=str, required=True, help="Path to save the built index.")
18
+ parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
19
+ args = parser.parse_args()
20
+ build_index(**vars(args))
exp/process_raw_data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+
9
+ def book_filter(book, ratings_count_threshold=10_000):
10
+ try:
11
+ if book["ratings_count"] == "":
12
+ return False
13
+ if int(book["ratings_count"]) < ratings_count_threshold:
14
+ return False
15
+
16
+ if book["description"] == "":
17
+ return False
18
+
19
+ if book["title"] == "":
20
+ return False
21
+
22
+ if book["title_without_series"] == "":
23
+ return False
24
+
25
+ possible_lang_codes = {"eng", "en-GB", "en-US"}
26
+ if not book["language_code"] in possible_lang_codes:
27
+ return False
28
+
29
+ return True
30
+ except Exception:
31
+ return False
32
+
33
+
34
+ def process_raw_data_goodreads(input_directory, save_directory, positive_rating_threshold = 4.0):
35
+ os.makedirs(save_directory, exist_ok=True)
36
+
37
+ ### Process items
38
+ columns = [
39
+ "book_id",
40
+ "description",
41
+ "title_without_series",
42
+ ]
43
+ numeric_columns = [
44
+ "book_id",
45
+ ]
46
+
47
+ items = []
48
+ with open(os.path.join(input_directory, "goodreads_books.json"), "r") as f:
49
+ for line in f:
50
+ item = json.loads(line)
51
+ if book_filter(item):
52
+ items.append([item[col] for col in columns])
53
+ items = pd.DataFrame(items, columns=columns)
54
+ for col in numeric_columns:
55
+ items[col] = pd.to_numeric(items[col])
56
+ items["item_id"] = items.index
57
+ items["title"] = items["title_without_series"]
58
+ items.drop("title_without_series", axis=1, inplace=True)
59
+ items.to_csv(os.path.join(save_directory, "items.csv"), index=False)
60
+
61
+ ### Process ratings
62
+ ratings = pd.read_csv(os.path.join(input_directory, "goodreads_interactions.csv"))
63
+
64
+ book_id_map = pd.read_csv(os.path.join(input_directory, "book_id_map.csv"))
65
+ csv_to_usual_map = dict(zip(book_id_map["book_id_csv"], book_id_map["book_id"]))
66
+ usual_to_csv_map = dict(zip(book_id_map["book_id"], book_id_map["book_id_csv"]))
67
+
68
+ book_ids = items["book_id"].unique()
69
+ book_ids_csv = set([usual_to_csv_map[book_id] for book_id in book_ids])
70
+ ratings = ratings[ratings["rating"] >= positive_rating_threshold]
71
+ ratings = ratings[ratings["book_id"].isin(book_ids_csv)]
72
+
73
+ book_to_item_id_map = dict(zip(items["book_id"], items["item_id"]))
74
+ ratings["item_id"] = ratings["book_id"].map(csv_to_usual_map).map(book_to_item_id_map)
75
+
76
+ user_ids = list(ratings["user_id"].unique())
77
+ user_ids_map = dict(zip(user_ids, range(len(user_ids))))
78
+ ratings["user_id"] = ratings["user_id"].map(user_ids_map)
79
+
80
+ ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
81
+
82
+
83
+ def create_train_val_split(ratings_path, train_savepath, val_savepath, seed=42):
84
+ ratings = pd.read_csv(ratings_path)
85
+ user_ids = ratings["user_id"].unique()
86
+
87
+ rng = np.random.default_rng(seed=seed)
88
+ train_size = int(len(user_ids) * 0.9)
89
+ train_indices = rng.choice(user_ids, size=train_size, replace=False)
90
+
91
+ train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
92
+ val_data = ratings.loc[~ratings["user_id"].isin(train_indices)]
93
+
94
+ print(f"Train size: {len(train_data)}.")
95
+ print(f"Validation size: {len(val_data)}.")
96
+
97
+ train_data.to_csv(train_savepath, index=False)
98
+ val_data.to_csv(val_savepath, index=False)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ parser = argparse.ArgumentParser(description="Process raw data.")
103
+ parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
104
+ parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
105
+ parser.add_argument("--create_train_val_split", action="store_true", help="Flag to indicate whether to create a train-validation split.")
106
+ args = parser.parse_args()
107
+
108
+ print("Processing raw data...")
109
+ process_raw_data_goodreads(args.input_directory, args.save_directory)
110
+ if args.create_train_val_split:
111
+ create_train_val_split(
112
+ os.path.join(args.save_directory, "ratings.csv"),
113
+ os.path.join(args.save_directory, "train_ratings.csv"),
114
+ os.path.join(args.save_directory, "val_ratings.csv")
115
+ )
116
+ print("The raw data has been successfully processed.")
exp/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ -r ../requirements.txt # install base requirements of app
2
+ dgl==2.1.0
3
+ torch==2.1.2
4
+ wandb==0.17.0
5
+ tqdm==4.66.4
6
+ pydantic==2.5.3
7
+ sentence_transformers==3.0.1
exp/requirements_gpu.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -f https://data.dgl.ai/wheels/cu121/repo.html
2
+ -r ../requirements.txt # install base requirements of app
3
+ dgl==2.1.0
4
+ torch==2.1.0
5
+ wandb==0.17.0
6
+ tqdm==4.66.4
7
+ pydantic==2.5.3
8
+ sentence_transformers==3.0.1
exp/sbert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tqdm.auto import tqdm
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ from utils import normalize_embeddings
10
+
11
+
12
+ def prepare_sbert_embeddings(
13
+ items_path,
14
+ embeddings_savepath,
15
+ model_name,
16
+ batch_size,
17
+ device
18
+ ):
19
+ items = pd.read_csv(items_path).sort_values("item_id")
20
+ sentences = items["description"].values
21
+ model = SentenceTransformer(model_name).to(device)
22
+ embeddings = []
23
+ for start_index in tqdm(range(0, len(sentences), batch_size)):
24
+ batch = sentences[start_index:start_index+batch_size]
25
+ embeddings.extend(model.encode(batch))
26
+ embeddings = normalize_embeddings(np.array(embeddings))
27
+ np.save(embeddings_savepath, embeddings)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ parser = argparse.ArgumentParser(description="Prepare SBERT embeddings.")
32
+ parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
33
+ parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to save the embeddings.")
34
+ parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2", help="Name of the SBERT model to use.")
35
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
36
+ parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
37
+ args = parser.parse_args()
38
+
39
+ prepare_sbert_embeddings(**vars(args))
exp/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import dgl
4
+ import torch
5
+
6
+
7
+ def normalize_embeddings(embeddings):
8
+ embeddings_norm = np.linalg.norm(embeddings, axis=1)
9
+ nonzero_embeddings = embeddings_norm > 0.0
10
+ embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
11
+ return embeddings
12
+
13
+
14
+ def prepare_graphs(items_path, ratings_path):
15
+ items = pd.read_csv(items_path)
16
+ ratings = pd.read_csv(ratings_path)
17
+
18
+ n_users = np.max(ratings["user_id"].unique()) + 1
19
+ item_ids = torch.tensor(sorted(items["item_id"].unique()))
20
+
21
+ edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
22
+ reverse_edges = (edges[1], edges[0])
23
+
24
+ bipartite_graph = dgl.heterograph(
25
+ data_dict={
26
+ ("User", "UserItem", "Item"): edges,
27
+ ("Item", "ItemUser", "User"): reverse_edges
28
+ },
29
+ num_nodes_dict={
30
+ "User": n_users,
31
+ "Item": len(item_ids)
32
+ }
33
+ )
34
+ graph = dgl.to_homogeneous(bipartite_graph)
35
+ graph = dgl.add_self_loop(graph)
36
+ return bipartite_graph, graph
37
+
38
+
39
+ def extract_item_embeddings(node_embeddings, bipartite_graph, graph):
40
+ item_ntype = bipartite_graph.ntypes.index("Item")
41
+ item_mask = graph.ndata[dgl.NTYPE] == item_ntype
42
+ item_embeddings = node_embeddings[item_mask]
43
+ original_ids = graph.ndata[dgl.NID][item_mask]
44
+ item_embeddings = item_embeddings[torch.argsort(original_ids)]
45
+ return item_embeddings.cpu().numpy()
46
+
47
+
48
+ class LRSchedule:
49
+ def __init__(self, total_steps, warmup_steps, final_factor):
50
+ self._total_steps = total_steps
51
+ self._warmup_steps = warmup_steps
52
+ self._final_factor = final_factor
53
+
54
+ def __call__(self, step):
55
+ if step >= self._total_steps:
56
+ return self._final_factor
57
+
58
+ if self._warmup_steps > 0:
59
+ warmup_factor = step / self._warmup_steps
60
+ else:
61
+ warmup_factor = 1.0
62
+
63
+ steps_after_warmup = step - self._warmup_steps
64
+ total_steps_after_warmup = self._total_steps - self._warmup_steps
65
+ after_warmup_factor = 1 \
66
+ - (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
67
+
68
+ factor = min(warmup_factor, after_warmup_factor)
69
+ return min(max(factor, 0), 1)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.35.0
2
+ pandas==2.2.2
3
+ numpy==1.26.4
4
+ faiss-cpu==1.8.0