-
Alexandre Chapin authored6d55e14b
train_sa.py 7.17 KiB
import datetime
import time
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import yaml
from osrt.model import SlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
def train_step(batch, model, optimizer, device):
"""Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = nn.MSELoss(recon_combined, input_image)
del recons, masks, slots # Unused.
# Get and apply gradients.
optimizer.zero_grad()
loss_value.backward()
optimizer.step()
return loss_value.item()
def eval_step(batch, model, device):
"""Perform a single eval step."""
input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = model(input_image)
recon_combined, recons, masks, slots = preds
input_image = input_image.permute(0, 2, 3, 1)
loss_value = nn.MSELoss(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
return loss_value.item(), psnr.item()
def main():
# Arguments
parser = argparse.ArgumentParser(
description='Train a 3D scene representation model.'
)
parser.add_argument('config', type=str, help="Where to save the checkpoints.")
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
args = parser.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
### Set random seed.
torch.manual_seed(args.seed)
### Hyperparameters of the model.
batch_size = cfg["training"]["batch_size"]
num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"]
base_learning_rate = 0.0004
num_train_steps = cfg["training"]["max_it"]
warmup_steps = cfg["training"]["warmup_it"]
decay_rate = cfg["training"]["decay_rate"]
decay_steps = cfg["training"]["decay_it"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resolution = (128, 128)
#### Create datasets
train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
val_dataset = data.get_dataset('val', cfg['data'])
val_loader = DataLoader(
val_dataset, batch_size=batch_size, num_workers=1,
shuffle=True, worker_init_fn=data.worker_init_fn)
vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:')
print(f'Model slot attention: {num_params}')
optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
#### Prepare checkpoint manager.
global_step = 0
ckpt = {
'network': model,
'optimizer': optimizer,
'global_step': global_step
}
ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
# ckpt = torch.load(args.ckpt + '/ckpt.pth')
model = ckpt['network']
optimizer = ckpt['optimizer']
global_step = ckpt['global_step']
""" TODO : setup wandb
if args.wandb:
if run_id is None:
run_id = wandb.util.generate_id()
print(f'Sampled new wandb run_id {run_id}.')
else:
print(f'Resuming wandb with existing run_id {run_id}.')
# Tell in which mode to launch the logging in W&B (for offline cluster)
if args.offline_log:
mode = "offline"
else:
mode = "online"
wandb.init(project='osrt', name=os.path.dirname(args.config),
id=run_id, resume=True, mode=mode, sync_tensorboard=True)
wandb.config = cfg"""
start = time.time()
epochs = num_train_steps // len(train_loader)
for epoch in range(epochs):
total_loss = 0
model.train()
for batch in tqdm(train_loader):
# Learning rate warm-up.
if global_step < warmup_steps:
learning_rate = base_learning_rate * global_step / warmup_steps
else:
learning_rate = base_learning_rate
learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
total_loss += train_step(batch, model, optimizer, device)
global_step += 1
total_loss /= len(train_loader)
# We save the checkpoints
if not epoch % cfg["training"]["checkpoint_every"]:
# Save the checkpoint of the model.
ckpt['global_step'] = global_step
ckpt['model_state_dict'] = model.state_dict()
torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
# We visualize some test data
if not epoch % cfg["training"]["visualize_every"]:
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128)
image = image.to(device)
recon_combined, recons, masks, slots = model(image)
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
# Log the training loss.
if not epoch % cfg["training"]["print_every"]:
print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
# We visualize some test data
if not epoch % cfg["training"]["validate_every"]:
val_loss = 0
val_psnr = 0
model.eval()
for batch in tqdm(val_loader):
mse, psnr = eval_step(batch, model, device)
val_loss += mse
val_psnr += psnr
val_loss /= len(val_loader)
val_psnr /= len(val_loader)
print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}")
model.train()
if __name__ == "__main__":
main()