Skip to content
Snippets Groups Projects
Commit 90037bf2 authored by Schneider Leo's avatar Schneider Leo
Browse files

model cuda loading

parent dfd088de
No related branches found
No related tags found
No related merge requests found
Showing
with 14 additions and 12 deletions
data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY18_AER.png

44.5 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY18_ANA.png

18.9 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY23_AER.png

68.8 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY23_ANA.png

47.4 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY31_AER.png

27.2 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY31_ANA.png

36.1 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY32_AER.png

24.9 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY32_ANA.png

17.1 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY36_ANA.png

32.4 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY37_AER.png

24.2 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY37_ANA.png

17.6 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY42_AER.png

22.1 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY42_ANA.png

13.8 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY7_AER.png

51.1 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY7_ANA.png

42.2 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY8_AER.png

23.9 KiB

data/processed_data/png_image/data_training/Klebsiella oxytoca/KLEOXY8_ANA.png

37.7 KiB

......@@ -38,7 +38,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
l = split_before_number(s)
species = l[0]
nb = l[1]
return '{}-{}-{}_100vW_100SPD.mzML'.format(species,nb,analyse)
return '{}-{}-{}-d200.mzML'.format(species,nb,analyse)
df['path_ana'] = df['sample_name'].map(lambda x: create_fname(x,analyse='ANA'))
df['path_aer'] = df['sample_name'].map(lambda x: create_fname(x, analyse='AER'))
......@@ -64,14 +64,16 @@ def create_dataset():
name = label[label['path_aer'] == path.split("/")[-1]]['sample_name'].values[0]
analyse = 'AER'
if species is not None:
directory_path = '../data/processed_data/{}'.format(species)
if not os.path.isdir(directory_path):
os.makedirs(directory_path)
directory_path_png = '../data/processed_data/png_image/{}'.format(species)
directory_path_npy = '../data/processed_data/npy_image/{}'.format(species)
if not os.path.isdir(directory_path_png):
os.makedirs(directory_path_png)
if not os.path.isdir(directory_path_npy):
os.makedirs(directory_path_npy)
mat = build_image_ms1(path, 1)
mpimg.imsave(directory_path + "/" + name + '_' + analyse + '.png', mat)
np.save(directory_path + "/" + name + '_' + analyse + '.npy', mat)
mpimg.imsave(directory_path_png + "/" + name + '_' + analyse + '.png', mat)
np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat)
#TODO : train val test split
if __name__ =='__main__' :
create_dataset()
\ No newline at end of file
......@@ -59,7 +59,10 @@ def test(model, data_test, loss_function, epoch):
return losses,acc
def run(args):
model = Classification_model(n_class=9)
data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
model = Classification_model(n_class=len(data_train.dataset.dataset.classes))
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
if torch.cuda.is_available():
model = model.cuda()
best_acc = 0
......@@ -67,9 +70,6 @@ def run(args):
train_loss=[]
val_acc=[]
val_loss=[]
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
......@@ -121,7 +121,7 @@ def make_prediction(model, data):
columns=[i for i in classes])
plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True)
plt.savefig('output.png')
plt.savefig('confusion_matrix.png')
def save_model(model, path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment