Skip to content
Snippets Groups Projects
Commit 5e2d0620 authored by Schneider Leo's avatar Schneider Leo
Browse files

model cuda loading

parent b3a3c45a
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ def load_args():
parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt')
......
......@@ -60,7 +60,7 @@ def test(model, data_test, loss_function, epoch):
def run(args):
data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
model = Classification_model(n_class=len(data_train.dataset.dataset.classes))
model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes))
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
if torch.cuda.is_available():
......@@ -89,10 +89,10 @@ def run(args):
plt.plot(train_acc)
plt.plot(train_acc)
plt.show()
plt.savefig('output/training_plot_{}_.png'.format(args.output))
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
load_model(model, args.save_path)
make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output))
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
def make_prediction(model, data, f_name):
......
......@@ -262,10 +262,11 @@ def resnet152(num_classes=1000,**kwargs):
class Classification_model(nn.Module):
def __init__(self, n_class, *args, **kwargs):
def __init__(self, model, n_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_class = n_class
self.im_encoder = resnet18(num_classes=self.n_class)
if model =='Resnet18':
self.im_encoder = resnet18(num_classes=self.n_class)
def forward(self, input):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment