From c1f3383570b509930b6e59bcc5021b991283d7ad Mon Sep 17 00:00:00 2001 From: Quoc-Tuan Truong <tqtg@users.noreply.github.com> Date: Thu, 28 Mar 2024 08:54:31 -0700 Subject: [PATCH] Add metadata file when saving model (#607) --- cornac/models/recommender.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 62a17f7f..c7080a4b 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -20,6 +20,7 @@ import pickle import warnings from datetime import datetime from glob import glob +import json import numpy as np @@ -219,7 +220,7 @@ class Recommender: return self.__class__(**init_params) - def save(self, save_dir=None, save_trainset=False): + def save(self, save_dir=None, save_trainset=False, metadata=None): """Save a recommender model to the filesystem. Parameters @@ -232,6 +233,10 @@ class Recommender: if we want to deploy model later because train_set is required for certain evaluation steps. + metadata: dict, default: None + Metadata to be saved with the model. This is useful + to store model details. + Returns ------- model_file : str @@ -246,16 +251,27 @@ class Recommender: model_file = os.path.join(model_dir, "{}.pkl".format(timestamp)) saved_model = copy.deepcopy(self) - pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump( + saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL + ) if self.verbose: print("{} model is saved to {}".format(self.name, model_file)) + metadata = {} if metadata is None else metadata + metadata["model_classname"] = type(saved_model).__name__ + metadata["model_file"] = os.path.basename(model_file) + if save_trainset: + trainset_file = model_file + ".trainset" pickle.dump( self.train_set, - open(model_file + ".trainset", "wb"), + open(trainset_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL, ) + metadata["trainset_file"] = os.path.basename(trainset_file) + + with open(model_file + ".meta", "w", encoding="utf-8") as f: + json.dump(metadata, f, ensure_ascii=False, indent=4) return model_file @@ -502,9 +518,7 @@ class Recommender: ) item_scores = all_item_scores[item_indices] - if ( - k != -1 - ): # O(n + k log k), faster for small k which is usually the case + if k != -1: # O(n + k log k), faster for small k which is usually the case partitioned_idx = np.argpartition(item_scores, -k) top_k_idx = partitioned_idx[-k:] sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])] @@ -545,7 +559,9 @@ class Recommender: raise ValueError(f"{user_id} is unknown to the model.") if k < -1 or k > self.total_items: - raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.") + raise ValueError( + f"k={k} is invalid, there are {self.total_users} users in total." + ) item_indices = np.arange(self.total_items) if remove_seen: @@ -622,7 +638,11 @@ class Recommender: if self.stopped_epoch > 0: print("Early stopping:") - print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch)) + print( + "- best epoch = {}, stopped epoch = {}".format( + self.best_epoch, self.stopped_epoch + ) + ) print( "- best monitored value = {:.6f} (delta = {:.6f})".format( self.best_value, current_value - self.best_value -- GitLab