import torch
import torchvision

from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
import random

root = '../data/processed_data'
dataset = torchvision.datasets.ImageFolder(root, transform=None)
data_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

class Threshold_noise:
    """Rotate by one of the given angles."""

    def __init__(self, threshold):
        self.threshold = threshold

    def __call__(self, x):
        angle = random.choice(self.angles)
        return torch.max(x,0)

rotation_transform = Threshold_noise(threshold=100)