Skip to content
Snippets Groups Projects
Commit 3a4fb34b authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Setup epochs + add validation steps

parent 0c882e98
No related branches found
No related tags found
No related merge requests found
......@@ -6,12 +6,12 @@ model:
training:
num_workers: 2
batch_size: 32
visualize_every: 5000
validate_every: 5000
checkpoint_every: 1000
backup_every: 25000
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
print_every: 1
validate_every: 1
checkpoint_every: 1
visualize_every: 2
......@@ -5,16 +5,15 @@ import torch.nn as nn
import torch.optim as optim
import argparse
import yaml
from osrt.model import SlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader
import torch.nn.functional as F
def l2_loss(prediction, target):
return torch.mean(torch.pow(prediction - target, 2))
import tqdm
def train_step(batch, model, optimizer, device):
"""Perform a single training step."""
......@@ -25,7 +24,7 @@ def train_step(batch, model, optimizer, device):
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = l2_loss(input_image, recon_combined)
loss_value = nn.MSELoss(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
......@@ -35,6 +34,20 @@ def train_step(batch, model, optimizer, device):
return loss_value.item()
def eval_step(batch, model, device):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = nn.MSELoss(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
return loss_value.item(), psnr.item()
def main():
# Arguments
......@@ -50,10 +63,10 @@ def main():
with open(args.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
# Set random seed.
### Set random seed.
torch.manual_seed(args.seed)
# Hyperparameters of the model.
### Hyperparameters of the model.
batch_size = cfg["training"]["batch_size"]
num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"]
......@@ -66,25 +79,32 @@ def main():
resolution = (128, 128)
#### Create datasets
train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
val_dataset = data.get_dataset('val', cfg['data'])
val_loader = DataLoader(
val_dataset, batch_size=batch_size, num_workers=1,
shuffle=True, worker_init_fn=data.worker_init_fn)
vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:')
print(f'\Model slot attention: {num_params}')
print(f'Model slot attention: {num_params}')
optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
# Prepare checkpoint manager.
#### Prepare checkpoint manager.
global_step = 0
ckpt = {
'network': model,
......@@ -97,44 +117,72 @@ def main():
optimizer = ckpt['optimizer']
global_step = ckpt['global_step']
start = time.time()
for batch in train_loader:
#batch = next(iter(train_loader))
# Learning rate warm-up.
if global_step < warmup_steps:
learning_rate = base_learning_rate * global_step / warmup_steps
""" TODO : setup wandb
if args.wandb:
if run_id is None:
run_id = wandb.util.generate_id()
print(f'Sampled new wandb run_id {run_id}.')
else:
learning_rate = base_learning_rate
learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
loss_value = train_step(batch, model, optimizer, device)
# Update the global step. We update it before logging the loss and saving
# the model so that the last checkpoint is saved at the last iteration.
global_step += 1
print(f'Resuming wandb with existing run_id {run_id}.')
# Tell in which mode to launch the logging in W&B (for offline cluster)
if args.offline_log:
mode = "offline"
else:
mode = "online"
wandb.init(project='osrt', name=os.path.dirname(args.config),
id=run_id, resume=True, mode=mode, sync_tensorboard=True)
wandb.config = cfg"""
# Log the training loss.
if not global_step % cfg["training"]["print_every"]:
print(f"Step: {global_step}, Loss: {loss_value}, Time: {datetime.timedelta(seconds=time.time() - start)}")
start = time.time()
epochs = num_train_steps // len(train_loader)
for epoch in range(epochs):
total_loss = 0
model.train()
for batch in tqdm(train_loader):
# Learning rate warm-up.
if global_step < warmup_steps:
learning_rate = base_learning_rate * global_step / warmup_steps
else:
learning_rate = base_learning_rate
learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
total_loss += train_step(batch, model, optimizer, device)
global_step += 1
total_loss /= len(train_loader)
# We save the checkpoints
if not global_step % cfg["training"]["checkpoint_every"]:
if not epoch % cfg["training"]["checkpoint_every"]:
# Save the checkpoint of the model.
ckpt['global_step'] = global_step
ckpt['model_state_dict'] = model.state_dict()
torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
# We visualize some test data
if not global_step % cfg["training"]["visualize_every"]:
if not epoch % cfg["training"]["visualize_every"]:
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128)
image = image.to(device)
recon_combined, recons, masks, slots = model(image)
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
# Log the training loss.
if not epoch % cfg["training"]["print_every"]:
print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
# We visualize some test data
if not epoch % cfg["training"]["validate_every"]:
val_loss = 0
val_psnr = 0
model.eval()
for batch in tqdm(val_loader):
mse, psnr = eval_step(batch, model, device)
val_loss += mse
val_psnr += psnr
val_loss /= len(val_loader)
val_psnr /= len(val_loader)
print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}")
model.train()
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment