From 93a385286041f6dae987288da47331314d56b93f Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong <tqtg@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:44:28 -0700 Subject: [PATCH] Add save and load functions for BiVAECF model (#608) --- cornac/models/bivaecf/recom_bivaecf.py | 56 ++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/cornac/models/bivaecf/recom_bivaecf.py b/cornac/models/bivaecf/recom_bivaecf.py index 70b314c5..9f40602e 100644 --- a/cornac/models/bivaecf/recom_bivaecf.py +++ b/cornac/models/bivaecf/recom_bivaecf.py @@ -255,3 +255,59 @@ class BiVAECF(Recommender, ANNMixin): """ item_vectors = self.bivae.mu_beta.detach().cpu().numpy() return item_vectors + + def save(self, save_dir=None, save_trainset=True): + """Save model to the filesystem. + + Parameters + ---------- + save_dir: str, default: None + Path to a directory for the model to be stored. + + save_trainset: bool, default: True + Save train_set together with the model. This is useful + if we want to deploy model later because train_set is + required for certain evaluation steps. + + Returns + ------- + model_file : str + Path to the model file stored on the filesystem. + """ + import torch + + if save_dir is None: + return + + self.bivae.to(torch.device("cpu")) + model_file = Recommender.save( + self, save_dir=save_dir, save_trainset=save_trainset + ) + + return model_file + + @staticmethod + def load(model_path, trainable=False): + """Load model from the filesystem. + + Parameters + ---------- + model_path: str, required + Path to a file or directory where the model is stored. If a directory is + provided, the latest model will be loaded. + + trainable: boolean, optional, default: False + Set it to True if you would like to finetune the model. By default, + the model parameters are assumed to be fixed after being loaded. + + Returns + ------- + self : object + """ + import torch + + model = Recommender.load(model_path, trainable) + if "cuda" in str(model.device) and torch.cuda.is_available(): + model.bivae.to(model.device) + + return model -- GitLab