erermeev-d commited on
Commit
51488f6
·
1 Parent(s): c746c39

Added config logging in WandB

Browse files
Files changed (1) hide show
  1. exp/gnn/train.py +32 -59
exp/gnn/train.py CHANGED
@@ -19,71 +19,44 @@ from exp.gnn.utils import (
19
  sample_item_batch, inference_model)
20
 
21
 
22
- def prepare_gnn_embeddings(
23
- # Paths
24
- items_path,
25
- train_ratings_path,
26
- val_ratings_path,
27
- text_embeddings_path,
28
- embeddings_savepath,
29
- # Learning hyperparameters
30
- temperature,
31
- batch_size,
32
- lr,
33
- num_epochs,
34
- # Model hyperparameters
35
- num_layers,
36
- hidden_dim,
37
- aggregator_type,
38
- skip_connection,
39
- bidirectional,
40
- num_traversals,
41
- termination_prob,
42
- num_random_walks,
43
- num_neighbor,
44
- # Misc
45
- validate_every_n_epoch,
46
- device,
47
- wandb_name,
48
- use_wandb,
49
- ):
50
  ### Prepare graph
51
- bipartite_graph, _ = prepare_graphs(items_path, train_ratings_path)
52
- bipartite_graph = bipartite_graph.to(device)
53
 
54
  ### Init wandb
55
- if use_wandb:
56
- wandb.init(project="graph-rec-gnn", name=wandb_name)
57
 
58
  ### Prepare model
59
- text_embeddings = torch.tensor(np.load(text_embeddings_path)).to(device)
60
  model = GNNModel(
61
  bipartite_graph=bipartite_graph,
62
  text_embeddings=text_embeddings,
63
- num_layers=num_layers,
64
- hidden_dim=hidden_dim,
65
- aggregator_type=aggregator_type,
66
- skip_connection=skip_connection,
67
- bidirectional=bidirectional,
68
- num_traversals=num_traversals,
69
- termination_prob=termination_prob,
70
- num_random_walks=num_random_walks,
71
- num_neighbor=num_neighbor
72
  )
73
- model = model.to(device)
74
 
75
  ### Prepare dataloader
76
- all_users = torch.arange(bipartite_graph.num_nodes("User")).to(device)
77
  all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
78
  dataloader = torch.utils.data.DataLoader(
79
- all_users, batch_size=batch_size, shuffle=True, drop_last=True)
80
 
81
  ### Prepare optimizer & LR scheduler
82
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
83
  lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
84
 
85
  ### Train loop
86
- for epoch in range(num_epochs):
87
  ### Train
88
  model.train()
89
  for user_batch in tqdm(dataloader):
@@ -91,35 +64,35 @@ def prepare_gnn_embeddings(
91
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
92
  features = model(item_batch) # (2 * |user_batch|, hidden_dim)
93
  sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
94
- loss = nt_xent_loss(sim, temperature)
95
- if use_wandb:
96
  wandb.log({"loss": loss.item()})
97
  optimizer.zero_grad()
98
  loss.backward()
99
  optimizer.step()
100
  lr_scheduler.step()
101
  ### Validation
102
- if (validate_every_n_epoch is not None) and (((epoch + 1) % validate_every_n_epoch) == 0):
103
  item_embeddings = inference_model(
104
- model, bipartite_graph, batch_size, hidden_dim, device)
105
  with tempfile.TemporaryDirectory() as tmp_dir_name:
106
  tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
107
  np.save(tmp_embeddings_path, item_embeddings)
108
- prepare_recsys(items_path, tmp_embeddings_path, tmp_dir_name)
109
  metrics = evaluate_recsys(
110
- val_ratings_path,
111
  os.path.join(tmp_dir_name, "index.faiss"),
112
  os.path.join(tmp_dir_name, "items.db"))
113
- print(f"Epoch {epoch + 1} / {num_epochs}. {metrics}")
114
- if use_wandb:
115
  wandb.log(metrics)
116
 
117
- if use_wandb:
118
  wandb.finish()
119
 
120
  ### Process full dataset
121
- item_embeddings = inference_model(model, bipartite_graph, batch_size, hidden_dim, device)
122
- np.save(embeddings_savepath, item_embeddings)
123
 
124
 
125
  if __name__ == "__main__":
@@ -157,4 +130,4 @@ if __name__ == "__main__":
157
 
158
  args = parser.parse_args()
159
 
160
- prepare_gnn_embeddings(**vars(args))
 
19
  sample_item_batch, inference_model)
20
 
21
 
22
+ def prepare_gnn_embeddings(config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ### Prepare graph
24
+ bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
25
+ bipartite_graph = bipartite_graph.to(config["device"])
26
 
27
  ### Init wandb
28
+ if config["use_wandb"]:
29
+ wandb.init(project="graph-rec-gnn", name=config["wandb_name"], config=config)
30
 
31
  ### Prepare model
32
+ text_embeddings = torch.tensor(np.load(config["text_embeddings_path"])).to(config["device"])
33
  model = GNNModel(
34
  bipartite_graph=bipartite_graph,
35
  text_embeddings=text_embeddings,
36
+ num_layers=config["num_layers"],
37
+ hidden_dim=config["hidden_dim"],
38
+ aggregator_type=config["aggregator_type"],
39
+ skip_connection=config["skip_connection"],
40
+ bidirectional=config["bidirectional"],
41
+ num_traversals=config["num_traversals"],
42
+ termination_prob=config["termination_prob"],
43
+ num_random_walks=config["num_random_walks"],
44
+ num_neighbor=config["num_neighbor"]
45
  )
46
+ model = model.to(config["device"])
47
 
48
  ### Prepare dataloader
49
+ all_users = torch.arange(bipartite_graph.num_nodes("User")).to(config["device"])
50
  all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] # We need to sample 2 items per user
51
  dataloader = torch.utils.data.DataLoader(
52
+ all_users, batch_size=config["batch_size"], shuffle=True, drop_last=True)
53
 
54
  ### Prepare optimizer & LR scheduler
55
+ optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
56
  lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)
57
 
58
  ### Train loop
59
+ for epoch in range(config["num_epochs"]):
60
  ### Train
61
  model.train()
62
  for user_batch in tqdm(dataloader):
 
64
  item_batch = item_batch.reshape(-1) # (2 * |user_batch|)
65
  features = model(item_batch) # (2 * |user_batch|, hidden_dim)
66
  sim = features @ features.T # (2 * |user_batch|, 2 * |user_batch|)
67
+ loss = nt_xent_loss(sim, config["temperature"])
68
+ if config["use_wandb"]:
69
  wandb.log({"loss": loss.item()})
70
  optimizer.zero_grad()
71
  loss.backward()
72
  optimizer.step()
73
  lr_scheduler.step()
74
  ### Validation
75
+ if (config["validate_every_n_epoch"] is not None) and (((epoch + 1) % config["validate_every_n_epoch"]) == 0):
76
  item_embeddings = inference_model(
77
+ model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
78
  with tempfile.TemporaryDirectory() as tmp_dir_name:
79
  tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy")
80
  np.save(tmp_embeddings_path, item_embeddings)
81
+ prepare_recsys(config["items_path"], tmp_embeddings_path, tmp_dir_name)
82
  metrics = evaluate_recsys(
83
+ config["val_ratings_path"],
84
  os.path.join(tmp_dir_name, "index.faiss"),
85
  os.path.join(tmp_dir_name, "items.db"))
86
+ print(f"Epoch {epoch + 1} / {config['num_epochs']}. {metrics}")
87
+ if config["use_wandb"]:
88
  wandb.log(metrics)
89
 
90
+ if config["use_wandb"]:
91
  wandb.finish()
92
 
93
  ### Process full dataset
94
+ item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
95
+ np.save(config["embeddings_savepath"], item_embeddings)
96
 
97
 
98
  if __name__ == "__main__":
 
130
 
131
  args = parser.parse_args()
132
 
133
+ prepare_gnn_embeddings(vars(args))