diff --git a/cornac/data/text.py b/cornac/data/text.py index 8545206cbcb4b80c0220ca8ec71cf043c0b4f6d6..10d1e92fddcf22c9236840c509103a662a4a6456 100644 --- a/cornac/data/text.py +++ b/cornac/data/text.py @@ -132,7 +132,7 @@ def rm_numeric(t: str) -> str: def rm_punctuation(t: str) -> str: """ - Remove "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" from t. + Remove "!"#$%&'()*+,-./:;<=>?@[]^_`{|}~" from t. """ return t.translate(str.maketrans('', '', string.punctuation)) diff --git a/cornac/models/bivaecf/bivae.py b/cornac/models/bivaecf/bivae.py index a66f9d36fdf07b707df691d90ee976eee2d7dade..5221ca813c0862133767f276530be1f501ec6688 100644 --- a/cornac/models/bivaecf/bivae.py +++ b/cornac/models/bivaecf/bivae.py @@ -163,6 +163,8 @@ def learn( verbose, device=torch.device("cpu"), dtype=torch.float32, + val_set=None, # Add validation set parameter + patience=10, # Add patience parameter for early stopping ): user_params = it.chain( bivae.user_encoder.parameters(), @@ -191,8 +193,12 @@ def learn( x.data = np.ones_like(x.data) # Binarize data tx = x.transpose() + # Initialize variables for early stopping + best_val_loss = float('inf') + patience_counter = 0 + progress_bar = trange(1, n_epochs + 1, disable=not verbose) - for _ in progress_bar: + for epoch in progress_bar: # item side i_sum_loss = 0.0 i_count = 0 @@ -255,6 +261,37 @@ def learn( progress_bar.set_postfix( loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count)) ) + # Validation loss calculation + if val_set is not None: + val_loss = 0.0 + val_count = 0 + + with torch.no_grad(): # No need to compute gradients during validation + for u_ids in val_set.user_iter(batch_size, shuffle=False): + u_batch = val_set.matrix[u_ids, :].A + u_batch = torch.tensor(u_batch, dtype=dtype, device=device) + + # Reconstruct batch + theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta) + + # Compute validation loss + u_loss = bivae.loss(u_batch, u_batch_, u_mu, 0.0, u_std, beta_kl) + val_loss += u_loss.data.item() + val_count += len(u_batch) + + avg_val_loss = val_loss / val_count + progress_bar.set_postfix(loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / u_count), + val_loss=avg_val_loss) + + # Early stopping check + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience_counter = 0 # Reset patience counter + else: + patience_counter += 1 + if patience_counter >= patience: + print(f"Early stopping at epoch {epoch} due to no improvement in validation loss.") + break # Stop training # infer mu_beta for i_ids in train_set.item_iter(batch_size, shuffle=False): diff --git a/cornac/models/bivaecf/recom_bivaecf.py b/cornac/models/bivaecf/recom_bivaecf.py index 9f40602eba7aa6adea847887a2fb80580b3f9818..70e503e6334d54dffdccf721082941fd76669338 100644 --- a/cornac/models/bivaecf/recom_bivaecf.py +++ b/cornac/models/bivaecf/recom_bivaecf.py @@ -19,6 +19,8 @@ from ..recommender import Recommender from ..recommender import ANNMixin, MEASURE_DOT from ...utils.common import scale from ...exception import ScoreException +import torch +from .bivae import BiVAE, learn class BiVAECF(Recommender, ANNMixin): @@ -130,9 +132,6 @@ class BiVAECF(Recommender, ANNMixin): """ Recommender.fit(self, train_set, val_set) - import torch - from .bivae import BiVAE, learn - self.device = ( torch.device("cuda:0") if (self.use_gpu and torch.cuda.is_available()) @@ -175,6 +174,7 @@ class BiVAECF(Recommender, ANNMixin): batch_size=self.batch_size, ).to(self.device) + learn( self.bivae, train_set, @@ -184,6 +184,8 @@ class BiVAECF(Recommender, ANNMixin): beta_kl=self.beta_kl, verbose=self.verbose, device=self.device, + val_set=val_set, # Pass validation set + patience=30, # Optional: You can modify the patience as needed ) elif self.verbose: print("%s is trained already (trainable = False)" % (self.name))