From c1f3383570b509930b6e59bcc5021b991283d7ad Mon Sep 17 00:00:00 2001
From: Quoc-Tuan Truong <tqtg@users.noreply.github.com>
Date: Thu, 28 Mar 2024 08:54:31 -0700
Subject: [PATCH] Add metadata file when saving model (#607)

---
 cornac/models/recommender.py | 36 ++++++++++++++++++++++++++++--------
 1 file changed, 28 insertions(+), 8 deletions(-)

diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py
index 62a17f7f..c7080a4b 100644
--- a/cornac/models/recommender.py
+++ b/cornac/models/recommender.py
@@ -20,6 +20,7 @@ import pickle
 import warnings
 from datetime import datetime
 from glob import glob
+import json
 
 import numpy as np
 
@@ -219,7 +220,7 @@ class Recommender:
 
         return self.__class__(**init_params)
 
-    def save(self, save_dir=None, save_trainset=False):
+    def save(self, save_dir=None, save_trainset=False, metadata=None):
         """Save a recommender model to the filesystem.
 
         Parameters
@@ -232,6 +233,10 @@ class Recommender:
             if we want to deploy model later because train_set is
             required for certain evaluation steps.
 
+        metadata: dict, default: None
+            Metadata to be saved with the model. This is useful
+            to store model details.
+
         Returns
         -------
         model_file : str
@@ -246,16 +251,27 @@ class Recommender:
         model_file = os.path.join(model_dir, "{}.pkl".format(timestamp))
 
         saved_model = copy.deepcopy(self)
-        pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
+        pickle.dump(
+            saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL
+        )
         if self.verbose:
             print("{} model is saved to {}".format(self.name, model_file))
 
+        metadata = {} if metadata is None else metadata
+        metadata["model_classname"] = type(saved_model).__name__
+        metadata["model_file"] = os.path.basename(model_file)
+
         if save_trainset:
+            trainset_file = model_file + ".trainset"
             pickle.dump(
                 self.train_set,
-                open(model_file + ".trainset", "wb"),
+                open(trainset_file, "wb"),
                 protocol=pickle.HIGHEST_PROTOCOL,
             )
+            metadata["trainset_file"] = os.path.basename(trainset_file)
+
+        with open(model_file + ".meta", "w", encoding="utf-8") as f:
+            json.dump(metadata, f, ensure_ascii=False, indent=4)
 
         return model_file
 
@@ -502,9 +518,7 @@ class Recommender:
         )
         item_scores = all_item_scores[item_indices]
 
-        if (
-            k != -1
-        ):  # O(n + k log k), faster for small k which is usually the case
+        if k != -1:  # O(n + k log k), faster for small k which is usually the case
             partitioned_idx = np.argpartition(item_scores, -k)
             top_k_idx = partitioned_idx[-k:]
             sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])]
@@ -545,7 +559,9 @@ class Recommender:
             raise ValueError(f"{user_id} is unknown to the model.")
 
         if k < -1 or k > self.total_items:
-            raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.")
+            raise ValueError(
+                f"k={k} is invalid, there are {self.total_users} users in total."
+            )
 
         item_indices = np.arange(self.total_items)
         if remove_seen:
@@ -622,7 +638,11 @@ class Recommender:
 
         if self.stopped_epoch > 0:
             print("Early stopping:")
-            print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch))
+            print(
+                "- best epoch = {}, stopped epoch = {}".format(
+                    self.best_epoch, self.stopped_epoch
+                )
+            )
             print(
                 "- best monitored value = {:.6f} (delta = {:.6f})".format(
                     self.best_value, current_value - self.best_value
-- 
GitLab