Skip to content
Snippets Groups Projects
Unverified Commit c1f33835 authored by Quoc-Tuan Truong's avatar Quoc-Tuan Truong Committed by GitHub
Browse files

Add metadata file when saving model (#607)

parent 93a38528
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import pickle
import warnings
from datetime import datetime
from glob import glob
import json
import numpy as np
......@@ -219,7 +220,7 @@ class Recommender:
return self.__class__(**init_params)
def save(self, save_dir=None, save_trainset=False):
def save(self, save_dir=None, save_trainset=False, metadata=None):
"""Save a recommender model to the filesystem.
Parameters
......@@ -232,6 +233,10 @@ class Recommender:
if we want to deploy model later because train_set is
required for certain evaluation steps.
metadata: dict, default: None
Metadata to be saved with the model. This is useful
to store model details.
Returns
-------
model_file : str
......@@ -246,16 +251,27 @@ class Recommender:
model_file = os.path.join(model_dir, "{}.pkl".format(timestamp))
saved_model = copy.deepcopy(self)
pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(
saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL
)
if self.verbose:
print("{} model is saved to {}".format(self.name, model_file))
metadata = {} if metadata is None else metadata
metadata["model_classname"] = type(saved_model).__name__
metadata["model_file"] = os.path.basename(model_file)
if save_trainset:
trainset_file = model_file + ".trainset"
pickle.dump(
self.train_set,
open(model_file + ".trainset", "wb"),
open(trainset_file, "wb"),
protocol=pickle.HIGHEST_PROTOCOL,
)
metadata["trainset_file"] = os.path.basename(trainset_file)
with open(model_file + ".meta", "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
return model_file
......@@ -502,9 +518,7 @@ class Recommender:
)
item_scores = all_item_scores[item_indices]
if (
k != -1
): # O(n + k log k), faster for small k which is usually the case
if k != -1: # O(n + k log k), faster for small k which is usually the case
partitioned_idx = np.argpartition(item_scores, -k)
top_k_idx = partitioned_idx[-k:]
sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])]
......@@ -545,7 +559,9 @@ class Recommender:
raise ValueError(f"{user_id} is unknown to the model.")
if k < -1 or k > self.total_items:
raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.")
raise ValueError(
f"k={k} is invalid, there are {self.total_users} users in total."
)
item_indices = np.arange(self.total_items)
if remove_seen:
......@@ -622,7 +638,11 @@ class Recommender:
if self.stopped_epoch > 0:
print("Early stopping:")
print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch))
print(
"- best epoch = {}, stopped epoch = {}".format(
self.best_epoch, self.stopped_epoch
)
)
print(
"- best monitored value = {:.6f} (delta = {:.6f})".format(
self.best_value, current_value - self.best_value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment