Skip to content
Snippets Groups Projects
quick_test.py 775 B
import yaml
from osrt import data
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

with open("runs/ycb/slot_att/config.yaml", 'r') as f:
    cfg = yaml.load(f, Loader=yaml.CLoader)

cfg['data']['path'] = "/home/achapin/Documents/Datasets/ycbv_small"
train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True)


for val in train_loader:
    print(f"Shape masks {val['input_masks'].shape}")
    fig, axes = plt.subplots(2, 2)
    axes[0][0].imshow(val['input_images'][0].permute(1, 2, 0))
    axes[0][1].imshow(val['input_masks'][0][:, :, 0])
    axes[1][0].imshow(val['input_images'][1].permute(1, 2, 0))
    axes[1][1].imshow(val['input_masks'][1][:, :, 0])
    plt.show()