diff --git a/osrt/trainer.py b/osrt/trainer.py
index 92764993fb24bb6364bcfa4e5b8cec54b94e75a7..0fc9285c9277053f959023cdd45f4cfaff9fae85 100644
--- a/osrt/trainer.py
+++ b/osrt/trainer.py
@@ -1,4 +1,5 @@
 import torch
+import torch.distributed as dist
 import numpy as np
 from tqdm import tqdm
 
@@ -11,80 +12,28 @@ import os
 import math
 from collections import defaultdict
 
-class OSRTSamTrainer:
-    def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
-        self.model = model
-        self.optimizer = optimizer
-        self.config = cfg
-        self.device = device
-        self.out_dir = out_dir
-        self.render_kwargs = render_kwargs
-        if 'num_coarse_samples' in cfg['training']:
-            self.render_kwargs['num_coarse_samples'] = cfg['training']['num_coarse_samples']
-        if 'num_fine_samples' in cfg['training']:
-            self.render_kwargs['num_fine_samples'] = cfg['training']['num_fine_samples']
-
-    def evaluate(self, val_loader, **kwargs):
-        ''' Performs an evaluation.
-        Args:
-            val_loader (dataloader): pytorch dataloader
-        '''
-        self.model.eval()
-        eval_lists = defaultdict(list)
+def train(args, model, rank, world_size, train_loader, optimizer, epoch):
+    ddp_loss = torch.zeros(2).to(rank)
+    for batch_idx, (data, target) in enumerate(train_loader):
+        model.train()
+        optimizer.zero_grad()
 
-        loader = val_loader if get_rank() > 0 else tqdm(val_loader)
-        sceneids = []
+        input_images = data.get('input_images').to(rank)
+        input_camera_pos = data.get('input_camera_pos').to(rank)
+        input_rays = data.get('input_rays').to(rank)
+        target_pixels = data.get('target_pixels').to(rank)
 
-        for data in loader:
-            sceneids.append(data['sceneid'])
-            eval_step_dict = self.eval_step(data, **kwargs)
-
-            for k, v in eval_step_dict.items():
-                eval_lists[k].append(v)
-
-        sceneids = torch.cat(sceneids, 0).cuda()
-        sceneids = torch.cat(gather_all(sceneids), 0)
-
-        print(f'Evaluated {len(torch.unique(sceneids))} unique scenes.')
-
-        eval_dict = {k: torch.cat(v, 0) for k, v in eval_lists.items()}
-        eval_dict = reduce_dict(eval_dict, average=True)  # Average across processes
-        eval_dict = {k: v.mean().item() for k, v in eval_dict.items()}  # Average across batch_size
-        print('Evaluation results:')
-        print(eval_dict)
-        return eval_dict
-
-    def train_step(self, data, it):
-        self.model.train()
-        self.optimizer.zero_grad()
-        loss, loss_terms = self.compute_loss(data, it)
-        loss = loss.mean(0)
-        loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
-        loss.backward()
-        self.optimizer.step()
-        return loss.item(), loss_terms
-
-    def compute_loss(self, data, it):
-        device = self.device
-
-        input_images = data.get('input_images').to(device)
-        input_camera_pos = data.get('input_camera_pos').to(device)
-        input_rays = data.get('input_rays').to(device)
-        target_pixels = data.get('target_pixels').to(device)
-
-        input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k,  h, w, c]
-        h, w, c = input_images[0][0].shape
         with torch.cuda.amp.autocast(): 
-            z = self.model.encoder(input_images, (h, w), input_camera_pos, input_rays)
+            z = model.encoder(input_images, input_camera_pos, input_rays)
 
-        target_camera_pos = data.get('target_camera_pos').to(device)
-        target_rays = data.get('target_rays').to(device)
+        target_camera_pos = data.get('target_camera_pos').to(rank)
+        target_rays = data.get('target_rays').to(rank)
 
         loss = 0.
         loss_terms = dict()
-        
+
         with torch.cuda.amp.autocast(): 
-            pred_pixels, extras = self.model.decoder(z, target_camera_pos, target_rays, **self.render_kwargs)
+            pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
 
         loss = loss + ((pred_pixels - target_pixels)**2).mean((1, 2))
         loss_terms['mse'] = loss
@@ -95,7 +44,7 @@ class OSRTSamTrainer:
 
         if 'segmentation' in extras:
             pred_seg = extras['segmentation']
-            true_seg = data['target_masks'].to(device).float()
+            true_seg = data['target_masks'].to(rank).float()
 
             # These are not actually used as part of the training loss. 
             # We just add the to the dict to report them.
@@ -103,126 +52,20 @@ class OSRTSamTrainer:
                                                             pred_seg.transpose(1, 2))
 
             loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
-                                                               pred_seg.transpose(1, 2))
-
-        return loss, loss_terms
-
-    def eval_step(self, data, full_scale=False):
-        with torch.no_grad():
-            loss, loss_terms = self.compute_loss(data, 1000000)
+                                                                pred_seg.transpose(1, 2))
 
-        mse = loss_terms['mse']
-        psnr = mse2psnr(mse)
-        return {'psnr': psnr, 'mse': mse, **loss_terms}
-
-    def render_image(self, z, camera_pos, rays, **render_kwargs):
-        """
-        Args:
-            z [n, k, c]: set structured latent variables
-            camera_pos [n, 3]: camera position
-            rays [n, h, w, 3]: ray directions
-            render_kwargs: kwargs passed on to decoder
-        """
-        batch_size, height, width = rays.shape[:3]
-        rays = rays.flatten(1, 2)
-        camera_pos = camera_pos.unsqueeze(1).repeat(1, rays.shape[1], 1)
-
-        max_num_rays = self.config['data']['num_points'] * \
-                self.config['training']['batch_size'] // (rays.shape[0] * get_world_size())
-        num_rays = rays.shape[1]
-        img = torch.zeros_like(rays)
-        all_extras = []
-        for i in range(0, num_rays, max_num_rays):
-            img[:, i:i+max_num_rays], extras = self.model.decoder(
-                z, camera_pos[:, i:i+max_num_rays], rays[:, i:i+max_num_rays],
-                **render_kwargs)
-            all_extras.append(extras)
-
-        agg_extras = {}
-        for key in all_extras[0]:
-            agg_extras[key] = torch.cat([extras[key] for extras in all_extras], 1)
-            agg_extras[key] = agg_extras[key].view(batch_size, height, width, -1)
-
-        img = img.view(img.shape[0], height, width, 3)
-        return img, agg_extras
-
-
-    def visualize(self, data, mode='val'):
-        self.model.eval()
-
-        with torch.no_grad():
-            device = self.device
-            input_images = data.get('input_images').to(device)
-            input_camera_pos = data.get('input_camera_pos').to(device)
-            input_rays = data.get('input_rays').to(device)
-
-            camera_pos_base = input_camera_pos[:, 0]
-            input_rays_base = input_rays[:, 0]
-
-            if 'transform' in data:
-                # If the data is transformed in some different coordinate system, where
-                # rotating around the z axis doesn't make sense, we first undo this transform,
-                # then rotate, and then reapply it.
-
-                transform = data['transform'].to(device)
-                inv_transform = torch.inverse(transform)
-                camera_pos_base = nerf.transform_points_torch(camera_pos_base, inv_transform)
-                input_rays_base = nerf.transform_points_torch(
-                    input_rays_base, inv_transform.unsqueeze(1).unsqueeze(2), translate=False)
-            else:
-                transform = None
-
-            input_images_np = np.transpose(input_images.cpu().numpy(), (0, 1, 3, 4, 2))
-
-            z = self.model.encoder(input_images, input_camera_pos, input_rays)
-
-            batch_size, num_input_images, height, width, _ = input_rays.shape
-
-            num_angles = 6
-
-            columns = []
-            for i in range(num_input_images):
-                header = 'input' if num_input_images == 1 else f'input {i+1}'
-                columns.append((header, input_images_np[:, i], 'image'))
-
-            if 'input_masks' in data:
-                input_mask = data['input_masks'][:, 0]
-                columns.append(('true seg 0°', input_mask.argmax(-1), 'clustering'))
-
-            row_labels = None
-            for i in range(num_angles):
-                angle = i * (2 * math.pi / num_angles)
-                angle_deg = (i * 360) // num_angles
-
-                camera_pos_rot = nerf.rotate_around_z_axis_torch(camera_pos_base, angle)
-                rays_rot = nerf.rotate_around_z_axis_torch(input_rays_base, angle)
-
-                if transform is not None:
-                    camera_pos_rot = nerf.transform_points_torch(camera_pos_rot, transform)
-                    rays_rot = nerf.transform_points_torch(
-                        rays_rot, transform.unsqueeze(1).unsqueeze(2), translate=False)
-
-                img, extras = self.render_image(z, camera_pos_rot, rays_rot, **self.render_kwargs)
-                columns.append((f'render {angle_deg}°', img.cpu().numpy(), 'image'))
-
-                if 'depth' in extras:
-                    depth_img = extras['depth'].unsqueeze(-1) / self.render_kwargs['max_dist']
-                    depth_img = depth_img.view(batch_size, height, width, 1)
-                    columns.append((f'depths {angle_deg}°', depth_img.cpu().numpy(), 'image'))
-
-                if 'segmentation' in extras:
-                    pred_seg = extras['segmentation'].cpu()
-                    columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering'))
-                    if i == 0:
-                        ari = compute_adjusted_rand_index(
-                            input_mask.flatten(1, 2).transpose(1, 2)[:, 1:],
-                            pred_seg.flatten(1, 2).transpose(1, 2))
-                        row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]
-
-            output_img_path = os.path.join(self.out_dir, f'renders-{mode}')
-            vis.draw_visualization_grid(columns, output_img_path, row_labels=row_labels)
 
+        loss = loss.mean(0)
+        loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
+        loss.backward()
+        optimizer.step()
+        ddp_loss[0] += loss.item()
+        ddp_loss[1] += len(input_images)
 
+    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
+    if rank == 0:
+        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
+        
 class SRTTrainer:
     def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
         self.model = model
diff --git a/osrt/utils/common.py b/osrt/utils/common.py
index 4cafbf2cb0610b731adac81204b2e4346fca1814..0c848a715cd625412a57245bb2ea208be85c66f9 100644
--- a/osrt/utils/common.py
+++ b/osrt/utils/common.py
@@ -189,6 +189,8 @@ def init_ddp():
     setup_dist_print(local_rank == 0)
     return local_rank, world_size
 
+def cleanup():
+    dist.destroy_process_group()
 
 def setup_dist_print(is_main):
     import builtins as __builtin__
diff --git a/segment-anything b/segment-anything
index 6fdee8f2727f4506cfbbe553e23b895e27956588..f7b29ba9df1496489af8c71a4bdabed7e8b017b1 160000
--- a/segment-anything
+++ b/segment-anything
@@ -1 +1 @@
-Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588
+Subproject commit f7b29ba9df1496489af8c71a4bdabed7e8b017b1