diff --git a/.gitmodules b/.gitmodules index a8ff4b8fe903913972aeaee9f41e0109c6ac84f1..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "segment-anything"] - path = segment-anything - url = https://github.com/facebookresearch/segment-anything.git diff --git a/runs/test/config.json b/runs/test/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6246b88b65bdf17cf8f6d50d39ea85c0214afa8c --- /dev/null +++ b/runs/test/config.json @@ -0,0 +1,38 @@ +{ + "data": { + "dataset": "clevr3d", + "num_points": 2000 , + "kwargs": { + "downsample": 1 + } + }, + "model":{ + "encoder": "osrt", + "encoder_kwargs": { + "pos_start_octave": -5, + "num_slots": 6 + }, + "decoder": "slot_mixer", + "decoder_kwargs":{ + "pos_start_octave": -5 + } + }, + "training":{ + "num_workers": 4, + "batch_size": 64, + "num_gpu": 8, + "model_selection_metric": "psnr", + "model_selection_mode": "max", + "print_every": 10, + "visualize_every": 5000, + "validate_every": 5000, + "checkpoint_every": 1000, + "backup_every": 25000, + "max_it": 333000000, + "decay_it": 4000000, + "lr_warmup": 5000, + "precision": "16-mixed", + "out_dir": "." + } + +} \ No newline at end of file diff --git a/train_lit.py b/train_lit.py index ef1d7ba42a3d572c21276cac3c9d06fbc0280a65..da98f7785c642f6a4b845d4a54356719114e1967 100644 --- a/train_lit.py +++ b/train_lit.py @@ -1,99 +1,224 @@ """ -Code inspired from Lit-Llama training script : https://github.com/Lightning-AI/lit-llama/blob/main/finetune/full.py +Code inspired and adapted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/train.py """ -import sys -from pathlib import Path + import os import time -from functools import partial +import json +import argparse +import math import lightning as L -from lightning.fabric.strategies import FSDPStrategy -import numpy as np +import segmentation_models_pytorch as smp import torch +import torch.nn.functional as F +from lightning.fabric.fabric import _FabricOptimizer +from lightning.fabric.loggers import TensorBoardLogger from torch.utils.data import DataLoader -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from jsonargparse.cli import CLI -import json -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) +from osrt.model import OSRT +from osrt.encoder import FeatureMasking +from osrt import data +from osrt.utils.training import AverageMeter +from osrt.utils.losses import DiceLoss, FocalLoss -from generate import generate -from lit_llama.model import Block, LLaMA, LLaMAConfig -from lit_llama.tokenizer import Tokenizer -from lit_llama.utils import save_model_checkpoint -from scripts.prepare_alpaca import generate_prompt +torch.set_float32_matmul_precision('high') -from osrt.layers import Transformer -from osrt import data -from osrt.model import OSRT +__LOG10 = math.log(10) -from segment_anything.modeling.transformer import TwoWayTransformer - - -instruction_tuning = True -eval_interval = 1000 -save_interval = 1000 -eval_iters = 100 -log_interval = 100 - -# Hyperparameters -learning_rate = 3e-5 -micro_batch_size = 4 -"""gradient_accumulation_iters = batch_size // micro_batch_size -assert gradient_accumulation_iters > 0""" -epoch_size = 50000 # train dataset size -num_epochs = 5 -#max_iters = num_epochs * (epoch_size // micro_batch_size) // devices -weight_decay = 0.0 -block_size = 512 -warmup_iters = 100 - -class LrScheduler(): - """ Implements a learning rate schedule with warum up and decay """ - def __init__(self, peak_lr=4e-4, peak_it=10000, decay_rate=0.5, decay_it=100000): - self.peak_lr = peak_lr - self.peak_it = peak_it - self.decay_rate = decay_rate - self.decay_it = decay_it - - def get_cur_lr(self, it): - if it < self.peak_it: # Warmup period - return self.peak_lr * (it / self.peak_it) - it_since_peak = it - self.peak_it - return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it)) - - -def main( - config_path:str, - data_dir: str = "data/alpaca", - out_dir: str = "out/full/alpaca", - checkpoint :str = None +def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0): + # TODO : add segmentation also to select the model following how it's done in the training + model.eval() + mses = AverageMeter() + psnrs = AverageMeter() + + sceneids = [] + + with torch.no_grad(): + for iter, data in enumerate(val_dataloader): + sceneids.append(data['sceneid']) + + input_images = data.get('input_images') + input_camera_pos = data.get('input_camera_pos') + input_rays = data.get('input_rays') + target_pixels = data.get('target_pixels') + + if isinstance(model.encoder, FeatureMasking): + input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c] + h, w, c = input_images[0][0].shape + z = model.encoder(input_images,(h, w), input_camera_pos, input_rays) + else: + z = model.encoder(input_images, input_camera_pos, input_rays) + + target_camera_pos = data.get('target_camera_pos') + target_rays = data.get('target_rays') + + loss_mse = torch.tensor(0., device=fabric.device) + pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs) + + ### Compute MSE on pixels + loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2)) + psnr = -10.*torch.log(loss_mse)/__LOG10 + mses.update(loss_mse) + psnrs.update(psnr) + fabric.print(f"Val [{epoch}] - [{iter}/{len(val_dataloader)}] : psnr {psnr}, mse: {loss_mse}") + + fabric.print(f'Validation [{epoch}]: Mean psnr: [{psnrs.avg:.4f}] -- Mean mse: [{mses.avg:.4f}]') + + + fabric.print(f"Saving checkpoint to {cfg.out_dir}") + state_dict = model.state_dict() + if fabric.global_rank == 0: + torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth")) + model.train() + + +def train_sam( + cfg, + fabric: L.Fabric, + model: OSRT, + optimizer: _FabricOptimizer, + scheduler: _FabricOptimizer, + train_dataloader: DataLoader, + val_dataloader: DataLoader, ): + """The SAM training loop.""" + + focal_loss = FocalLoss() + dice_loss = DiceLoss() + nb_epochs = cfg["training"]["max_it"] // cfg["training"]["batch_size"] + for epoch in range(1, nb_epochs): + # TODO : add psnr loss ? + batch_time = AverageMeter() + data_time = AverageMeter() + focal_losses = AverageMeter() + dice_losses = AverageMeter() + mse_losses = AverageMeter() + total_losses = AverageMeter() + end = time.time() + validated = False + + for iter, data in enumerate(train_dataloader): + if epoch > 1 and epoch % cfg["training"]["validate_every"] == 0 and not validated: + validate(fabric, model, val_dataloader, epoch) + validated = True + + data_time.update(time.time() - end) + + # TODO : adapt to our model + input_images = data.get('input_images') + input_camera_pos = data.get('input_camera_pos') + input_rays = data.get('input_rays') + target_pixels = data.get('target_pixels') + + if isinstance(model.encoder, FeatureMasking): + input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c] + h, w, c = input_images[0][0].shape + masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True) + else: + z = model.encoder(input_images, input_camera_pos, input_rays) + + target_camera_pos = data.get('target_camera_pos') + target_rays = data.get('target_rays') + + loss_mse = torch.tensor(0., device=fabric.device) + loss_focal = torch.tensor(0., device=fabric.device) + loss_dice = torch.tensor(0., device=fabric.device) + pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs) + + ### Compute MSE on pixels + loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2)) + + batch_size = input_images.shape[0] + + if 'segmentation' in extras: + # TODO : for visualisation only, could be interesting to check real GT + #true_seg = data['target_masks'].float() + + pred_masks = extras['segmentation'] + # TODO : check the content of num_masks + num_masks = sum(len(pred_mask) for pred_mask in pred_mask) + for pred_mask, gt_mask in zip(pred_masks, masks_info["segmentations"]): + loss_focal += focal_loss(pred_mask, gt_mask, num_masks) + loss_dice += dice_loss(pred_mask, gt_mask, num_masks) + + # TODO : check the values of the loss and see if scale is ok + loss_total = 20. * loss_focal + loss_dice + loss_mse + # TODO : check also with ARI, FG-ARI values and new from recent paper + """loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2), + pred_seg.transpose(1, 2)) + + loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], + pred_seg.transpose(1, 2))""" + + optimizer.zero_grad() + fabric.backward(loss_total) + optimizer.step() + scheduler.step() + batch_time.update(time.time() - end) + end = time.time() + + focal_losses.update(loss_focal.item(), batch_size) + dice_losses.update(loss_dice.item(), batch_size) + mse_losses.update(loss_mse.item(), batch_size) + total_losses.update(loss_total.item(), batch_size) + + fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]' + f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]' + f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]' + f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]' + f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]' + f' | MSE Loss [{mse_losses.val:.4f} ({mse_losses.avg:.4f})]' + f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]') + +def configure_opt(cfg, model: OSRT): + warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000 + peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500 + peak_lr = 1e-4 + decay_rate=0.16 - with open(config_path, 'r') as f: - cfg = json.load(f) + # LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16) + def lr_lambda(step): + if step < peak_it: # Warmup period + return peak_lr * (step / peak_it) + it_since_peak = step - peak_it + return peak_lr * (decay_rate ** (it_since_peak / warmup_iters)) + + # TODO : check begin value of lr + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=decay_rate) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + return optimizer, scheduler + + +def main(cfg) -> None: + + ######################### + ### Setup parameters + ######################### num_devices = cfg['training']['num_gpu'] if 'num_gpu' in cfg['training'] else 1 num_workers = cfg['training']['num_workers'] if 'num_workers' in cfg['training'] else 1 batch_size = cfg['training']['batch_size'] // num_devices - - auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Transformer, TwoWayTransformer}) - strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing={Transformer, TwoWayTransformer}, limit_all_gathers=True) - - # TODO : activer precision bf16 - fabric = L.Fabric(accelerator="cuda", devices=num_devices, precision=cfg["training"]["precision"], strategy=strategy) + + ######################### + ### Launch the model + ######################### + fabric = L.Fabric(accelerator="gpu", + devices=num_devices, + strategy="auto", + loggers=[TensorBoardLogger(cfg['training']['out_dir'], name="lightning-sam")]) fabric.launch() fabric.seed_everything(1337 + fabric.global_rank) if fabric.global_rank == 0: - os.makedirs(out_dir, exist_ok=True) + os.makedirs(cfg['training']['out_dir'], exist_ok=True) - ################### - # Import Dataset - ################### + with fabric.device: + model = OSRT(cfg) + + ######################### + ### Loading the dataset + ######################### train_dataset = data.get_dataset('train', cfg['data']) val_dataset = data.get_dataset('val', cfg['data']) test_dataset = data.get_dataset('test', cfg['data']) @@ -107,162 +232,36 @@ def main( train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) + data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization + data_vis_val = fabric.to_device(data_vis_val) - if checkpoint: - checkpoint = torch.load(checkpoint) - - with fabric.device: - torch.set_default_tensor_type(torch.HalfTensor) - model = OSRT(cfg['model']).bfloat16() - torch.set_default_tensor_type(torch.FloatTensor) - if checkpoint: - model.load_state_dict(checkpoint, strict=False) - - model = fabric.setup_module(model) - - params = [p for p in model.parameters() if p.requires_grad] - # Setup scheduler - warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000 - peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500 - lr_scheduler = LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16) - - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False) - optimizer = fabric.setup_optimizers(optimizer) - - train(fabric, model, optimizer, train_loader, val_loader, out_dir) - - # Save the final checkpoint at the end of training - save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth")) - - -def train( - fabric: L.Fabric, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - train_data: DataLoader, # TODO : maybe use np.array - val_data: DataLoader, - out_dir: str, -) -> None: - """The training loop. - - Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. - """ - step_count = 0 - model.train() + ######################### + ### Prepare the optimizer + ######################### + optimizer, scheduler = configure_opt(cfg, model) + model, optimizer = fabric.setup(model, optimizer) - for iter_num in range(max_iters): + ######################### + ### Training + ######################### + train_sam(cfg, fabric, model, optimizer, scheduler, train_loader, val_loader) + validate(fabric, model, val_loader, epoch=0) - is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0 - - if step_count <= warmup_iters: - # linear warmup - lr = learning_rate * step_count / warmup_iters - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - t0 = time.time() - - input_ids, targets = get_batch(fabric, train_data) - with fabric.no_backward_sync(model, enabled=is_accumulating): - logits = model(input_ids) - loss = loss_fn(logits, targets) - fabric.backward(loss / gradient_accumulation_iters) - - if not is_accumulating: - optimizer.step() - optimizer.zero_grad() - step_count += 1 - - if step_count % eval_interval == 0: - val_loss = validate(fabric, model, val_data) - fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") - fabric.barrier() - - if step_count % save_interval == 0: - print(f"Saving weights to {out_dir}") - save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")) - - dt = time.time() - t0 - if iter_num % log_interval == 0: - fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") - -def generate_response(model, instruction): - tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model") - sample = {"instruction": instruction, "input": ""} - prompt = instruction - if instruction_tuning: - prompt = generate_prompt(sample) - encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) - - output = generate( - model, - idx=encoded, - max_seq_length=block_size, - max_new_tokens=100, +if __name__ == "__main__": + ### Arguments + parser = argparse.ArgumentParser( + description='Train a 3D scene representation model.' ) - output = tokenizer.decode(output) - return output # output.split("### Response:")[1].strip() - + parser.add_argument('config', type=str, help='Path to config file.') + parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') + parser.add_argument('--checkpoint', type=str, default='', help='Path to a model checkpoint') -@torch.no_grad() -def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor: - fabric.print("Validating ...") - model.eval() - losses = torch.zeros(eval_iters) - for k in range(eval_iters): - input_ids, targets = get_batch(fabric, val_data) - logits = model(input_ids) - loss = loss_fn(logits, targets) - losses[k] = loss.item() - out = losses.mean() - - # produce an example: - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + args = parser.parse_args() - output = generate_response(model, instruction) - fabric.print(instruction) - fabric.print(output) - - model.train() - return out.item() - -def loss_fn(logits, targets): - # shift the targets such that output n predicts token n+1 - logits = logits[..., :-1, :].contiguous() - targets = targets[..., 1:].contiguous() - loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - return loss - - -def get_batch(fabric: L.Fabric, data: list): - ix = torch.randint(len(data), (micro_batch_size,)) - - input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] - labels = [data[i]["labels"].type(torch.int64) for i in ix] - - max_len = max(len(s) for s in input_ids) - - def pad_right(x, pad_id): - # pad right based on the longest sequence - n = max_len - len(x) - return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) - - x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) - y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) - x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) - return x, y - - -def load_datasets(data_dir): - train_data = torch.load(os.path.join(data_dir, "train.pt")) - val_data = torch.load(os.path.join(data_dir, "test.pt")) - return train_data, val_data - - -if __name__ == "__main__": - # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" - # torch.backends.cuda.enable_flash_sdp(False) - torch.set_float32_matmul_precision("high") - - CLI(main) \ No newline at end of file + ######################### + ### Creating utility var + ######################### + with open(args.config, 'r') as f: + cfg = json.load(f) + main(cfg) \ No newline at end of file