Skip to content
Snippets Groups Projects
Unverified Commit 07760ba9 authored by Lê Trung Hoàng's avatar Lê Trung Hoàng Committed by GitHub
Browse files

Fix max_seq_length variable using the maximum number of baskets in training...

Fix max_seq_length variable using the maximum number of baskets in training sequences as default (#591)
parent 85ce38cf
Branches main
No related tags found
No related merge requests found
...@@ -51,6 +51,10 @@ class Beacon(NextBasketRecommender): ...@@ -51,6 +51,10 @@ class Beacon(NextBasketRecommender):
Number of hops for constructing correlation matrix. Number of hops for constructing correlation matrix.
If 0, zeros matrix will be used. If 0, zeros matrix will be used.
max_seq_length: int, optional, default: None
Maximum basket sequence length.
If None, it is the maximum number of basket in training sequences.
n_epochs: int, optional, default: 15 n_epochs: int, optional, default: 15
Number of training epochs Number of training epochs
...@@ -83,6 +87,7 @@ class Beacon(NextBasketRecommender): ...@@ -83,6 +87,7 @@ class Beacon(NextBasketRecommender):
rnn_cell_type="LSTM", rnn_cell_type="LSTM",
dropout_rate=0.5, dropout_rate=0.5,
nb_hop=1, nb_hop=1,
max_seq_length=None,
n_epochs=15, n_epochs=15,
batch_size=32, batch_size=32,
lr=0.001, lr=0.001,
...@@ -99,10 +104,12 @@ class Beacon(NextBasketRecommender): ...@@ -99,10 +104,12 @@ class Beacon(NextBasketRecommender):
self.alpha = alpha self.alpha = alpha
self.rnn_cell_type = rnn_cell_type self.rnn_cell_type = rnn_cell_type
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.max_seq_length = max_seq_length
self.seed = seed self.seed = seed
self.lr = lr self.lr = lr
def fit(self, train_set, val_set=None): def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from .beacon_tf import BeaconModel from .beacon_tf import BeaconModel
...@@ -113,8 +120,12 @@ class Beacon(NextBasketRecommender): ...@@ -113,8 +120,12 @@ class Beacon(NextBasketRecommender):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
super().fit(train_set=train_set, val_set=val_set) # max sequence length
self.max_seq_length = (
max([len(bids) for bids in train_set.user_basket_data.values()])
if self.max_seq_length is None # init max_seq_length
else self.max_seq_length
)
self.correlation_matrix = self._build_correlation_matrix( self.correlation_matrix = self._build_correlation_matrix(
train_set=train_set, val_set=val_set, n_items=self.total_items train_set=train_set, val_set=val_set, n_items=self.total_items
) )
...@@ -132,7 +143,7 @@ class Beacon(NextBasketRecommender): ...@@ -132,7 +143,7 @@ class Beacon(NextBasketRecommender):
self.emb_dim, self.emb_dim,
self.rnn_unit, self.rnn_unit,
self.alpha, self.alpha,
train_set.max_basket_size, self.max_seq_length,
self.total_items, self.total_items,
self.item_probs, self.item_probs,
self.correlation_matrix, self.correlation_matrix,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment