From 07760ba9b7ad93699a1d43f7ac4eaec3b4fb3d7b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=AA=20Trung=20Ho=C3=A0ng?=
 <lthoang@users.noreply.github.com>
Date: Mon, 22 Jan 2024 10:07:04 +0700
Subject: [PATCH] Fix max_seq_length variable using the maximum number of
 baskets in training sequences as default (#591)

---
 cornac/models/beacon/recom_beacon.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/cornac/models/beacon/recom_beacon.py b/cornac/models/beacon/recom_beacon.py
index 1838a2e3..93d1b061 100644
--- a/cornac/models/beacon/recom_beacon.py
+++ b/cornac/models/beacon/recom_beacon.py
@@ -51,6 +51,10 @@ class Beacon(NextBasketRecommender):
         Number of hops for constructing correlation matrix.
         If 0, zeros matrix will be used.
 
+    max_seq_length: int, optional, default: None
+        Maximum basket sequence length.
+        If None, it is the maximum number of basket in training sequences.
+
     n_epochs: int, optional, default: 15
         Number of training epochs
 
@@ -83,6 +87,7 @@ class Beacon(NextBasketRecommender):
         rnn_cell_type="LSTM",
         dropout_rate=0.5,
         nb_hop=1,
+        max_seq_length=None,
         n_epochs=15,
         batch_size=32,
         lr=0.001,
@@ -99,10 +104,12 @@ class Beacon(NextBasketRecommender):
         self.alpha = alpha
         self.rnn_cell_type = rnn_cell_type
         self.dropout_rate = dropout_rate
+        self.max_seq_length = max_seq_length
         self.seed = seed
         self.lr = lr
 
     def fit(self, train_set, val_set=None):
+        super().fit(train_set=train_set, val_set=val_set)
         import tensorflow.compat.v1 as tf
 
         from .beacon_tf import BeaconModel
@@ -113,8 +120,12 @@ class Beacon(NextBasketRecommender):
         os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
         tf.logging.set_verbosity(tf.logging.ERROR)
 
-        super().fit(train_set=train_set, val_set=val_set)
-
+        # max sequence length
+        self.max_seq_length = (
+            max([len(bids) for bids in train_set.user_basket_data.values()])
+            if self.max_seq_length is None  # init max_seq_length
+            else self.max_seq_length
+        )
         self.correlation_matrix = self._build_correlation_matrix(
             train_set=train_set, val_set=val_set, n_items=self.total_items
         )
@@ -132,7 +143,7 @@ class Beacon(NextBasketRecommender):
             self.emb_dim,
             self.rnn_unit,
             self.alpha,
-            train_set.max_basket_size,
+            self.max_seq_length,
             self.total_items,
             self.item_probs,
             self.correlation_matrix,
-- 
GitLab