Skip to content
Snippets Groups Projects
Unverified Commit c1f33835 authored by Quoc-Tuan Truong's avatar Quoc-Tuan Truong Committed by GitHub
Browse files

Add metadata file when saving model (#607)

parent 93a38528
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ import pickle ...@@ -20,6 +20,7 @@ import pickle
import warnings import warnings
from datetime import datetime from datetime import datetime
from glob import glob from glob import glob
import json
import numpy as np import numpy as np
...@@ -219,7 +220,7 @@ class Recommender: ...@@ -219,7 +220,7 @@ class Recommender:
return self.__class__(**init_params) return self.__class__(**init_params)
def save(self, save_dir=None, save_trainset=False): def save(self, save_dir=None, save_trainset=False, metadata=None):
"""Save a recommender model to the filesystem. """Save a recommender model to the filesystem.
Parameters Parameters
...@@ -232,6 +233,10 @@ class Recommender: ...@@ -232,6 +233,10 @@ class Recommender:
if we want to deploy model later because train_set is if we want to deploy model later because train_set is
required for certain evaluation steps. required for certain evaluation steps.
metadata: dict, default: None
Metadata to be saved with the model. This is useful
to store model details.
Returns Returns
------- -------
model_file : str model_file : str
...@@ -246,16 +251,27 @@ class Recommender: ...@@ -246,16 +251,27 @@ class Recommender:
model_file = os.path.join(model_dir, "{}.pkl".format(timestamp)) model_file = os.path.join(model_dir, "{}.pkl".format(timestamp))
saved_model = copy.deepcopy(self) saved_model = copy.deepcopy(self)
pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(
saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL
)
if self.verbose: if self.verbose:
print("{} model is saved to {}".format(self.name, model_file)) print("{} model is saved to {}".format(self.name, model_file))
metadata = {} if metadata is None else metadata
metadata["model_classname"] = type(saved_model).__name__
metadata["model_file"] = os.path.basename(model_file)
if save_trainset: if save_trainset:
trainset_file = model_file + ".trainset"
pickle.dump( pickle.dump(
self.train_set, self.train_set,
open(model_file + ".trainset", "wb"), open(trainset_file, "wb"),
protocol=pickle.HIGHEST_PROTOCOL, protocol=pickle.HIGHEST_PROTOCOL,
) )
metadata["trainset_file"] = os.path.basename(trainset_file)
with open(model_file + ".meta", "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
return model_file return model_file
...@@ -502,9 +518,7 @@ class Recommender: ...@@ -502,9 +518,7 @@ class Recommender:
) )
item_scores = all_item_scores[item_indices] item_scores = all_item_scores[item_indices]
if ( if k != -1: # O(n + k log k), faster for small k which is usually the case
k != -1
): # O(n + k log k), faster for small k which is usually the case
partitioned_idx = np.argpartition(item_scores, -k) partitioned_idx = np.argpartition(item_scores, -k)
top_k_idx = partitioned_idx[-k:] top_k_idx = partitioned_idx[-k:]
sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])] sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])]
...@@ -545,7 +559,9 @@ class Recommender: ...@@ -545,7 +559,9 @@ class Recommender:
raise ValueError(f"{user_id} is unknown to the model.") raise ValueError(f"{user_id} is unknown to the model.")
if k < -1 or k > self.total_items: if k < -1 or k > self.total_items:
raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.") raise ValueError(
f"k={k} is invalid, there are {self.total_users} users in total."
)
item_indices = np.arange(self.total_items) item_indices = np.arange(self.total_items)
if remove_seen: if remove_seen:
...@@ -622,7 +638,11 @@ class Recommender: ...@@ -622,7 +638,11 @@ class Recommender:
if self.stopped_epoch > 0: if self.stopped_epoch > 0:
print("Early stopping:") print("Early stopping:")
print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch)) print(
"- best epoch = {}, stopped epoch = {}".format(
self.best_epoch, self.stopped_epoch
)
)
print( print(
"- best monitored value = {:.6f} (delta = {:.6f})".format( "- best monitored value = {:.6f} (delta = {:.6f})".format(
self.best_value, current_value - self.best_value self.best_value, current_value - self.best_value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment