Skip to content
Snippets Groups Projects
Unverified Commit 93a38528 authored by Quoc-Tuan Truong's avatar Quoc-Tuan Truong Committed by GitHub
Browse files

Add save and load functions for BiVAECF model (#608)

parent 91ed98c6
No related branches found
No related tags found
No related merge requests found
......@@ -255,3 +255,59 @@ class BiVAECF(Recommender, ANNMixin):
"""
item_vectors = self.bivae.mu_beta.detach().cpu().numpy()
return item_vectors
def save(self, save_dir=None, save_trainset=True):
"""Save model to the filesystem.
Parameters
----------
save_dir: str, default: None
Path to a directory for the model to be stored.
save_trainset: bool, default: True
Save train_set together with the model. This is useful
if we want to deploy model later because train_set is
required for certain evaluation steps.
Returns
-------
model_file : str
Path to the model file stored on the filesystem.
"""
import torch
if save_dir is None:
return
self.bivae.to(torch.device("cpu"))
model_file = Recommender.save(
self, save_dir=save_dir, save_trainset=save_trainset
)
return model_file
@staticmethod
def load(model_path, trainable=False):
"""Load model from the filesystem.
Parameters
----------
model_path: str, required
Path to a file or directory where the model is stored. If a directory is
provided, the latest model will be loaded.
trainable: boolean, optional, default: False
Set it to True if you would like to finetune the model. By default,
the model parameters are assumed to be fixed after being loaded.
Returns
-------
self : object
"""
import torch
model = Recommender.load(model_path, trainable)
if "cuda" in str(model.device) and torch.cuda.is_available():
model.bivae.to(model.device)
return model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment