erermeev-d
commited on
Commit
·
d4852d9
0
Parent(s):
Initial commit
Browse files- .gitattributes +1 -0
- Dockerfile +18 -0
- Makefile +3 -0
- README.md +42 -0
- app/__init__.py +0 -0
- app/database.py +35 -0
- app/main.py +60 -0
- app/recommendations.py +19 -0
- exp/__init__.py +0 -0
- exp/deepwalk.py +80 -0
- exp/evaluate.py +81 -0
- exp/gnn.py +284 -0
- exp/prepare_db.py +33 -0
- exp/prepare_embeddings.sh +52 -0
- exp/prepare_index.py +20 -0
- exp/process_raw_data.py +116 -0
- exp/requirements.txt +7 -0
- exp/requirements_gpu.txt +8 -0
- exp/sbert.py +39 -0
- exp/utils.py +69 -0
- requirements.txt +4 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
data/* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10.13
|
2 |
+
|
3 |
+
COPY requirements.txt .
|
4 |
+
RUN pip3 install -r requirements.txt
|
5 |
+
|
6 |
+
RUN mkdir /data
|
7 |
+
RUN wget https://storage.yandexcloud.net/eremeev-d-bucket-main/1722760266.tar -O data.tar
|
8 |
+
RUN tar -xf data.tar -C /data
|
9 |
+
RUN rm data.tar
|
10 |
+
|
11 |
+
RUN mkdir /app
|
12 |
+
COPY app app
|
13 |
+
|
14 |
+
EXPOSE 8501
|
15 |
+
|
16 |
+
ENV PYTHONPATH .
|
17 |
+
|
18 |
+
ENTRYPOINT ["streamlit", "run", "app/main.py", "--server.port=8501"]
|
Makefile
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
run-app:
|
2 |
+
docker build -t graph-rec-app .
|
3 |
+
docker run --rm -p 8501:8501 graph-rec-app
|
README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: A simple Graph-based Recommender System
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
app_port: 8501
|
8 |
+
---
|
9 |
+
# A simple Graph-based Recommender System
|
10 |
+
|
11 |
+
### What is it?
|
12 |
+
|
13 |
+
This app is a simple graph-based recommender system that searches for items and recommends similar ones. It can be applied to any dataset. For demonstration purposes, we use the (filtered) [Goodreads](https://mengtingwan.github.io/data/goodreads#datasets) dataset.
|
14 |
+
|
15 |
+
### Where can I try this app?
|
16 |
+
|
17 |
+
The app is currently deployed at HuggingFace Spaces ([link](https://huggingface.co/spaces/eremeev-d/graph-rec)). You will probably need to wait a minute or two for app to start running.
|
18 |
+
|
19 |
+
### How to use it?
|
20 |
+
|
21 |
+
Simply enter a keyword (e.g., "Brave") into the search bar and press the "Search" button. The app will display relevant books along with their short descriptions.
|
22 |
+
|
23 |
+
For each book, you can click "Recommend Similar Items" to see other books you might enjoy if you liked the selected one.
|
24 |
+
|
25 |
+
### How to reproduce embeddings computation?
|
26 |
+
|
27 |
+
First, install needed requirements from `exp/requirements.txt` (or `exp/requirements_gpu.txt` for gpu) file.
|
28 |
+
|
29 |
+
Then, download needed raw data from [Goodreads website](https://mengtingwan.github.io/data/goodreads#datasets). We will need the following files: `book_id_map.csv`, `goodreads_books.json`, `goodreads_interactions.csv` and `user_id_map.csv`. You can download this files manually or use this [Kaggle dataset](https://www.kaggle.com/datasets/eremeevd/graph-rec-goodreads).
|
30 |
+
|
31 |
+
Finally, simply run the following command at the root of the repo:
|
32 |
+
```
|
33 |
+
sh exp/prepare_embeddings.sh INPUT_DIRECTORY SAVE_DIRECTORY
|
34 |
+
```
|
35 |
+
where `INPUT_DIRECTORY` is path to the directory with raw data (e.g. `/kaggle/input/graph-rec-goodreads/goodreads-books`). And `SAVE_DIRECTORY` is path to the directory, where results will be saved (e.g. `/kaggle/working/embeddings`). To use obtained embeddings, copy the following files to the `app/data`: `index.faiss` and `items.db`.
|
36 |
+
|
37 |
+
To run on GPU, run the following command:
|
38 |
+
```
|
39 |
+
sh exp/prepare_embeddings.sh INPUT_DIRECTORY SAVE_DIRECTORY cuda
|
40 |
+
```
|
41 |
+
|
42 |
+
For further information, refer to the `exp` directory in this repo.
|
app/__init__.py
ADDED
File without changes
|
app/database.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import sqlite3
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ItemDatabase:
|
8 |
+
def __init__(self, db_path):
|
9 |
+
sqlite3.register_converter("embedding", self._text_to_numpy_array)
|
10 |
+
self._db_path = db_path
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def _text_to_numpy_array(text):
|
14 |
+
out = io.BytesIO(text)
|
15 |
+
out.seek(0)
|
16 |
+
return np.load(out)
|
17 |
+
|
18 |
+
def _connect(self):
|
19 |
+
return sqlite3.connect(
|
20 |
+
self._db_path, detect_types=sqlite3.PARSE_DECLTYPES)
|
21 |
+
|
22 |
+
def search_items(self, query, n_items=10):
|
23 |
+
with self._connect() as conn:
|
24 |
+
c = conn.cursor()
|
25 |
+
c.execute(f"select item_id from items where title like '%{query}%'")
|
26 |
+
rows = c.fetchall()[:n_items]
|
27 |
+
return [row[0] for row in rows]
|
28 |
+
|
29 |
+
def get_item(self, item_id):
|
30 |
+
|
31 |
+
with self._connect() as conn:
|
32 |
+
c = conn.cursor()
|
33 |
+
c.row_factory = sqlite3.Row
|
34 |
+
c.execute(f"select * from items where item_id like '{item_id}'")
|
35 |
+
return c.fetchone()
|
app/main.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
from app.database import ItemDatabase
|
7 |
+
from app.recommendations import RecommenderSystem
|
8 |
+
|
9 |
+
|
10 |
+
def show_item(item_id):
|
11 |
+
item = st.session_state["db"].get_item(item_id)
|
12 |
+
title = item["title"]
|
13 |
+
with st.container(border=True):
|
14 |
+
st.write(f"**{title}**")
|
15 |
+
st.write(item["description"])
|
16 |
+
if st.button("Recommend similar items", key=item["item_id"]):
|
17 |
+
st.session_state["recommendation_query"] = item["item_id"]
|
18 |
+
st.session_state["search_query"] = None # reset
|
19 |
+
st.rerun()
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
st.title("Graph-based RecSys")
|
24 |
+
|
25 |
+
if "db" not in st.session_state:
|
26 |
+
st.session_state["db"] = ItemDatabase(
|
27 |
+
db_path="/data/items.db")
|
28 |
+
if "recsys" not in st.session_state:
|
29 |
+
st.session_state["recsys"] = RecommenderSystem(
|
30 |
+
faiss_index_path="/data/index.faiss",
|
31 |
+
db_path="/data/items.db")
|
32 |
+
|
33 |
+
if "search_query" not in st.session_state:
|
34 |
+
st.session_state["search_query"] = None
|
35 |
+
if "recommendation_query" not in st.session_state:
|
36 |
+
st.session_state["recommendation_query"] = None
|
37 |
+
|
38 |
+
search_query = st.text_input("Enter item name", st.session_state["search_query"])
|
39 |
+
|
40 |
+
if st.button("Search"):
|
41 |
+
st.session_state["search_query"] = search_query
|
42 |
+
st.session_state["recommendation_query"] = None # reset
|
43 |
+
|
44 |
+
if st.session_state["recommendation_query"] is not None:
|
45 |
+
query = st.session_state["recommendation_query"]
|
46 |
+
base_item_title = st.session_state["db"].get_item(query)["title"]
|
47 |
+
st.subheader(f'Recommendation Results for "{base_item_title}"')
|
48 |
+
results = st.session_state["recsys"].recommend_items(query)
|
49 |
+
for item_id in results:
|
50 |
+
show_item(item_id)
|
51 |
+
|
52 |
+
elif st.session_state["search_query"] is not None:
|
53 |
+
query = st.session_state["search_query"]
|
54 |
+
st.subheader(f'Search Results for "{query}"')
|
55 |
+
results = st.session_state["db"].search_items(query)
|
56 |
+
for item_id in results:
|
57 |
+
show_item(item_id)
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
app/recommendations.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
from app.database import ItemDatabase
|
8 |
+
|
9 |
+
|
10 |
+
class RecommenderSystem:
|
11 |
+
def __init__(self, faiss_index_path, db_path):
|
12 |
+
self._index = faiss.read_index(faiss_index_path)
|
13 |
+
self._db = ItemDatabase(db_path)
|
14 |
+
|
15 |
+
def recommend_items(self, query, n_items=10):
|
16 |
+
query_embedding = self._db.get_item(query)["embedding"]
|
17 |
+
_, results = self._index.search(query_embedding, k=n_items+1)
|
18 |
+
results = filter(lambda item: item != query, results[0])
|
19 |
+
return itertools.islice(results, n_items)
|
exp/__init__.py
ADDED
File without changes
|
exp/deepwalk.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import dgl
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
from utils import prepare_graphs, extract_item_embeddings, normalize_embeddings
|
12 |
+
|
13 |
+
|
14 |
+
def prepare_deepwalk_embeddings(
|
15 |
+
items_path,
|
16 |
+
ratings_path,
|
17 |
+
embeddings_savepath,
|
18 |
+
emb_dim,
|
19 |
+
window_size,
|
20 |
+
batch_size,
|
21 |
+
lr,
|
22 |
+
num_epochs,
|
23 |
+
device,
|
24 |
+
wandb_name,
|
25 |
+
use_wandb
|
26 |
+
):
|
27 |
+
### Prepare graph
|
28 |
+
bipartite_graph, graph = prepare_graphs(items_path, ratings_path)
|
29 |
+
bipartite_graph = bipartite_graph.to(device)
|
30 |
+
graph = graph.to(device)
|
31 |
+
|
32 |
+
### Run deepwalk
|
33 |
+
if use_wandb:
|
34 |
+
wandb.init(project="graph-recs-deepwalk", name=wandb_name)
|
35 |
+
|
36 |
+
model = dgl.nn.DeepWalk(graph, emb_dim=emb_dim, window_size=window_size)
|
37 |
+
model = model.to(device)
|
38 |
+
dataloader = torch.utils.data.DataLoader(
|
39 |
+
torch.arange(graph.num_nodes()),
|
40 |
+
batch_size=batch_size,
|
41 |
+
shuffle=True,
|
42 |
+
collate_fn=model.sample)
|
43 |
+
|
44 |
+
optimizer = torch.optim.SparseAdam(model.parameters(), lr=lr)
|
45 |
+
for epoch in range(num_epochs):
|
46 |
+
for batch_walk in tqdm(dataloader):
|
47 |
+
loss = model(batch_walk)
|
48 |
+
if use_wandb:
|
49 |
+
wandb.log({"loss": loss.item()})
|
50 |
+
optimizer.zero_grad()
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
|
54 |
+
if use_wandb:
|
55 |
+
wandb.finish()
|
56 |
+
|
57 |
+
node_embeddings = model.node_embed.weight.detach().to(device)
|
58 |
+
|
59 |
+
### Extract & save item embeddings
|
60 |
+
item_embeddings = extract_item_embeddings(node_embeddings, bipartite_graph, graph)
|
61 |
+
item_embeddings = normalize_embeddings(item_embeddings)
|
62 |
+
np.save(embeddings_savepath, item_embeddings)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser(description="Prepare DeepWalk embeddings.")
|
67 |
+
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
|
68 |
+
parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file.")
|
69 |
+
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where embeddings will be saved.")
|
70 |
+
parser.add_argument("--emb_dim", type=int, default=384, help="Dimensionality of the embeddings.")
|
71 |
+
parser.add_argument("--window_size", type=int, default=4, help="Window size for the DeepWalk algorithm.")
|
72 |
+
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training.")
|
73 |
+
parser.add_argument("--lr", type=float, default=1e-2, help="Learning rate for training.")
|
74 |
+
parser.add_argument("--num_epochs", type=int, default=2, help="Number of epochs for training.")
|
75 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
|
76 |
+
parser.add_argument("--wandb_name", type=str, help="Name for WandB run.")
|
77 |
+
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
78 |
+
args = parser.parse_args()
|
79 |
+
|
80 |
+
prepare_deepwalk_embeddings(**vars(args))
|
exp/evaluate.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from app.recommendations import RecommenderSystem
|
8 |
+
|
9 |
+
|
10 |
+
def precision_at_k(recommended_items, relevant_items, k):
|
11 |
+
recommended_at_k = set(recommended_items[:k])
|
12 |
+
relevant_set = set(relevant_items)
|
13 |
+
return len(recommended_at_k & relevant_set) / k
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_recsys(
|
17 |
+
metrics_savepath,
|
18 |
+
val_ratings_path,
|
19 |
+
faiss_index_path,
|
20 |
+
db_path,
|
21 |
+
n_recommend_items,
|
22 |
+
):
|
23 |
+
recsys = RecommenderSystem(
|
24 |
+
faiss_index_path=faiss_index_path,
|
25 |
+
db_path=db_path)
|
26 |
+
|
27 |
+
val_ratings = pd.read_csv(val_ratings_path)
|
28 |
+
grouped_items = val_ratings.groupby("user_id")["item_id"].apply(list).reset_index()
|
29 |
+
grouped_items = grouped_items["item_id"].tolist()
|
30 |
+
|
31 |
+
|
32 |
+
metric_arrays = {
|
33 |
+
"precision@1": [],
|
34 |
+
"precision@3": [],
|
35 |
+
"precision@10": []
|
36 |
+
}
|
37 |
+
|
38 |
+
for item_group in grouped_items:
|
39 |
+
if len(item_group) == 1:
|
40 |
+
continue
|
41 |
+
|
42 |
+
### Precision@k is computed for each edge.
|
43 |
+
### We will first aggregate it over all edges for user
|
44 |
+
### And after that - aggregate over all users
|
45 |
+
user_metric_arrays = dict()
|
46 |
+
for metric in metric_arrays.keys():
|
47 |
+
user_metric_arrays[metric] = []
|
48 |
+
|
49 |
+
for item in item_group:
|
50 |
+
recommend_items = list(recsys.recommend_items(item, n_recommend_items))
|
51 |
+
relevant_items = set(item_group) - {item}
|
52 |
+
|
53 |
+
user_metric_arrays["precision@1"].append(
|
54 |
+
precision_at_k(recommend_items, relevant_items, k=1))
|
55 |
+
user_metric_arrays["precision@3"].append(
|
56 |
+
precision_at_k(recommend_items, relevant_items, k=3))
|
57 |
+
user_metric_arrays["precision@10"].append(
|
58 |
+
precision_at_k(recommend_items, relevant_items, k=10))
|
59 |
+
|
60 |
+
for metric in metric_arrays.keys():
|
61 |
+
user_metric = np.mean(user_metric_arrays[metric])
|
62 |
+
metric_arrays[metric].append(user_metric)
|
63 |
+
|
64 |
+
metrics = dict()
|
65 |
+
for metric, array in metric_arrays.items():
|
66 |
+
metrics[metric] = np.mean(array)
|
67 |
+
|
68 |
+
with open(metrics_savepath, "w") as f:
|
69 |
+
json.dump(metrics, f)
|
70 |
+
print(f"Saved metrics to {metrics_savepath}")
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
parser = argparse.ArgumentParser(description="Evaluate a recommendation system.")
|
75 |
+
parser.add_argument("--metrics_savepath", required=True, type=str, help="Path to save the evaluation metrics.")
|
76 |
+
parser.add_argument("--val_ratings_path", required=True, type=str, help="Path to the csv file with validation ratings.")
|
77 |
+
parser.add_argument("--faiss_index_path", required=True, type=str, help="Path to the FAISS index.")
|
78 |
+
parser.add_argument("--db_path", required=True, type=str, help="Path to the database file.")
|
79 |
+
parser.add_argument("--n_recommend_items", type=int, default=10, help="Number of items to recommend.")
|
80 |
+
args = parser.parse_args()
|
81 |
+
evaluate_recsys(**vars(args))
|
exp/gnn.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import dgl
|
7 |
+
import torch
|
8 |
+
import wandb
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
from utils import prepare_graphs, normalize_embeddings, LRSchedule
|
12 |
+
|
13 |
+
|
14 |
+
class GNNLayer(torch.nn.Module):
|
15 |
+
def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
|
16 |
+
super().__init__()
|
17 |
+
self._skip_connection = skip_connection
|
18 |
+
self._bidirectional = bidirectional
|
19 |
+
|
20 |
+
self._norm = torch.nn.LayerNorm(hidden_dim)
|
21 |
+
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
22 |
+
self._activation = torch.nn.ReLU()
|
23 |
+
|
24 |
+
if bidirectional:
|
25 |
+
self._norm_rev = torch.nn.LayerNorm(hidden_dim)
|
26 |
+
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
|
27 |
+
self._activation_rev = torch.nn.ReLU()
|
28 |
+
|
29 |
+
def forward(self, graph, x):
|
30 |
+
y = self._activation(self._conv(graph, self._norm(x)))
|
31 |
+
if self._bidirectional:
|
32 |
+
y = y + self._activation_rev(self._conv_rev(dgl.reverse(graph), self._norm_rev(x)))
|
33 |
+
if self._skip_connection:
|
34 |
+
return x + y
|
35 |
+
else:
|
36 |
+
return y
|
37 |
+
|
38 |
+
|
39 |
+
class GNNModel(torch.nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
bipartite_graph,
|
43 |
+
text_embeddings,
|
44 |
+
deepwalk_embeddings,
|
45 |
+
num_layers,
|
46 |
+
hidden_dim,
|
47 |
+
aggregator_type,
|
48 |
+
skip_connection,
|
49 |
+
bidirectional,
|
50 |
+
num_traversals,
|
51 |
+
termination_prob,
|
52 |
+
num_random_walks,
|
53 |
+
num_neighbor,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self._bipartite_graph = bipartite_graph
|
58 |
+
self._text_embeddings = text_embeddings
|
59 |
+
self._deepwalk_embeddings = deepwalk_embeddings
|
60 |
+
|
61 |
+
self._sampler = dgl.sampling.PinSAGESampler(
|
62 |
+
bipartite_graph, "Item", "User", num_traversals,
|
63 |
+
termination_prob, num_random_walks, num_neighbor)
|
64 |
+
|
65 |
+
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
|
66 |
+
self._deepwalk_encoder = torch.nn.Linear(deepwalk_embeddings.shape[-1], hidden_dim)
|
67 |
+
|
68 |
+
self._layers = torch.nn.ModuleList()
|
69 |
+
for _ in range(num_layers):
|
70 |
+
self._layers.append(GNNLayer(
|
71 |
+
hidden_dim, aggregator_type, skip_connection, bidirectional))
|
72 |
+
|
73 |
+
def _sample_subraph(self, frontier_ids):
|
74 |
+
num_layers = len(self._layers)
|
75 |
+
device = self._bipartite_graph.device
|
76 |
+
|
77 |
+
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
|
78 |
+
prev_ids = set()
|
79 |
+
|
80 |
+
for _ in range(num_layers):
|
81 |
+
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
|
82 |
+
new_edges = self._sampler(frontier_ids).edges()
|
83 |
+
subgraph.add_edges(*new_edges)
|
84 |
+
prev_ids |= set(frontier_ids.cpu().tolist())
|
85 |
+
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
|
86 |
+
frontier_ids = list(frontier_ids - prev_ids)
|
87 |
+
|
88 |
+
return subgraph
|
89 |
+
|
90 |
+
def forward(self, ids):
|
91 |
+
### Sample subgraph
|
92 |
+
sampled_subgraph = self._sample_subraph(ids)
|
93 |
+
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
|
94 |
+
|
95 |
+
### Encode text & DeepWalk embeddings
|
96 |
+
text_embeddings = self._text_embeddings[
|
97 |
+
sampled_subgraph.ndata[dgl.NID]]
|
98 |
+
deepwalk_embeddings = self._deepwalk_embeddings[
|
99 |
+
sampled_subgraph.ndata[dgl.NID]]
|
100 |
+
features = self._text_encoder(text_embeddings) \
|
101 |
+
+ self._deepwalk_encoder(deepwalk_embeddings)
|
102 |
+
|
103 |
+
### GNN goes brr...
|
104 |
+
for layer in self._layers:
|
105 |
+
features = layer(sampled_subgraph, features)
|
106 |
+
|
107 |
+
### Select features for initial ids
|
108 |
+
# TODO: write it more efficiently?
|
109 |
+
matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
|
110 |
+
ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
|
111 |
+
features = features[ids_in_subgraph]
|
112 |
+
|
113 |
+
### Normalize and return
|
114 |
+
features = features / torch.linalg.norm(features, dim=1, keepdim=True)
|
115 |
+
return features
|
116 |
+
|
117 |
+
|
118 |
+
### Based on https://arxiv.org/pdf/2205.03169
|
119 |
+
def nt_xent_loss(sim, temperature):
|
120 |
+
sim = sim / temperature
|
121 |
+
n = sim.shape[0] // 2 # n = |user_batch|
|
122 |
+
|
123 |
+
aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
|
124 |
+
|
125 |
+
mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
|
126 |
+
sim = torch.where(mask, -torch.inf, sim)
|
127 |
+
sim = sim[:n, :]
|
128 |
+
distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
|
129 |
+
|
130 |
+
loss = aligment_loss + distribution_loss
|
131 |
+
return loss
|
132 |
+
|
133 |
+
|
134 |
+
def sample_item_batch(user_batch, bipartite_graph):
|
135 |
+
sampled_edges = dgl.sampling.sample_neighbors(
|
136 |
+
bipartite_graph, {"User": user_batch}, fanout=2
|
137 |
+
).edges(etype="ItemUser")
|
138 |
+
item_batch = sampled_edges[0]
|
139 |
+
item_batch = item_batch[torch.argsort(sampled_edges[1])]
|
140 |
+
item_batch = item_batch.reshape(-1, 2)
|
141 |
+
item_batch = item_batch.T
|
142 |
+
return item_batch
|
143 |
+
|
144 |
+
|
145 |
+
def prepare_gnn_embeddings(
|
146 |
+
# Paths
|
147 |
+
items_path,
|
148 |
+
ratings_path,
|
149 |
+
text_embeddings_path,
|
150 |
+
deepwalk_embeddings_path,
|
151 |
+
embeddings_savepath,
|
152 |
+
# Learning hyperparameters
|
153 |
+
temperature,
|
154 |
+
batch_size,
|
155 |
+
lr,
|
156 |
+
num_epochs,
|
157 |
+
# Model hyperparameters
|
158 |
+
num_layers,
|
159 |
+
hidden_dim,
|
160 |
+
aggregator_type,
|
161 |
+
skip_connection,
|
162 |
+
bidirectional,
|
163 |
+
num_traversals,
|
164 |
+
termination_prob,
|
165 |
+
num_random_walks,
|
166 |
+
num_neighbor,
|
167 |
+
# Misc
|
168 |
+
device,
|
169 |
+
wandb_name,
|
170 |
+
use_wandb,
|
171 |
+
):
|
172 |
+
### Prepare graph
|
173 |
+
bipartite_graph, _ = prepare_graphs(items_path, ratings_path)
|
174 |
+
bipartite_graph = bipartite_graph.to(device)
|
175 |
+
|
176 |
+
### Init wandb
|
177 |
+
if use_wandb:
|
178 |
+
wandb.init(project="graph-rec-gnn", name=wandb_name)
|
179 |
+
|
180 |
+
### Prepare model
|
181 |
+
text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
|
182 |
+
deepwalk_embeddings = torch.tensor(np.load(deepwalk_embeddings_path)).to(device)
|
183 |
+
model = GNNModel(
|
184 |
+
bipartite_graph=bipartite_graph,
|
185 |
+
text_embeddings=text_embeddings,
|
186 |
+
deepwalk_embeddings=deepwalk_embeddings,
|
187 |
+
num_layers=num_layers,
|
188 |
+
hidden_dim=hidden_dim,
|
189 |
+
aggregator_type=aggregator_type,
|
190 |
+
skip_connection=skip_connection,
|
191 |
+
bidirectional=bidirectional,
|
192 |
+
num_traversals=num_traversals,
|
193 |
+
termination_prob=termination_prob,
|
194 |
+
num_random_walks=num_random_walks,
|
195 |
+
num_neighbor=num_neighbor
|
196 |
+
)
|
197 |
+
model = model.to(device)
|
198 |
+
|
199 |
+
### Prepare dataloader
|
200 |
+
all_users = torch.arange(bipartite_graph.num_nodes("User")).to(device)
|
201 |
+
all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
|
202 |
+
dataloader = torch.utils.data.DataLoader(
|
203 |
+
all_users, batch_size=batch_size, shuffle=True, drop_last=True)
|
204 |
+
|
205 |
+
### Prepare optimizer & LR scheduler
|
206 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
207 |
+
total_steps = num_epochs * len(dataloader)
|
208 |
+
lr_schedule = LRSchedule(
|
209 |
+
total_steps=total_steps,
|
210 |
+
warmup_steps=int(0.1*total_steps),
|
211 |
+
final_factor=0.1) # TODO: move to args
|
212 |
+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
|
213 |
+
|
214 |
+
### Train loop
|
215 |
+
model.train()
|
216 |
+
for epoch in range(num_epochs):
|
217 |
+
for user_batch in tqdm(dataloader):
|
218 |
+
item_batch = sample_item_batch(user_batch, bipartite_graph) # (2, |user_batch|)
|
219 |
+
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
220 |
+
features = model(item_batch) # (2 * |user_batch|, hidden_dim)
|
221 |
+
sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
|
222 |
+
loss = nt_xent_loss(sim, temperature)
|
223 |
+
if use_wandb:
|
224 |
+
wandb.log({"loss": loss.item()})
|
225 |
+
optimizer.zero_grad()
|
226 |
+
loss.backward()
|
227 |
+
optimizer.step()
|
228 |
+
lr_scheduler.step()
|
229 |
+
|
230 |
+
if use_wandb:
|
231 |
+
wandb.finish()
|
232 |
+
|
233 |
+
### Process full dataset
|
234 |
+
model.eval()
|
235 |
+
with torch.no_grad():
|
236 |
+
hidden_dim = text_embeddings.shape[-1]
|
237 |
+
item_embeddings = torch.zeros(bipartite_graph.num_nodes("Item"), hidden_dim).to(device)
|
238 |
+
for items_batch in tqdm(torch.utils.data.DataLoader(
|
239 |
+
torch.arange(bipartite_graph.num_nodes("Item")),
|
240 |
+
batch_size=batch_size,
|
241 |
+
shuffle=True
|
242 |
+
)):
|
243 |
+
item_embeddings[items_batch] = model(items_batch.to(device))
|
244 |
+
|
245 |
+
### Extract & save item embeddings
|
246 |
+
item_embeddings = normalize_embeddings(item_embeddings.cpu().numpy())
|
247 |
+
np.save(embeddings_savepath, item_embeddings)
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
|
252 |
+
|
253 |
+
# Paths
|
254 |
+
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file")
|
255 |
+
parser.add_argument("--ratings_path", type=str, required=True, help="Path to the ratings file")
|
256 |
+
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
|
257 |
+
parser.add_argument("--deepwalk_embeddings_path", type=str, required=True, help="Path to the deepwalk embeddings file")
|
258 |
+
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
|
259 |
+
|
260 |
+
# Learning hyperparameters
|
261 |
+
parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
|
262 |
+
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")
|
263 |
+
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
|
264 |
+
parser.add_argument("--num_epochs", type=int, default=4, help="Number of epochs")
|
265 |
+
|
266 |
+
# Model hyperparameters
|
267 |
+
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model")
|
268 |
+
parser.add_argument("--hidden_dim", type=int, default=384, help="Hidden dimension size")
|
269 |
+
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv")
|
270 |
+
parser.add_argument("--no_skip_connection", action="store_false", dest="skip_connection", help="Disable skip connections")
|
271 |
+
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution")
|
272 |
+
parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler")
|
273 |
+
parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler")
|
274 |
+
parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler")
|
275 |
+
parser.add_argument("--num_neighbor", type=int, default=3, help="Number of neighbors in PinSAGE-like sampler")
|
276 |
+
|
277 |
+
# Misc
|
278 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
279 |
+
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
280 |
+
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging")
|
281 |
+
|
282 |
+
args = parser.parse_args()
|
283 |
+
|
284 |
+
prepare_gnn_embeddings(**vars(args))
|
exp/prepare_db.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sqlite3
|
3 |
+
import io
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def convert_numpy_array_to_text(array):
|
10 |
+
stream = io.BytesIO()
|
11 |
+
np.save(stream, array)
|
12 |
+
stream.seek(0)
|
13 |
+
return sqlite3.Binary(stream.read())
|
14 |
+
|
15 |
+
|
16 |
+
def prepare_items_db(items_path, embeddings_path, db_path):
|
17 |
+
items = pd.read_csv(items_path)
|
18 |
+
embeddings = np.load(embeddings_path)
|
19 |
+
items["embedding"] = np.split(embeddings, embeddings.shape[0])
|
20 |
+
|
21 |
+
sqlite3.register_adapter(np.ndarray, convert_numpy_array_to_text)
|
22 |
+
with sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
|
23 |
+
items.to_sql("items", conn, if_exists="replace", index=False, dtype={"embedding": "embedding"})
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
parser = argparse.ArgumentParser(description="Prepare items database from a CSV file.")
|
28 |
+
parser.add_argument("--items_path", required=True, type=str, help="Path to the CSV file containing items.")
|
29 |
+
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the .npy file containing item embeddings.")
|
30 |
+
parser.add_argument("--db_path", required=True, type=str, help="Path to the SQLite database file.")
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
prepare_items_db(**vars(args))
|
exp/prepare_embeddings.sh
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
|
4 |
+
input_directory="$1"
|
5 |
+
save_directory="$2"
|
6 |
+
device="${3:-cpu}"
|
7 |
+
|
8 |
+
echo Running on "$device".
|
9 |
+
|
10 |
+
PYTHONPATH=. python exp/process_raw_data.py \
|
11 |
+
--input_directory "$input_directory" \
|
12 |
+
--save_directory "$save_directory" \
|
13 |
+
--create_train_val_split
|
14 |
+
|
15 |
+
PYTHONPATH=. python exp/deepwalk.py \
|
16 |
+
--items_path "$save_directory/items.csv" \
|
17 |
+
--ratings_path "$save_directory/train_ratings.csv" \
|
18 |
+
--embeddings_savepath "$save_directory/deepwalk_embeddings.npy" \
|
19 |
+
--device $device \
|
20 |
+
--no_wandb
|
21 |
+
|
22 |
+
PYTHONPATH=. python exp/sbert.py \
|
23 |
+
--items_path "$save_directory/items.csv" \
|
24 |
+
--embeddings_savepath "$save_directory/text_embeddings.npy" \
|
25 |
+
--device $device
|
26 |
+
|
27 |
+
PYTHONPATH=. python exp/gnn.py \
|
28 |
+
--items_path "$save_directory/items.csv" \
|
29 |
+
--ratings_path "$save_directory/train_ratings.csv" \
|
30 |
+
--text_embeddings_path "$save_directory/text_embeddings.npy" \
|
31 |
+
--deepwalk_embeddings_path "$save_directory/deepwalk_embeddings.npy" \
|
32 |
+
--embeddings_savepath "$save_directory/embeddings.npy"\
|
33 |
+
--device $device \
|
34 |
+
--no_wandb
|
35 |
+
|
36 |
+
PYTHONPATH=. python exp/prepare_index.py \
|
37 |
+
--embeddings_path "$save_directory/embeddings.npy" \
|
38 |
+
--save_path "$save_directory/index.faiss"
|
39 |
+
|
40 |
+
PYTHONPATH=. python exp/prepare_db.py \
|
41 |
+
--items_path "$save_directory/items.csv" \
|
42 |
+
--embeddings_path "$save_directory/embeddings.npy" \
|
43 |
+
--db_path "$save_directory/items.db"
|
44 |
+
|
45 |
+
PYTHONPATH=. python exp/evaluate.py \
|
46 |
+
--metrics_savepath "$save_directory/metrics.json" \
|
47 |
+
--val_ratings_path "$save_directory/val_ratings.csv" \
|
48 |
+
--faiss_index_path "$save_directory/index.faiss" \
|
49 |
+
--db_path "$save_directory/items.db"
|
50 |
+
|
51 |
+
echo "Evaluation metrics:"
|
52 |
+
cat "$save_directory/metrics.json"
|
exp/prepare_index.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def build_index(embeddings_path, save_path, n_neighbors):
|
8 |
+
embeddings = np.load(embeddings_path)
|
9 |
+
index = faiss.IndexHNSWFlat(embeddings.shape[-1], 32)
|
10 |
+
index.add(embeddings)
|
11 |
+
faiss.write_index(index, save_path)
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
parser = argparse.ArgumentParser(description="Build an HNSW index from embeddings.")
|
16 |
+
parser.add_argument("--embeddings_path", required=True, type=str, help="Path to the embeddings file.")
|
17 |
+
parser.add_argument("--save_path", type=str, required=True, help="Path to save the built index.")
|
18 |
+
parser.add_argument("--n_neighbors", type=int, default=32, help="Number of neighbors for the index.")
|
19 |
+
args = parser.parse_args()
|
20 |
+
build_index(**vars(args))
|
exp/process_raw_data.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def book_filter(book, ratings_count_threshold=10_000):
|
10 |
+
try:
|
11 |
+
if book["ratings_count"] == "":
|
12 |
+
return False
|
13 |
+
if int(book["ratings_count"]) < ratings_count_threshold:
|
14 |
+
return False
|
15 |
+
|
16 |
+
if book["description"] == "":
|
17 |
+
return False
|
18 |
+
|
19 |
+
if book["title"] == "":
|
20 |
+
return False
|
21 |
+
|
22 |
+
if book["title_without_series"] == "":
|
23 |
+
return False
|
24 |
+
|
25 |
+
possible_lang_codes = {"eng", "en-GB", "en-US"}
|
26 |
+
if not book["language_code"] in possible_lang_codes:
|
27 |
+
return False
|
28 |
+
|
29 |
+
return True
|
30 |
+
except Exception:
|
31 |
+
return False
|
32 |
+
|
33 |
+
|
34 |
+
def process_raw_data_goodreads(input_directory, save_directory, positive_rating_threshold = 4.0):
|
35 |
+
os.makedirs(save_directory, exist_ok=True)
|
36 |
+
|
37 |
+
### Process items
|
38 |
+
columns = [
|
39 |
+
"book_id",
|
40 |
+
"description",
|
41 |
+
"title_without_series",
|
42 |
+
]
|
43 |
+
numeric_columns = [
|
44 |
+
"book_id",
|
45 |
+
]
|
46 |
+
|
47 |
+
items = []
|
48 |
+
with open(os.path.join(input_directory, "goodreads_books.json"), "r") as f:
|
49 |
+
for line in f:
|
50 |
+
item = json.loads(line)
|
51 |
+
if book_filter(item):
|
52 |
+
items.append([item[col] for col in columns])
|
53 |
+
items = pd.DataFrame(items, columns=columns)
|
54 |
+
for col in numeric_columns:
|
55 |
+
items[col] = pd.to_numeric(items[col])
|
56 |
+
items["item_id"] = items.index
|
57 |
+
items["title"] = items["title_without_series"]
|
58 |
+
items.drop("title_without_series", axis=1, inplace=True)
|
59 |
+
items.to_csv(os.path.join(save_directory, "items.csv"), index=False)
|
60 |
+
|
61 |
+
### Process ratings
|
62 |
+
ratings = pd.read_csv(os.path.join(input_directory, "goodreads_interactions.csv"))
|
63 |
+
|
64 |
+
book_id_map = pd.read_csv(os.path.join(input_directory, "book_id_map.csv"))
|
65 |
+
csv_to_usual_map = dict(zip(book_id_map["book_id_csv"], book_id_map["book_id"]))
|
66 |
+
usual_to_csv_map = dict(zip(book_id_map["book_id"], book_id_map["book_id_csv"]))
|
67 |
+
|
68 |
+
book_ids = items["book_id"].unique()
|
69 |
+
book_ids_csv = set([usual_to_csv_map[book_id] for book_id in book_ids])
|
70 |
+
ratings = ratings[ratings["rating"] >= positive_rating_threshold]
|
71 |
+
ratings = ratings[ratings["book_id"].isin(book_ids_csv)]
|
72 |
+
|
73 |
+
book_to_item_id_map = dict(zip(items["book_id"], items["item_id"]))
|
74 |
+
ratings["item_id"] = ratings["book_id"].map(csv_to_usual_map).map(book_to_item_id_map)
|
75 |
+
|
76 |
+
user_ids = list(ratings["user_id"].unique())
|
77 |
+
user_ids_map = dict(zip(user_ids, range(len(user_ids))))
|
78 |
+
ratings["user_id"] = ratings["user_id"].map(user_ids_map)
|
79 |
+
|
80 |
+
ratings.to_csv(os.path.join(save_directory, "ratings.csv"), index=False)
|
81 |
+
|
82 |
+
|
83 |
+
def create_train_val_split(ratings_path, train_savepath, val_savepath, seed=42):
|
84 |
+
ratings = pd.read_csv(ratings_path)
|
85 |
+
user_ids = ratings["user_id"].unique()
|
86 |
+
|
87 |
+
rng = np.random.default_rng(seed=seed)
|
88 |
+
train_size = int(len(user_ids) * 0.9)
|
89 |
+
train_indices = rng.choice(user_ids, size=train_size, replace=False)
|
90 |
+
|
91 |
+
train_data = ratings.loc[ratings["user_id"].isin(train_indices)]
|
92 |
+
val_data = ratings.loc[~ratings["user_id"].isin(train_indices)]
|
93 |
+
|
94 |
+
print(f"Train size: {len(train_data)}.")
|
95 |
+
print(f"Validation size: {len(val_data)}.")
|
96 |
+
|
97 |
+
train_data.to_csv(train_savepath, index=False)
|
98 |
+
val_data.to_csv(val_savepath, index=False)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
parser = argparse.ArgumentParser(description="Process raw data.")
|
103 |
+
parser.add_argument("--input_directory", required=True, type=str, help="Directory containing the raw data.")
|
104 |
+
parser.add_argument("--save_directory", required=True, type=str, help="Directory where processed data will be saved.")
|
105 |
+
parser.add_argument("--create_train_val_split", action="store_true", help="Flag to indicate whether to create a train-validation split.")
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
print("Processing raw data...")
|
109 |
+
process_raw_data_goodreads(args.input_directory, args.save_directory)
|
110 |
+
if args.create_train_val_split:
|
111 |
+
create_train_val_split(
|
112 |
+
os.path.join(args.save_directory, "ratings.csv"),
|
113 |
+
os.path.join(args.save_directory, "train_ratings.csv"),
|
114 |
+
os.path.join(args.save_directory, "val_ratings.csv")
|
115 |
+
)
|
116 |
+
print("The raw data has been successfully processed.")
|
exp/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-r ../requirements.txt # install base requirements of app
|
2 |
+
dgl==2.1.0
|
3 |
+
torch==2.1.2
|
4 |
+
wandb==0.17.0
|
5 |
+
tqdm==4.66.4
|
6 |
+
pydantic==2.5.3
|
7 |
+
sentence_transformers==3.0.1
|
exp/requirements_gpu.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://data.dgl.ai/wheels/cu121/repo.html
|
2 |
+
-r ../requirements.txt # install base requirements of app
|
3 |
+
dgl==2.1.0
|
4 |
+
torch==2.1.0
|
5 |
+
wandb==0.17.0
|
6 |
+
tqdm==4.66.4
|
7 |
+
pydantic==2.5.3
|
8 |
+
sentence_transformers==3.0.1
|
exp/sbert.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
|
9 |
+
from utils import normalize_embeddings
|
10 |
+
|
11 |
+
|
12 |
+
def prepare_sbert_embeddings(
|
13 |
+
items_path,
|
14 |
+
embeddings_savepath,
|
15 |
+
model_name,
|
16 |
+
batch_size,
|
17 |
+
device
|
18 |
+
):
|
19 |
+
items = pd.read_csv(items_path).sort_values("item_id")
|
20 |
+
sentences = items["description"].values
|
21 |
+
model = SentenceTransformer(model_name).to(device)
|
22 |
+
embeddings = []
|
23 |
+
for start_index in tqdm(range(0, len(sentences), batch_size)):
|
24 |
+
batch = sentences[start_index:start_index+batch_size]
|
25 |
+
embeddings.extend(model.encode(batch))
|
26 |
+
embeddings = normalize_embeddings(np.array(embeddings))
|
27 |
+
np.save(embeddings_savepath, embeddings)
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
parser = argparse.ArgumentParser(description="Prepare SBERT embeddings.")
|
32 |
+
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file.")
|
33 |
+
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to save the embeddings.")
|
34 |
+
parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2", help="Name of the SBERT model to use.")
|
35 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
|
36 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device to use for training (cpu or cuda).")
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
prepare_sbert_embeddings(**vars(args))
|
exp/utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import dgl
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def normalize_embeddings(embeddings):
|
8 |
+
embeddings_norm = np.linalg.norm(embeddings, axis=1)
|
9 |
+
nonzero_embeddings = embeddings_norm > 0.0
|
10 |
+
embeddings[nonzero_embeddings] /= embeddings_norm[nonzero_embeddings, None]
|
11 |
+
return embeddings
|
12 |
+
|
13 |
+
|
14 |
+
def prepare_graphs(items_path, ratings_path):
|
15 |
+
items = pd.read_csv(items_path)
|
16 |
+
ratings = pd.read_csv(ratings_path)
|
17 |
+
|
18 |
+
n_users = np.max(ratings["user_id"].unique()) + 1
|
19 |
+
item_ids = torch.tensor(sorted(items["item_id"].unique()))
|
20 |
+
|
21 |
+
edges = torch.tensor(ratings["user_id"]), torch.tensor(ratings["item_id"])
|
22 |
+
reverse_edges = (edges[1], edges[0])
|
23 |
+
|
24 |
+
bipartite_graph = dgl.heterograph(
|
25 |
+
data_dict={
|
26 |
+
("User", "UserItem", "Item"): edges,
|
27 |
+
("Item", "ItemUser", "User"): reverse_edges
|
28 |
+
},
|
29 |
+
num_nodes_dict={
|
30 |
+
"User": n_users,
|
31 |
+
"Item": len(item_ids)
|
32 |
+
}
|
33 |
+
)
|
34 |
+
graph = dgl.to_homogeneous(bipartite_graph)
|
35 |
+
graph = dgl.add_self_loop(graph)
|
36 |
+
return bipartite_graph, graph
|
37 |
+
|
38 |
+
|
39 |
+
def extract_item_embeddings(node_embeddings, bipartite_graph, graph):
|
40 |
+
item_ntype = bipartite_graph.ntypes.index("Item")
|
41 |
+
item_mask = graph.ndata[dgl.NTYPE] == item_ntype
|
42 |
+
item_embeddings = node_embeddings[item_mask]
|
43 |
+
original_ids = graph.ndata[dgl.NID][item_mask]
|
44 |
+
item_embeddings = item_embeddings[torch.argsort(original_ids)]
|
45 |
+
return item_embeddings.cpu().numpy()
|
46 |
+
|
47 |
+
|
48 |
+
class LRSchedule:
|
49 |
+
def __init__(self, total_steps, warmup_steps, final_factor):
|
50 |
+
self._total_steps = total_steps
|
51 |
+
self._warmup_steps = warmup_steps
|
52 |
+
self._final_factor = final_factor
|
53 |
+
|
54 |
+
def __call__(self, step):
|
55 |
+
if step >= self._total_steps:
|
56 |
+
return self._final_factor
|
57 |
+
|
58 |
+
if self._warmup_steps > 0:
|
59 |
+
warmup_factor = step / self._warmup_steps
|
60 |
+
else:
|
61 |
+
warmup_factor = 1.0
|
62 |
+
|
63 |
+
steps_after_warmup = step - self._warmup_steps
|
64 |
+
total_steps_after_warmup = self._total_steps - self._warmup_steps
|
65 |
+
after_warmup_factor = 1 \
|
66 |
+
- (1 - self._final_factor) * (steps_after_warmup / total_steps_after_warmup)
|
67 |
+
|
68 |
+
factor = min(warmup_factor, after_warmup_factor)
|
69 |
+
return min(max(factor, 0), 1)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.35.0
|
2 |
+
pandas==2.2.2
|
3 |
+
numpy==1.26.4
|
4 |
+
faiss-cpu==1.8.0
|