Skip to content
Snippets Groups Projects
Commit d012a29a authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : all tensor to float32

parent d5e36476
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ def load_args():
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data')
parser.add_argument('--dataset_test_dir', type=str, default=None)
parser.add_argument('--outname', type=str, default=None)
parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
......
......@@ -10,9 +10,64 @@ import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict
from skimage import measure
from sklearn.model_selection import train_test_split
IMG_EXTENSIONS = ".npy"
class Random_erasing2:
"""with a probability prob, erases a proportion prop of the image"""
def __init__(self, prob, prop):
self.prob = prob
self.prop = prop
def __call__(self, x):
if np.random.rand() < self.prob:
labels = measure.label(x.numpy() > 0, connectivity=1)
regions = measure.regionprops(labels)
pics_suppr = np.random.rand(len(regions)) < self.prop
for k in range(len(regions)):
if pics_suppr[k]:
try:
_, y1, x1, _, y2, x2 = regions[k].bbox
except:
raise Exception(regions[k].bbox)
x[:, y1:y2, x1:x2] *= regions[k].image == False
return x
return x
class Random_int_noise:
"""With a probability prob, adds a gaussian noise to the image """
def __init__(self, prob, maximum):
self.prob = prob
self.minimum = 1 / maximum
self.delta = maximum - self.minimum
def __call__(self, x):
if np.random.rand() < self.prob:
return x * (self.minimum + torch.rand_like(x) * self.delta)
return x
class Random_shift_rt:
"""With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
def __init__(self, prob, mean, std):
self.prob = prob
self.mean = torch.tensor(float(mean))
self.std = float(std)
def __call__(self, x):
if np.random.rand() < self.prob:
shift = torch.normal(self.mean, self.std)
return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
return x
class Threshold_noise:
"""Remove intensities under given threshold"""
......@@ -31,9 +86,6 @@ class Log_normalisation:
def __call__(self, x):
return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
class Random_shift_rt:
pass
class Repeat:
"Repeat tensor along new dim"
......
......@@ -97,11 +97,11 @@ def run(args):
plt.plot(train_acc)
plt.ylim(0, 1.05)
plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
plt.savefig('output/training_plot_{}.png'.format(args.outname))
#load and evaluated best model
load_model(model, args.save_path)
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
make_prediction(model,data_test, 'output/confusion_matrix_{}.png'.format(args.outname))
def make_prediction(model, data, f_name):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment