erermeev-d commited on
Commit
75a5562
·
1 Parent(s): 51488f6

Added saving of the final model checkpoint

Browse files
Files changed (2) hide show
  1. exp/gnn/train.py +4 -0
  2. exp/pipeline.sh +1 -0
exp/gnn/train.py CHANGED
@@ -94,6 +94,9 @@ def prepare_gnn_embeddings(config):
94
  item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
95
  np.save(config["embeddings_savepath"], item_embeddings)
96
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
  parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
@@ -104,6 +107,7 @@ if __name__ == "__main__":
104
  parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
105
  parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
106
  parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
 
107
 
108
  # Learning hyperparameters
109
  parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
 
94
  item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"])
95
  np.save(config["embeddings_savepath"], item_embeddings)
96
 
97
+ ### Save final model
98
+ torch.save(model.to("cpu").state_dict(), config["model_savepath"])
99
+
100
 
101
  if __name__ == "__main__":
102
  parser = argparse.ArgumentParser(description="Prepare GNN Embeddings")
 
107
  parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file")
108
  parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file")
109
  parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved")
110
+ parser.add_argument("--model_savepath", type=str, required=True, help="Path to save final model checkpoint.")
111
 
112
  # Learning hyperparameters
113
  parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss")
exp/pipeline.sh CHANGED
@@ -23,6 +23,7 @@ PYTHONPATH=. python exp/gnn/train.py \
23
  --val_ratings_path "$save_directory/val_ratings.csv" \
24
  --text_embeddings_path "$save_directory/text_embeddings.npy" \
25
  --embeddings_savepath "$save_directory/embeddings.npy"\
 
26
  --device $device \
27
  --no_wandb
28
 
 
23
  --val_ratings_path "$save_directory/val_ratings.csv" \
24
  --text_embeddings_path "$save_directory/text_embeddings.npy" \
25
  --embeddings_savepath "$save_directory/embeddings.npy"\
26
+ --model_savepath "$save_directory/model.pt" \
27
  --device $device \
28
  --no_wandb
29