Skip to content
Snippets Groups Projects
Commit f9d36ede authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add setup for training

parent ece9ede3
No related branches found
No related tags found
No related merge requests found
[submodule "segment-anything"]
path = segment-anything
url = https://github.com/facebookresearch/segment-anything.git
{
"data": {
"dataset": "clevr3d",
"num_points": 2000 ,
"kwargs": {
"downsample": 1
}
},
"model":{
"encoder": "osrt",
"encoder_kwargs": {
"pos_start_octave": -5,
"num_slots": 6
},
"decoder": "slot_mixer",
"decoder_kwargs":{
"pos_start_octave": -5
}
},
"training":{
"num_workers": 4,
"batch_size": 64,
"num_gpu": 8,
"model_selection_metric": "psnr",
"model_selection_mode": "max",
"print_every": 10,
"visualize_every": 5000,
"validate_every": 5000,
"checkpoint_every": 1000,
"backup_every": 25000,
"max_it": 333000000,
"decay_it": 4000000,
"lr_warmup": 5000,
"precision": "16-mixed",
"out_dir": "."
}
}
\ No newline at end of file
"""
Code inspired from Lit-Llama training script : https://github.com/Lightning-AI/lit-llama/blob/main/finetune/full.py
Code inspired and adapted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/train.py
"""
import sys
from pathlib import Path
import os
import time
from functools import partial
import json
import argparse
import math
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.nn.functional as F
from lightning.fabric.fabric import _FabricOptimizer
from lightning.fabric.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from jsonargparse.cli import CLI
import json
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from osrt.model import OSRT
from osrt.encoder import FeatureMasking
from osrt import data
from osrt.utils.training import AverageMeter
from osrt.utils.losses import DiceLoss, FocalLoss
from generate import generate
from lit_llama.model import Block, LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import save_model_checkpoint
from scripts.prepare_alpaca import generate_prompt
torch.set_float32_matmul_precision('high')
from osrt.layers import Transformer
from osrt import data
from osrt.model import OSRT
__LOG10 = math.log(10)
from segment_anything.modeling.transformer import TwoWayTransformer
instruction_tuning = True
eval_interval = 1000
save_interval = 1000
eval_iters = 100
log_interval = 100
# Hyperparameters
learning_rate = 3e-5
micro_batch_size = 4
"""gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0"""
epoch_size = 50000 # train dataset size
num_epochs = 5
#max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
weight_decay = 0.0
block_size = 512
warmup_iters = 100
class LrScheduler():
""" Implements a learning rate schedule with warum up and decay """
def __init__(self, peak_lr=4e-4, peak_it=10000, decay_rate=0.5, decay_it=100000):
self.peak_lr = peak_lr
self.peak_it = peak_it
self.decay_rate = decay_rate
self.decay_it = decay_it
def get_cur_lr(self, it):
if it < self.peak_it: # Warmup period
return self.peak_lr * (it / self.peak_it)
it_since_peak = it - self.peak_it
return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
def main(
config_path:str,
data_dir: str = "data/alpaca",
out_dir: str = "out/full/alpaca",
checkpoint :str = None
def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
# TODO : add segmentation also to select the model following how it's done in the training
model.eval()
mses = AverageMeter()
psnrs = AverageMeter()
sceneids = []
with torch.no_grad():
for iter, data in enumerate(val_dataloader):
sceneids.append(data['sceneid'])
input_images = data.get('input_images')
input_camera_pos = data.get('input_camera_pos')
input_rays = data.get('input_rays')
target_pixels = data.get('target_pixels')
if isinstance(model.encoder, FeatureMasking):
input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c]
h, w, c = input_images[0][0].shape
z = model.encoder(input_images,(h, w), input_camera_pos, input_rays)
else:
z = model.encoder(input_images, input_camera_pos, input_rays)
target_camera_pos = data.get('target_camera_pos')
target_rays = data.get('target_rays')
loss_mse = torch.tensor(0., device=fabric.device)
pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
### Compute MSE on pixels
loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2))
psnr = -10.*torch.log(loss_mse)/__LOG10
mses.update(loss_mse)
psnrs.update(psnr)
fabric.print(f"Val [{epoch}] - [{iter}/{len(val_dataloader)}] : psnr {psnr}, mse: {loss_mse}")
fabric.print(f'Validation [{epoch}]: Mean psnr: [{psnrs.avg:.4f}] -- Mean mse: [{mses.avg:.4f}]')
fabric.print(f"Saving checkpoint to {cfg.out_dir}")
state_dict = model.state_dict()
if fabric.global_rank == 0:
torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
model.train()
def train_sam(
cfg,
fabric: L.Fabric,
model: OSRT,
optimizer: _FabricOptimizer,
scheduler: _FabricOptimizer,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
):
"""The SAM training loop."""
focal_loss = FocalLoss()
dice_loss = DiceLoss()
nb_epochs = cfg["training"]["max_it"] // cfg["training"]["batch_size"]
for epoch in range(1, nb_epochs):
# TODO : add psnr loss ?
batch_time = AverageMeter()
data_time = AverageMeter()
focal_losses = AverageMeter()
dice_losses = AverageMeter()
mse_losses = AverageMeter()
total_losses = AverageMeter()
end = time.time()
validated = False
for iter, data in enumerate(train_dataloader):
if epoch > 1 and epoch % cfg["training"]["validate_every"] == 0 and not validated:
validate(fabric, model, val_dataloader, epoch)
validated = True
data_time.update(time.time() - end)
# TODO : adapt to our model
input_images = data.get('input_images')
input_camera_pos = data.get('input_camera_pos')
input_rays = data.get('input_rays')
target_pixels = data.get('target_pixels')
if isinstance(model.encoder, FeatureMasking):
input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c]
h, w, c = input_images[0][0].shape
masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True)
else:
z = model.encoder(input_images, input_camera_pos, input_rays)
target_camera_pos = data.get('target_camera_pos')
target_rays = data.get('target_rays')
loss_mse = torch.tensor(0., device=fabric.device)
loss_focal = torch.tensor(0., device=fabric.device)
loss_dice = torch.tensor(0., device=fabric.device)
pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
### Compute MSE on pixels
loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2))
batch_size = input_images.shape[0]
if 'segmentation' in extras:
# TODO : for visualisation only, could be interesting to check real GT
#true_seg = data['target_masks'].float()
pred_masks = extras['segmentation']
# TODO : check the content of num_masks
num_masks = sum(len(pred_mask) for pred_mask in pred_mask)
for pred_mask, gt_mask in zip(pred_masks, masks_info["segmentations"]):
loss_focal += focal_loss(pred_mask, gt_mask, num_masks)
loss_dice += dice_loss(pred_mask, gt_mask, num_masks)
# TODO : check the values of the loss and see if scale is ok
loss_total = 20. * loss_focal + loss_dice + loss_mse
# TODO : check also with ARI, FG-ARI values and new from recent paper
"""loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))"""
optimizer.zero_grad()
fabric.backward(loss_total)
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
end = time.time()
focal_losses.update(loss_focal.item(), batch_size)
dice_losses.update(loss_dice.item(), batch_size)
mse_losses.update(loss_mse.item(), batch_size)
total_losses.update(loss_total.item(), batch_size)
fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
f' | MSE Loss [{mse_losses.val:.4f} ({mse_losses.avg:.4f})]'
f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
def configure_opt(cfg, model: OSRT):
warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000
peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500
peak_lr = 1e-4
decay_rate=0.16
with open(config_path, 'r') as f:
cfg = json.load(f)
# LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16)
def lr_lambda(step):
if step < peak_it: # Warmup period
return peak_lr * (step / peak_it)
it_since_peak = step - peak_it
return peak_lr * (decay_rate ** (it_since_peak / warmup_iters))
# TODO : check begin value of lr
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=decay_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return optimizer, scheduler
def main(cfg) -> None:
#########################
### Setup parameters
#########################
num_devices = cfg['training']['num_gpu'] if 'num_gpu' in cfg['training'] else 1
num_workers = cfg['training']['num_workers'] if 'num_workers' in cfg['training'] else 1
batch_size = cfg['training']['batch_size'] // num_devices
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Transformer, TwoWayTransformer})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing={Transformer, TwoWayTransformer}, limit_all_gathers=True)
# TODO : activer precision bf16
fabric = L.Fabric(accelerator="cuda", devices=num_devices, precision=cfg["training"]["precision"], strategy=strategy)
#########################
### Launch the model
#########################
fabric = L.Fabric(accelerator="gpu",
devices=num_devices,
strategy="auto",
loggers=[TensorBoardLogger(cfg['training']['out_dir'], name="lightning-sam")])
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
os.makedirs(cfg['training']['out_dir'], exist_ok=True)
###################
# Import Dataset
###################
with fabric.device:
model = OSRT(cfg)
#########################
### Loading the dataset
#########################
train_dataset = data.get_dataset('train', cfg['data'])
val_dataset = data.get_dataset('val', cfg['data'])
test_dataset = data.get_dataset('test', cfg['data'])
......@@ -107,162 +232,36 @@ def main(
train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)
data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization
data_vis_val = fabric.to_device(data_vis_val)
if checkpoint:
checkpoint = torch.load(checkpoint)
with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
model = OSRT(cfg['model']).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
if checkpoint:
model.load_state_dict(checkpoint, strict=False)
model = fabric.setup_module(model)
params = [p for p in model.parameters() if p.requires_grad]
# Setup scheduler
warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000
peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500
lr_scheduler = LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
optimizer = fabric.setup_optimizers(optimizer)
train(fabric, model, optimizer, train_loader, val_loader, out_dir)
# Save the final checkpoint at the end of training
save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_data: DataLoader, # TODO : maybe use np.array
val_data: DataLoader,
out_dir: str,
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0
model.train()
#########################
### Prepare the optimizer
#########################
optimizer, scheduler = configure_opt(cfg, model)
model, optimizer = fabric.setup(model, optimizer)
for iter_num in range(max_iters):
#########################
### Training
#########################
train_sam(cfg, fabric, model, optimizer, scheduler, train_loader, val_loader)
validate(fabric, model, val_loader, epoch=0)
is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
for param_group in optimizer.param_groups:
param_group['lr'] = lr
t0 = time.time()
input_ids, targets = get_batch(fabric, train_data)
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = loss_fn(logits, targets)
fabric.backward(loss / gradient_accumulation_iters)
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
step_count += 1
if step_count % eval_interval == 0:
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.barrier()
if step_count % save_interval == 0:
print(f"Saving weights to {out_dir}")
save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
dt = time.time() - t0
if iter_num % log_interval == 0:
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
def generate_response(model, instruction):
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
sample = {"instruction": instruction, "input": ""}
prompt = instruction
if instruction_tuning:
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
output = generate(
model,
idx=encoded,
max_seq_length=block_size,
max_new_tokens=100,
if __name__ == "__main__":
### Arguments
parser = argparse.ArgumentParser(
description='Train a 3D scene representation model.'
)
output = tokenizer.decode(output)
return output # output.split("### Response:")[1].strip()
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--checkpoint', type=str, default='', help='Path to a model checkpoint')
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data)
logits = model(input_ids)
loss = loss_fn(logits, targets)
losses[k] = loss.item()
out = losses.mean()
# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
args = parser.parse_args()
output = generate_response(model, instruction)
fabric.print(instruction)
fabric.print(output)
model.train()
return out.item()
def loss_fn(logits, targets):
# shift the targets such that output n predicts token n+1
logits = logits[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
def get_batch(fabric: L.Fabric, data: list):
ix = torch.randint(len(data), (micro_batch_size,))
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
max_len = max(len(s) for s in input_ids)
def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
CLI(main)
\ No newline at end of file
#########################
### Creating utility var
#########################
with open(args.config, 'r') as f:
cfg = json.load(f)
main(cfg)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment