Skip to content
Snippets Groups Projects
Commit d012a29a authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : all tensor to float32

parent d5e36476
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,7 @@ def load_args(): ...@@ -15,6 +15,7 @@ def load_args():
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data') parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data')
parser.add_argument('--dataset_test_dir', type=str, default=None) parser.add_argument('--dataset_test_dir', type=str, default=None)
parser.add_argument('--outname', type=str, default=None)
parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None) parser.add_argument('--pretrain_path', type=str, default=None)
......
...@@ -10,9 +10,64 @@ import os.path ...@@ -10,9 +10,64 @@ import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
from skimage import measure
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
IMG_EXTENSIONS = ".npy" IMG_EXTENSIONS = ".npy"
class Random_erasing2:
"""with a probability prob, erases a proportion prop of the image"""
def __init__(self, prob, prop):
self.prob = prob
self.prop = prop
def __call__(self, x):
if np.random.rand() < self.prob:
labels = measure.label(x.numpy() > 0, connectivity=1)
regions = measure.regionprops(labels)
pics_suppr = np.random.rand(len(regions)) < self.prop
for k in range(len(regions)):
if pics_suppr[k]:
try:
_, y1, x1, _, y2, x2 = regions[k].bbox
except:
raise Exception(regions[k].bbox)
x[:, y1:y2, x1:x2] *= regions[k].image == False
return x
return x
class Random_int_noise:
"""With a probability prob, adds a gaussian noise to the image """
def __init__(self, prob, maximum):
self.prob = prob
self.minimum = 1 / maximum
self.delta = maximum - self.minimum
def __call__(self, x):
if np.random.rand() < self.prob:
return x * (self.minimum + torch.rand_like(x) * self.delta)
return x
class Random_shift_rt:
"""With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
def __init__(self, prob, mean, std):
self.prob = prob
self.mean = torch.tensor(float(mean))
self.std = float(std)
def __call__(self, x):
if np.random.rand() < self.prob:
shift = torch.normal(self.mean, self.std)
return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
return x
class Threshold_noise: class Threshold_noise:
"""Remove intensities under given threshold""" """Remove intensities under given threshold"""
...@@ -31,9 +86,6 @@ class Log_normalisation: ...@@ -31,9 +86,6 @@ class Log_normalisation:
def __call__(self, x): def __call__(self, x):
return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon) return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
class Random_shift_rt:
pass
class Repeat: class Repeat:
"Repeat tensor along new dim" "Repeat tensor along new dim"
......
...@@ -97,11 +97,11 @@ def run(args): ...@@ -97,11 +97,11 @@ def run(args):
plt.plot(train_acc) plt.plot(train_acc)
plt.ylim(0, 1.05) plt.ylim(0, 1.05)
plt.show() plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) plt.savefig('output/training_plot_{}.png'.format(args.outname))
#load and evaluated best model #load and evaluated best model
load_model(model, args.save_path) load_model(model, args.save_path)
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) make_prediction(model,data_test, 'output/confusion_matrix_{}.png'.format(args.outname))
def make_prediction(model, data, f_name): def make_prediction(model, data, f_name):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment