Skip to content
Snippets Groups Projects
Commit 35f64483 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : new dataset

parent 80516349
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ def load_args(): ...@@ -20,6 +20,7 @@ def load_args():
parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None) parser.add_argument('--pretrain_path', type=str, default=None)
parser.add_argument('--pretrain_imgnet', action=argparse.BooleanOptionalAction, default=False) parser.add_argument('--pretrain_imgnet', action=argparse.BooleanOptionalAction, default=False)
parser.add_argument('--resolution', type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
return args return args
......
...@@ -194,18 +194,23 @@ def test_duo(model, data_test, loss_function, epoch): ...@@ -194,18 +194,23 @@ def test_duo(model, data_test, loss_function, epoch):
def run_duo(args): def run_duo(args):
#load data #load data
#data_train, data_test,_ = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, batch_size=args.batch_size) if args.resolution == 'max':
data_train, data_test = load_data_duo_patch(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, data_train, data_test = load_data_duo_patch(base_dir_train=args.dataset_train_dir,
base_dir_test=args.dataset_test_dir, batch_size=args.batch_size) base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir, batch_size=args.batch_size)
else :
data_train, data_test,_ = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, batch_size=args.batch_size)
#load model #load model
if not args.pretrain_imgnet: if not args.pretrain_imgnet:
#adaptative path model #adaptative path model
for imaer,imana, label in data_train: for imaer,imana, label in data_train:
img_shape = imaer.shape # batch x tile x img img_shape = imaer.shape # batch x tile x img
break break
model = Classification_model_duo_tile(model = args.model, n_class=len(data_train.dataset.classes),n_tile=img_shape[1]) if args.resolution == 'max' :
model = Classification_model_duo_tile(model = args.model, n_class=len(data_train.dataset.classes),n_tile=img_shape[1])
#model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes)) else :
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.classes))
else : else :
model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes)) model = Classification_model_duo_pretrained(model = args.model, n_class=len(data_train.dataset.classes))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment