Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
cf16a8f4
Commit
cf16a8f4
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Fix import
parent
5d6ea2d3
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
osrt/model.py
+11
-8
11 additions, 8 deletions
osrt/model.py
train_lit.py
+2
-3
2 additions, 3 deletions
train_lit.py
with
13 additions
and
11 deletions
osrt/model.py
+
11
−
8
View file @
cf16a8f4
...
@@ -7,7 +7,8 @@ import osrt.layers as layers
...
@@ -7,7 +7,8 @@ import osrt.layers as layers
import
lightning.pytorch
as
pl
import
lightning.pytorch
as
pl
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
from
osrt.utils.common
import
mse2psnr
,
compute_adjusted_rand_index
from
osrt.utils.common
import
mse2psnr
from
osrt.utils.losses
import
compute_ari
from
torch.optim.lr_scheduler
import
LambdaLR
from
torch.optim.lr_scheduler
import
LambdaLR
from
typing
import
Dict
from
typing
import
Dict
...
@@ -54,10 +55,11 @@ class OSRT(nn.Module):
...
@@ -54,10 +55,11 @@ class OSRT(nn.Module):
self
.
encoder
.
train
()
self
.
encoder
.
train
()
self
.
decoder
.
train
()
self
.
decoder
.
train
()
"""
class LitOSRT(pl.LightningModule):
class LitOSRT(pl.LightningModule):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
"""
OSRT Model
OSRT Model
The definition of the encoder/decoder are defined in the config file with the path to classes
The definition of the encoder/decoder are defined in the config file with the path to classes
Args:
Args:
...
@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule):
...
@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule):
decoder: class of the decoder to use
decoder: class of the decoder to use
cfg: config file containing informations of the model
cfg: config file containing informations of the model
extract_masks: wether to use masks for training
extract_masks: wether to use masks for training
"""
super().__init__()
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters()
self.cfg = cfg
self.cfg = cfg
...
@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule):
...
@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule):
return self.encoder(x) # Returns: slot_latents
return self.encoder(x) # Returns: slot_latents
def compute_loss(self, batch):
def compute_loss(self, batch):
"""
Args:
Args:
batch: dict containing the informations for training --> input images, rays and position
batch: dict containing the informations for training --> input images, rays and position
extract_masks (Bool): whether to use masks to compute the segmentation loss or not
extract_masks (Bool): whether to use masks to compute the segmentation loss or not
Returns:
Returns:
loss: the loss value
loss: the loss value
loss_terms: a dict containing more loss values
loss_terms: a dict containing more loss values
"""
device = self.device
device = self.device
render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs
render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs
...
@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule):
...
@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule):
# These are not actually used as part of the training loss.
# These are not actually used as part of the training loss.
# We just add the to the dict to report them.
# We just add the to the dict to report them.
loss_terms
[
'
ari
'
]
=
compute_a
djusted_rand_index
(
true_seg
.
transpose
(
1
,
2
),
loss_terms[
'
ari
'
] = compute_a
ri
(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
pred_seg.transpose(1, 2))
loss_terms
[
'
fg_ari
'
]
=
compute_a
djusted_rand_index
(
true_seg
.
transpose
(
1
,
2
)[:,
1
:],
loss_terms[
'
fg_ari
'
] = compute_a
ri
(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
pred_seg.transpose(1, 2))
# TODO : add new ari metrics
# TODO : add new ari metrics
...
@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule):
...
@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule):
'
interval
'
:
'
step
'
'
interval
'
:
'
step
'
}
}
}
}
"""
\ No newline at end of file
This diff is collapsed.
Click to expand it.
train_lit.py
+
2
−
3
View file @
cf16a8f4
...
@@ -29,7 +29,7 @@ __LOG10 = math.log(10)
...
@@ -29,7 +29,7 @@ __LOG10 = math.log(10)
def
validate
(
fabric
:
L
.
Fabric
,
model
:
OSRT
,
val_dataloader
:
DataLoader
,
epoch
:
int
=
0
):
def
validate
(
fabric
:
L
.
Fabric
,
model
:
OSRT
,
val_dataloader
:
DataLoader
,
epoch
:
int
=
0
):
# TODO : add segmentation also to select the model following how it's done in the training
# TODO : add segmentation also to select the model following how it's done in the training
"""
model.eval()
model
.
eval
()
mses
=
AverageMeter
()
mses
=
AverageMeter
()
psnrs
=
AverageMeter
()
psnrs
=
AverageMeter
()
...
@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
...
@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
if
fabric
.
global_rank
==
0
:
if
fabric
.
global_rank
==
0
:
torch
.
save
(
state_dict
,
os
.
path
.
join
(
cfg
.
out_dir
,
f
"
epoch-
{
epoch
:
06
d
}
-psnr
{
psnrs
.
avg
:
.
2
f
}
-mse
{
mses
.
avg
:
.
2
f
}
-ckpt.pth
"
))
torch
.
save
(
state_dict
,
os
.
path
.
join
(
cfg
.
out_dir
,
f
"
epoch-
{
epoch
:
06
d
}
-psnr
{
psnrs
.
avg
:
.
2
f
}
-mse
{
mses
.
avg
:
.
2
f
}
-ckpt.pth
"
))
model.train()
"""
model
.
train
()
pass
def
train_sam
(
def
train_sam
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment