erermeev-d
commited on
Commit
·
51488f6
1
Parent(s):
c746c39
Added config logging in WandB
Browse files- exp/gnn/train.py +32 -59
exp/gnn/train.py
CHANGED
@@ -19,71 +19,44 @@ from exp.gnn.utils import (
|
|
19 |
sample_item_batch, inference_model)
|
20 |
|
21 |
|
22 |
-
def prepare_gnn_embeddings(
|
23 |
-
# Paths
|
24 |
-
items_path,
|
25 |
-
train_ratings_path,
|
26 |
-
val_ratings_path,
|
27 |
-
text_embeddings_path,
|
28 |
-
embeddings_savepath,
|
29 |
-
# Learning hyperparameters
|
30 |
-
temperature,
|
31 |
-
batch_size,
|
32 |
-
lr,
|
33 |
-
num_epochs,
|
34 |
-
# Model hyperparameters
|
35 |
-
num_layers,
|
36 |
-
hidden_dim,
|
37 |
-
aggregator_type,
|
38 |
-
skip_connection,
|
39 |
-
bidirectional,
|
40 |
-
num_traversals,
|
41 |
-
termination_prob,
|
42 |
-
num_random_walks,
|
43 |
-
num_neighbor,
|
44 |
-
# Misc
|
45 |
-
validate_every_n_epoch,
|
46 |
-
device,
|
47 |
-
wandb_name,
|
48 |
-
use_wandb,
|
49 |
-
):
|
50 |
### Prepare graph
|
51 |
-
bipartite_graph, _ = prepare_graphs(items_path, train_ratings_path)
|
52 |
-
bipartite_graph = bipartite_graph.to(device)
|
53 |
|
54 |
### Init wandb
|
55 |
-
if use_wandb:
|
56 |
-
wandb.init(project="graph-rec-gnn", name=wandb_name)
|
57 |
|
58 |
### Prepare model
|
59 |
-
text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
|
60 |
model = GNNModel(
|
61 |
bipartite_graph=bipartite_graph,
|
62 |
text_embeddings=text_embeddings,
|
63 |
-
num_layers=num_layers,
|
64 |
-
hidden_dim=hidden_dim,
|
65 |
-
aggregator_type=aggregator_type,
|
66 |
-
skip_connection=skip_connection,
|
67 |
-
bidirectional=bidirectional,
|
68 |
-
num_traversals=num_traversals,
|
69 |
-
termination_prob=termination_prob,
|
70 |
-
num_random_walks=num_random_walks,
|
71 |
-
num_neighbor=num_neighbor
|
72 |
)
|
73 |
-
model = model.to(device)
|
74 |
|
75 |
### Prepare dataloader
|
76 |
-
all_users = torch.arange(bipartite_graph.num_nodes("User")).to(device)
|
77 |
all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
|
78 |
dataloader = torch.utils.data.DataLoader(
|
79 |
-
all_users, batch_size=batch_size, shuffle=True, drop_last=True)
|
80 |
|
81 |
### Prepare optimizer & LR scheduler
|
82 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
83 |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
|
84 |
|
85 |
### Train loop
|
86 |
-
for epoch in range(num_epochs):
|
87 |
### Train
|
88 |
model.train()
|
89 |
for user_batch in tqdm(dataloader):
|
@@ -91,35 +64,35 @@ def prepare_gnn_embeddings(
|
|
91 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
92 |
features = model(item_batch) # (2 * |user_batch|, hidden_dim)
|
93 |
sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
|
94 |
-
loss = nt_xent_loss(sim, temperature)
|
95 |
-
if use_wandb:
|
96 |
wandb.log({"loss": loss.item()})
|
97 |
optimizer.zero_grad()
|
98 |
loss.backward()
|
99 |
optimizer.step()
|
100 |
lr_scheduler.step()
|
101 |
### Validation
|
102 |
-
if (validate_every_n_epoch is not None) and (((epoch + 1) % validate_every_n_epoch) == 0):
|
103 |
item_embeddings = inference_model(
|
104 |
-
model, bipartite_graph, batch_size, hidden_dim, device)
|
105 |
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
106 |
tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
|
107 |
np.save(tmp_embeddings_path, item_embeddings)
|
108 |
-
prepare_recsys(items_path, tmp_embeddings_path, tmp_dir_name)
|
109 |
metrics = evaluate_recsys(
|
110 |
-
val_ratings_path,
|
111 |
os.path.join(tmp_dir_name, "index.faiss"),
|
112 |
os.path.join(tmp_dir_name, "items.db"))
|
113 |
-
print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
|
114 |
-
if use_wandb:
|
115 |
wandb.log(metrics)
|
116 |
|
117 |
-
if use_wandb:
|
118 |
wandb.finish()
|
119 |
|
120 |
### Process full dataset
|
121 |
-
item_embeddings = inference_model(model, bipartite_graph, batch_size, hidden_dim, device)
|
122 |
-
np.save(embeddings_savepath, item_embeddings)
|
123 |
|
124 |
|
125 |
if __name__ == "__main__":
|
@@ -157,4 +130,4 @@ if __name__ == "__main__":
|
|
157 |
|
158 |
args = parser.parse_args()
|
159 |
|
160 |
-
prepare_gnn_embeddings(
|
|
|
19 |
sample_item_batch, inference_model)
|
20 |
|
21 |
|
22 |
+
def prepare_gnn_embeddings(config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
### Prepare graph
|
24 |
+
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
|
25 |
+
bipartite_graph = bipartite_graph.to(config["device"])
|
26 |
|
27 |
### Init wandb
|
28 |
+
if config["use_wandb"]:
|
29 |
+
wandb.init(project="graph-rec-gnn", name=config["wandb_name"], config=config)
|
30 |
|
31 |
### Prepare model
|
32 |
+
text_embeddings = torch.tensor(np.load(config["text_embeddings_path"])).to(config["device"])
|
33 |
model = GNNModel(
|
34 |
bipartite_graph=bipartite_graph,
|
35 |
text_embeddings=text_embeddings,
|
36 |
+
num_layers=config["num_layers"],
|
37 |
+
hidden_dim=config["hidden_dim"],
|
38 |
+
aggregator_type=config["aggregator_type"],
|
39 |
+
skip_connection=config["skip_connection"],
|
40 |
+
bidirectional=config["bidirectional"],
|
41 |
+
num_traversals=config["num_traversals"],
|
42 |
+
termination_prob=config["termination_prob"],
|
43 |
+
num_random_walks=config["num_random_walks"],
|
44 |
+
num_neighbor=config["num_neighbor"]
|
45 |
)
|
46 |
+
model = model.to(config["device"])
|
47 |
|
48 |
### Prepare dataloader
|
49 |
+
all_users = torch.arange(bipartite_graph.num_nodes("User")).to(config["device"])
|
50 |
all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
|
51 |
dataloader = torch.utils.data.DataLoader(
|
52 |
+
all_users, batch_size=config["batch_size"], shuffle=True, drop_last=True)
|
53 |
|
54 |
### Prepare optimizer & LR scheduler
|
55 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
|
56 |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
|
57 |
|
58 |
### Train loop
|
59 |
+
for epoch in range(config["num_epochs"]):
|
60 |
### Train
|
61 |
model.train()
|
62 |
for user_batch in tqdm(dataloader):
|
|
|
64 |
item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
|
65 |
features = model(item_batch) # (2 * |user_batch|, hidden_dim)
|
66 |
sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
|
67 |
+
loss = nt_xent_loss(sim, config["temperature"])
|
68 |
+
if config["use_wandb"]:
|
69 |
wandb.log({"loss": loss.item()})
|
70 |
optimizer.zero_grad()
|
71 |
loss.backward()
|
72 |
optimizer.step()
|
73 |
lr_scheduler.step()
|
74 |
### Validation
|
75 |
+
if (config["validate_every_n_epoch"] is not None) and (((epoch + 1) % config["validate_every_n_epoch"]) == 0):
|
76 |
item_embeddings = inference_model(
|
77 |
+
model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
|
78 |
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
79 |
tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
|
80 |
np.save(tmp_embeddings_path, item_embeddings)
|
81 |
+
prepare_recsys(config["items_path"], tmp_embeddings_path, tmp_dir_name)
|
82 |
metrics = evaluate_recsys(
|
83 |
+
config["val_ratings_path"],
|
84 |
os.path.join(tmp_dir_name, "index.faiss"),
|
85 |
os.path.join(tmp_dir_name, "items.db"))
|
86 |
+
print(f"Epoch {epoch + 1} / {config['num_epochs']}. {metrics}")
|
87 |
+
if config["use_wandb"]:
|
88 |
wandb.log(metrics)
|
89 |
|
90 |
+
if config["use_wandb"]:
|
91 |
wandb.finish()
|
92 |
|
93 |
### Process full dataset
|
94 |
+
item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
|
95 |
+
np.save(config["embeddings_savepath"], item_embeddings)
|
96 |
|
97 |
|
98 |
if __name__ == "__main__":
|
|
|
130 |
|
131 |
args = parser.parse_args()
|
132 |
|
133 |
+
prepare_gnn_embeddings(vars(args))
|