From 93a385286041f6dae987288da47331314d56b93f Mon Sep 17 00:00:00 2001
From: Quoc-Tuan Truong <tqtg@users.noreply.github.com>
Date: Wed, 27 Mar 2024 09:44:28 -0700
Subject: [PATCH] Add save and load functions for BiVAECF model (#608)

---
 cornac/models/bivaecf/recom_bivaecf.py | 56 ++++++++++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/cornac/models/bivaecf/recom_bivaecf.py b/cornac/models/bivaecf/recom_bivaecf.py
index 70b314c5..9f40602e 100644
--- a/cornac/models/bivaecf/recom_bivaecf.py
+++ b/cornac/models/bivaecf/recom_bivaecf.py
@@ -255,3 +255,59 @@ class BiVAECF(Recommender, ANNMixin):
         """
         item_vectors = self.bivae.mu_beta.detach().cpu().numpy()
         return item_vectors
+
+    def save(self, save_dir=None, save_trainset=True):
+        """Save model to the filesystem.
+
+        Parameters
+        ----------
+        save_dir: str, default: None
+            Path to a directory for the model to be stored.
+
+        save_trainset: bool, default: True
+            Save train_set together with the model. This is useful
+            if we want to deploy model later because train_set is
+            required for certain evaluation steps.
+
+        Returns
+        -------
+        model_file : str
+            Path to the model file stored on the filesystem.
+        """
+        import torch
+
+        if save_dir is None:
+            return
+
+        self.bivae.to(torch.device("cpu"))
+        model_file = Recommender.save(
+            self, save_dir=save_dir, save_trainset=save_trainset
+        )
+
+        return model_file
+
+    @staticmethod
+    def load(model_path, trainable=False):
+        """Load model from the filesystem.
+
+        Parameters
+        ----------
+        model_path: str, required
+            Path to a file or directory where the model is stored. If a directory is
+            provided, the latest model will be loaded.
+
+        trainable: boolean, optional, default: False
+            Set it to True if you would like to finetune the model. By default,
+            the model parameters are assumed to be fixed after being loaded.
+
+        Returns
+        -------
+        self : object
+        """
+        import torch
+
+        model = Recommender.load(model_path, trainable)
+        if "cuda" in str(model.device) and torch.cuda.is_available():
+            model.bivae.to(model.device)
+
+        return model
-- 
GitLab