Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
cornac
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arthur Batel
cornac
Commits
d8885c6a
Commit
d8885c6a
authored
10 months ago
by
Arthur Batel
Browse files
Options
Downloads
Patches
Plain Diff
bivaecf early stopping
parent
0ed95849
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
cornac/data/text.py
+1
-1
1 addition, 1 deletion
cornac/data/text.py
cornac/models/bivaecf/bivae.py
+38
-1
38 additions, 1 deletion
cornac/models/bivaecf/bivae.py
cornac/models/bivaecf/recom_bivaecf.py
+5
-3
5 additions, 3 deletions
cornac/models/bivaecf/recom_bivaecf.py
with
44 additions
and
5 deletions
cornac/data/text.py
+
1
−
1
View file @
d8885c6a
...
...
@@ -132,7 +132,7 @@ def rm_numeric(t: str) -> str:
def
rm_punctuation
(
t
:
str
)
->
str
:
"""
Remove
"
!
"
#$%&
'
()*+,-./:;<=>?@[
\
]^_`{|}~
"
from t.
Remove
"
!
"
#$%&
'
()*+,-./:;<=>?@[]^_`{|}~
"
from t.
"""
return
t
.
translate
(
str
.
maketrans
(
''
,
''
,
string
.
punctuation
))
...
...
This diff is collapsed.
Click to expand it.
cornac/models/bivaecf/bivae.py
+
38
−
1
View file @
d8885c6a
...
...
@@ -163,6 +163,8 @@ def learn(
verbose
,
device
=
torch
.
device
(
"
cpu
"
),
dtype
=
torch
.
float32
,
val_set
=
None
,
# Add validation set parameter
patience
=
10
,
# Add patience parameter for early stopping
):
user_params
=
it
.
chain
(
bivae
.
user_encoder
.
parameters
(),
...
...
@@ -191,8 +193,12 @@ def learn(
x
.
data
=
np
.
ones_like
(
x
.
data
)
# Binarize data
tx
=
x
.
transpose
()
# Initialize variables for early stopping
best_val_loss
=
float
(
'
inf
'
)
patience_counter
=
0
progress_bar
=
trange
(
1
,
n_epochs
+
1
,
disable
=
not
verbose
)
for
_
in
progress_bar
:
for
epoch
in
progress_bar
:
# item side
i_sum_loss
=
0.0
i_count
=
0
...
...
@@ -255,6 +261,37 @@ def learn(
progress_bar
.
set_postfix
(
loss_i
=
(
i_sum_loss
/
i_count
),
loss_u
=
(
u_sum_loss
/
(
u_count
))
)
# Validation loss calculation
if
val_set
is
not
None
:
val_loss
=
0.0
val_count
=
0
with
torch
.
no_grad
():
# No need to compute gradients during validation
for
u_ids
in
val_set
.
user_iter
(
batch_size
,
shuffle
=
False
):
u_batch
=
val_set
.
matrix
[
u_ids
,
:].
A
u_batch
=
torch
.
tensor
(
u_batch
,
dtype
=
dtype
,
device
=
device
)
# Reconstruct batch
theta
,
u_batch_
,
u_mu
,
u_std
=
bivae
(
u_batch
,
user
=
True
,
beta
=
bivae
.
beta
)
# Compute validation loss
u_loss
=
bivae
.
loss
(
u_batch
,
u_batch_
,
u_mu
,
0.0
,
u_std
,
beta_kl
)
val_loss
+=
u_loss
.
data
.
item
()
val_count
+=
len
(
u_batch
)
avg_val_loss
=
val_loss
/
val_count
progress_bar
.
set_postfix
(
loss_i
=
(
i_sum_loss
/
i_count
),
loss_u
=
(
u_sum_loss
/
u_count
),
val_loss
=
avg_val_loss
)
# Early stopping check
if
avg_val_loss
<
best_val_loss
:
best_val_loss
=
avg_val_loss
patience_counter
=
0
# Reset patience counter
else
:
patience_counter
+=
1
if
patience_counter
>=
patience
:
print
(
f
"
Early stopping at epoch
{
epoch
}
due to no improvement in validation loss.
"
)
break
# Stop training
# infer mu_beta
for
i_ids
in
train_set
.
item_iter
(
batch_size
,
shuffle
=
False
):
...
...
This diff is collapsed.
Click to expand it.
cornac/models/bivaecf/recom_bivaecf.py
+
5
−
3
View file @
d8885c6a
...
...
@@ -19,6 +19,8 @@ from ..recommender import Recommender
from
..recommender
import
ANNMixin
,
MEASURE_DOT
from
...utils.common
import
scale
from
...exception
import
ScoreException
import
torch
from
.bivae
import
BiVAE
,
learn
class
BiVAECF
(
Recommender
,
ANNMixin
):
...
...
@@ -130,9 +132,6 @@ class BiVAECF(Recommender, ANNMixin):
"""
Recommender
.
fit
(
self
,
train_set
,
val_set
)
import
torch
from
.bivae
import
BiVAE
,
learn
self
.
device
=
(
torch
.
device
(
"
cuda:0
"
)
if
(
self
.
use_gpu
and
torch
.
cuda
.
is_available
())
...
...
@@ -175,6 +174,7 @@ class BiVAECF(Recommender, ANNMixin):
batch_size
=
self
.
batch_size
,
).
to
(
self
.
device
)
learn
(
self
.
bivae
,
train_set
,
...
...
@@ -184,6 +184,8 @@ class BiVAECF(Recommender, ANNMixin):
beta_kl
=
self
.
beta_kl
,
verbose
=
self
.
verbose
,
device
=
self
.
device
,
val_set
=
val_set
,
# Pass validation set
patience
=
30
,
# Optional: You can modify the patience as needed
)
elif
self
.
verbose
:
print
(
"
%s is trained already (trainable = False)
"
%
(
self
.
name
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment