import yaml from osrt import data from torch.utils.data import DataLoader import matplotlib.pyplot as plt with open("runs/ycb/slot_att/config.yaml", 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader) cfg['data']['path'] = "/home/achapin/Documents/Datasets/ycbv_small" train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True) for val in train_loader: print(f"Shape masks {val['input_masks'].shape}") fig, axes = plt.subplots(2, 2) axes[0][0].imshow(val['input_images'][0].permute(1, 2, 0)) axes[0][1].imshow(val['input_masks'][0][:, :, 0]) axes[1][0].imshow(val['input_images'][1].permute(1, 2, 0)) axes[1][1].imshow(val['input_masks'][1][:, :, 0]) plt.show()