diff --git a/automatic_mask_train.py b/automatic_mask_train.py index 1110115b001d048d25e115fdda97d47befcea4d9..9e0ee5974160169948175804051a75ef53c1a6f2 100644 --- a/automatic_mask_train.py +++ b/automatic_mask_train.py @@ -1,13 +1,12 @@ import argparse import torch -from osrt.encoder import SamAutomaticMask, FeatureMasking from osrt.model import OSRT -from segment_anything import sam_model_registry import time import matplotlib.pyplot as plt import numpy as np import cv2 + def show_anns(masks): ax = plt.gca() ax.set_autoscale_on(False) @@ -57,13 +56,13 @@ if __name__ == '__main__': cfg['encoder'] = 'sam' cfg['decoder'] = 'slot_mixer' cfg['encoder_kwargs'] = { - 'points_per_side': 12, + 'points_per_side': 32, 'box_nms_thresh': 0.7, 'stability_score_thresh': 0.9, 'pred_iou_thresh': 0.88, 'sam_model': model_type, 'sam_path': checkpoint, - 'points_per_batch': 16 + 'points_per_batch': 12 } cfg['decoder_kwargs'] = { 'pos_start_octave': -5, @@ -71,10 +70,10 @@ if __name__ == '__main__': model = OSRT(cfg)#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64) model.to(device, non_blocking=True) - num_encoder_params = sum(p.numel() for p in model.encoder.parameters()) + """num_encoder_params = sum(p.numel() for p in model.encoder.parameters()) num_decoder_params = sum(p.numel() for p in model.decoder.parameters()) - """print('Number of parameters:') + print('Number of parameters:') print(f'\tEncoder: {num_encoder_params}') num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters()) @@ -87,12 +86,12 @@ if __name__ == '__main__': print(f'\t\t\tMask Decoder: {num_mask_params}.') print(f'\t\t\tPrompt Encoder: {num_prompt_params}.') print(f'\t\tSlot Attention: {num_slotatt_params}.') - print(f'\tDecoder: {num_decoder_params}') -""" + print(f'\tDecoder: {num_decoder_params}')""" + images= [] from torchvision import transforms transform = transforms.ToTensor() - for j in range(2): + for j in range(10): image = images_path[j] img = cv2.imread(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -119,12 +118,12 @@ if __name__ == '__main__': # TODO : set ray and camera directions #with torch.no_grad(): with torch.cuda.amp.autocast(): - masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True) + masks = model.encoder.mask_generator(images_t, (h, w), None, None, extract_embeddings=False) end = time.time() print(f"Inference time : {int((end-start) * 1000)}ms") if args.visualize: - for j in range(2): + for j in range(10): image = images_path[j] img = cv2.imread(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) diff --git a/osrt/encoder.py b/osrt/encoder.py index aa258ee9355d00b4d7cb6828f35a91bc04736b21..f5b63c6ee3ff1339d4c4d70eaeceb4f48eede921 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -123,8 +123,8 @@ class FeatureMasking(nn.Module): num_slots=32, slot_dim=1536, slot_iters=3, - sam_model="default", - sam_path="sam_vit_h_4b8939.pth", + sam_model="vit_t", + sam_path="mobile_sam.pt", randomize_initial_slots=False): super().__init__() @@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module): self.mask_decoder = mask_decoder for param in self.mask_decoder.parameters(): param.requires_grad = True - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) - self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) # Transform image to a square by putting it to the longest side #self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR) @@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module): nn.Linear(2500, self.token_dim), ) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + # Space positional embedding self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, ray_octaves=15) diff --git a/sam_test.py b/sam_test.py index fa34060be430b1898c9faf275cab317ca6abc202..56d4d8682666f13534b2ec891e9cf8350922acc4 100644 --- a/sam_test.py +++ b/sam_test.py @@ -1,18 +1,14 @@ -import argparse import torch from segment_anything import sam_model_registry, SamAutomaticMaskGenerator -from torchvision import transforms -from PIL import Image -import time import matplotlib.pyplot as plt -import matplotlib as mpl import numpy as np import cv2 +import time -def show_anns(masks): - if len(masks) == 0: +def show_anns(anns): + if len(anns) == 0: return - sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) @@ -24,100 +20,44 @@ def show_anns(masks): img[m] = color_mask ax.imshow(img) -def show_points(coords, ax, marker_size=100): - ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size) - - -if __name__ == '__main__': - # Arguments - parser = argparse.ArgumentParser( - description='Test Segment Anything Auto Mask simplified implementation' - ) - parser.add_argument('--model', default='vit_b', type=str, help='Model to use') - parser.add_argument('--path_model', default='.', type=str, help='Path to the model') - - args = parser.parse_args() - device = "cuda" - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - model_type = args.model - if args.model == 'vit_h': - checkpoint = args.path_model + '/sam_vit_h_4b8939.pth' - elif args.model == 'vit_b': - checkpoint = args.path_model + '/sam_vit_b_01ec64.pth' - else: - checkpoint = args.path_model + '/sam_vit_l_0b3195.pth' - ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/" - images_path = [] - with open(ycb_path + "image_sets/train.txt", 'r') as f: - for line in f.readlines(): - line = line.strip() - images_path.append(ycb_path + 'data/' + line + "-color.png") - import random - random.shuffle(images_path) +model_type = "vit_t" +sam_checkpoint = "./mobile_sam.pt" - sam = sam_model_registry[model_type](checkpoint=checkpoint) - sam.to(device=device) - mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88) +device = "cuda" if torch.cuda.is_available() else "cpu" - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - labels = [1 for i in range(len(mask_generator.point_grids))] - with torch.no_grad(): - for image in images_path: - img = cv2.imread(image) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - """img_depth = cv2.imread(image.replace("color", "depth")) - img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)""" +mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) +mobile_sam.to(device=device) +mobile_sam.eval() - h, w, _ = img.shape - points = mask_generator.point_grids[0] - new_points= [] - for val in points: - x, y = val[0], val[1] - x *= w - y *= h - new_points.append([x, y]) - new_points = np.array(new_points) +mask_generator = SamAutomaticMaskGenerator(mobile_sam, points_per_side=16, points_per_batch= 12) +ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/" +images_path = [] +with open(ycb_path + "image_sets/train.txt", 'r') as f: + for line in f.readlines(): + line = line.strip() + images_path.append(ycb_path + 'data/' + line + "-color.png") - start = time.time() - masks = mask_generator.generate(img) - end = time.time() - print(f"Inference time : {int((end-start) * 1000)}ms") - - plt.figure(figsize=(15,15)) - plt.imshow(img) - show_anns(masks) - show_points(new_points, plt.gca()) - plt.axis('off') - plt.show() - - """fig, ax = plt.subplots() - cmap = plt.cm.get_cmap('plasma') - img = ax.imshow(img_depth, cmap=cmap) - cbar = fig.colorbar(img, ax=ax) - depth_array_new = img.get_array() - plt.show() - - depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB) - plt.imshow(depth_array_new) - plt.show() - print(depth_array_new.shape) - - start = time.time() - masks = mask_generator.generate(depth_array_new) - end = time.time() - print(f"Inference time : {int((end-start) * 1000)}ms") - - - plt.figure(figsize=(15,15)) - plt.imshow(depth_array_new) - show_anns(masks) - show_points(new_points, plt.gca()) - plt.axis('off') - plt.show()""" +import random +#random.shuffle(images_path) +images= [] +from torchvision import transforms +transform = transforms.ToTensor() +for j in range(20): + image = images_path[j] + img = cv2.imread(image) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + images.append(transform(img).unsqueeze(0)) +images_t = torch.stack(images).to(device) + +start = time.time() +masks = mask_generator.generate(images_t) +end = time.time() +print(f"Inference time : {int((end-start) * 1000)}ms") +plt.figure(figsize=(15,15)) +plt.imshow(img) +show_anns(masks) # show masks +plt.axis('off') +plt.show() \ No newline at end of file