erermeev-d
commited on
Commit
·
75a5562
1
Parent(s):
51488f6
Added saving of the final model checkpoint
Browse files- exp/gnn/train.py +4 -0
- exp/pipeline.sh +1 -0
exp/gnn/train.py
CHANGED
@@ -94,6 +94,9 @@ def prepare_gnn_embeddings(config):
|
|
94 |
item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
|
95 |
np.save(config["embeddings_savepath"], item_embeddings)
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
if __name__ == "__main__":
|
99 |
parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
|
@@ -104,6 +107,7 @@ if __name__ == "__main__":
|
|
104 |
parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
|
105 |
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
|
106 |
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
|
|
|
107 |
|
108 |
# Learning hyperparameters
|
109 |
parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
|
|
|
94 |
item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
|
95 |
np.save(config["embeddings_savepath"], item_embeddings)
|
96 |
|
97 |
+
### Save final model
|
98 |
+
torch.save(model.to("cpu").state_dict(), config["model_savepath"])
|
99 |
+
|
100 |
|
101 |
if __name__ == "__main__":
|
102 |
parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
|
|
|
107 |
parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
|
108 |
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
|
109 |
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
|
110 |
+
parser.add_argument("--model_savepath", type=str, required=True, help="Path to save final model checkpoint.")
|
111 |
|
112 |
# Learning hyperparameters
|
113 |
parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
|
exp/pipeline.sh
CHANGED
@@ -23,6 +23,7 @@ PYTHONPATH=. python exp/gnn/train.py \
|
|
23 |
--val_ratings_path "$save_directory/val_ratings.csv" \
|
24 |
--text_embeddings_path "$save_directory/text_embeddings.npy" \
|
25 |
--embeddings_savepath "$save_directory/embeddings.npy"\
|
|
|
26 |
--device $device \
|
27 |
--no_wandb
|
28 |
|
|
|
23 |
--val_ratings_path "$save_directory/val_ratings.csv" \
|
24 |
--text_embeddings_path "$save_directory/text_embeddings.npy" \
|
25 |
--embeddings_savepath "$save_directory/embeddings.npy"\
|
26 |
+
--model_savepath "$save_directory/model.pt" \
|
27 |
--device $device \
|
28 |
--no_wandb
|
29 |
|