Skip to content
Snippets Groups Projects
Commit d8885c6a authored by Arthur Batel's avatar Arthur Batel
Browse files

bivaecf early stopping

parent 0ed95849
No related branches found
No related tags found
No related merge requests found
......@@ -132,7 +132,7 @@ def rm_numeric(t: str) -> str:
def rm_punctuation(t: str) -> str:
"""
Remove "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" from t.
Remove "!"#$%&'()*+,-./:;<=>?@[]^_`{|}~" from t.
"""
return t.translate(str.maketrans('', '', string.punctuation))
......
......@@ -163,6 +163,8 @@ def learn(
verbose,
device=torch.device("cpu"),
dtype=torch.float32,
val_set=None, # Add validation set parameter
patience=10, # Add patience parameter for early stopping
):
user_params = it.chain(
bivae.user_encoder.parameters(),
......@@ -191,8 +193,12 @@ def learn(
x.data = np.ones_like(x.data) # Binarize data
tx = x.transpose()
# Initialize variables for early stopping
best_val_loss = float('inf')
patience_counter = 0
progress_bar = trange(1, n_epochs + 1, disable=not verbose)
for _ in progress_bar:
for epoch in progress_bar:
# item side
i_sum_loss = 0.0
i_count = 0
......@@ -255,6 +261,37 @@ def learn(
progress_bar.set_postfix(
loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / (u_count))
)
# Validation loss calculation
if val_set is not None:
val_loss = 0.0
val_count = 0
with torch.no_grad(): # No need to compute gradients during validation
for u_ids in val_set.user_iter(batch_size, shuffle=False):
u_batch = val_set.matrix[u_ids, :].A
u_batch = torch.tensor(u_batch, dtype=dtype, device=device)
# Reconstruct batch
theta, u_batch_, u_mu, u_std = bivae(u_batch, user=True, beta=bivae.beta)
# Compute validation loss
u_loss = bivae.loss(u_batch, u_batch_, u_mu, 0.0, u_std, beta_kl)
val_loss += u_loss.data.item()
val_count += len(u_batch)
avg_val_loss = val_loss / val_count
progress_bar.set_postfix(loss_i=(i_sum_loss / i_count), loss_u=(u_sum_loss / u_count),
val_loss=avg_val_loss)
# Early stopping check
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0 # Reset patience counter
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch} due to no improvement in validation loss.")
break # Stop training
# infer mu_beta
for i_ids in train_set.item_iter(batch_size, shuffle=False):
......
......@@ -19,6 +19,8 @@ from ..recommender import Recommender
from ..recommender import ANNMixin, MEASURE_DOT
from ...utils.common import scale
from ...exception import ScoreException
import torch
from .bivae import BiVAE, learn
class BiVAECF(Recommender, ANNMixin):
......@@ -130,9 +132,6 @@ class BiVAECF(Recommender, ANNMixin):
"""
Recommender.fit(self, train_set, val_set)
import torch
from .bivae import BiVAE, learn
self.device = (
torch.device("cuda:0")
if (self.use_gpu and torch.cuda.is_available())
......@@ -175,6 +174,7 @@ class BiVAECF(Recommender, ANNMixin):
batch_size=self.batch_size,
).to(self.device)
learn(
self.bivae,
train_set,
......@@ -184,6 +184,8 @@ class BiVAECF(Recommender, ANNMixin):
beta_kl=self.beta_kl,
verbose=self.verbose,
device=self.device,
val_set=val_set, # Pass validation set
patience=30, # Optional: You can modify the patience as needed
)
elif self.verbose:
print("%s is trained already (trainable = False)" % (self.name))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment