diff --git a/config/config.py b/config/config.py index 8f8a0a0d85dfbe799c478c7705c7baeeeaede269..723a20b757c04b2d7ce24084c835767c2c5ad6d4 100644 --- a/config/config.py +++ b/config/config.py @@ -3,20 +3,21 @@ import argparse def load_args(): parser = argparse.ArgumentParser() + parser.add_argument('--test', default = None) parser.add_argument('--epoches', type=int, default=20) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--augment_args', nargs = '+', type = float, default = [1,0,0,0.99,0.1,0.,7.5]) + parser.add_argument('--augment_args', nargs = '+', type = float, default = [1,0,0,0.05,0.1,0.,7.5]) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam") parser.add_argument('--beta1', type=float, default=0.938) parser.add_argument('--beta2', type=float, default=0.9928) parser.add_argument('--momentum', type=float, default=0.9) - parser.add_argument('--weighted_entropy', type=bool, default = False) + parser.add_argument('--weighted_entropy', type=bool, default = True) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model_type', type=str, default='duo') - parser.add_argument('--dataset_dir', type=str, default='data/fused_data/clean_species') + parser.add_argument('--dataset_dir', type=str, default='data/Antibio/AMK') parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--pretrain_path', type=str, default=None) diff --git a/dataset/dataset.py b/dataset/dataset.py index b17eddeb64b8f5ff82394bac06df00c0cd4c5410..b3b0aec8765d329ead70aa20b7caaff436826e87 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -76,10 +76,10 @@ class Random_erasing2: for k in range(len(regions)): if pics_suppr[k]: try: - y1,x1,y2,x2 = regions[k].bbox + _,y1,x1,_,y2,x2 = regions[k].bbox except: raise Exception(regions[k].bbox) - x[y1:y2,x1:x2] *= regions[k].image== False + x[:,y1:y2,x1:x2] *= regions[k].image== False return x return x @@ -203,6 +203,7 @@ def make_dataset_custom( is_valid_file = cast(Callable[[str], bool], is_valid_file) instances = [] + indexes = [] available_classes = set() for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] @@ -219,6 +220,7 @@ def make_dataset_custom( if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): item = path_aer, path_ana, class_index instances.append(item) + indexes.append(class_index) if target_class not in available_classes: available_classes.add(target_class) @@ -229,19 +231,19 @@ def make_dataset_custom( if extensions is not None: msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" raise FileNotFoundError(msg) - - return instances + return instances,indexes class ImageFolderDuo(data.Dataset): def __init__(self, root, transform=None, target_transform=None, flist_reader=make_dataset_custom, loader=npy_loader): self.root = root - self.imlist = flist_reader(root) + self.imlist,self.lbllist = flist_reader(root) self.transform = transform self.target_transform = target_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] + raise Exception(self.classes) def __getitem__(self, index): impathAER, impathANA, target = self.imlist[index] @@ -278,9 +280,10 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True): train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) - classes_name = os.listdir(args.dataset_dir) - repart_weights = [name for name in classes_name for k in range(len(os.listdir(args.dataset_dir+"/"+name))//2)] - iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=args.random_state) + # classes_name = os.listdir(args.dataset_dir) + # repart_weights = np.array([name for name in classes_name for k in range(len(os.listdir(args.dataset_dir+"/"+name))//2)]) + # args.test = repart_weights + iTrain,iVal,_,_ = train_test_split(np.array((range(len(train_dataset)))), train_dataset.lbllist,test_size = 0.2, shuffle = shuffle, stratify = train_dataset.lbllist, random_state=args.random_state) train_dataset = torch.utils.data.Subset(train_dataset, iTrain) val_dataset = torch.utils.data.Subset(val_dataset, iVal) # generator1 = torch.Generator().manual_seed(42) diff --git a/main.py b/main.py index d01f86cc9a25929f8cc1f7bedc89e3863e34e33a..e1c2b614eba04d513bb8540e1d3b2570fe47247a 100644 --- a/main.py +++ b/main.py @@ -222,7 +222,8 @@ def run_duo(args): val_loss=[] #init training if args.weighted_entropy: - loss_weights = torch.tensor([2/len(listdir(args.dataset_dir+"/"+class_name)) for class_name in listdir(args.dataset_dir)]) + #Problème ! + loss_weights = torch.tensor([1.*len(listdir(args.dataset_dir+"/"+class_name)) for class_name in listdir(args.dataset_dir)]) if torch.cuda.is_available(): loss_weights = loss_weights.cuda() loss_function = nn.CrossEntropyLoss(loss_weights) diff --git a/output/best_model.pt b/output/best_model.pt index 5d79fe1ad168f3fc79ee5d6bcf638c1c59db41a0..d7b6ddd2806e5a75ac671ff88a94a43ed669cdb3 100644 Binary files a/output/best_model.pt and b/output/best_model.pt differ diff --git a/output/confiance_matrix_species_clean_best_param.png b/output/confiance_matrix_species_clean_best_param.png index 66e01c75b008f3616fc1536f7c4537f4da4aa26f..380db4375ce50df7c1c4317450ff80c0590aa841 100644 Binary files a/output/confiance_matrix_species_clean_best_param.png and b/output/confiance_matrix_species_clean_best_param.png differ diff --git a/output/confusion_matrix_species_clean_best_param.png b/output/confusion_matrix_species_clean_best_param.png index 2c96e25224338d19992ee1a8aa308951f82c381d..b460f833cfd1d23f8a457bdb7bdaa7c3bacd853e 100644 Binary files a/output/confusion_matrix_species_clean_best_param.png and b/output/confusion_matrix_species_clean_best_param.png differ diff --git a/output/species_clean_best_param.png b/output/species_clean_best_param.png index 3c8a4c23bd3b2491464cae32581992b9bfb163e8..116beca9a3d5a4830b2baf8a8d1f8b221ac0add2 100644 Binary files a/output/species_clean_best_param.png and b/output/species_clean_best_param.png differ