From a93e5bccf909cfdd60e03616517e754d4bf82688 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 27 Mar 2025 16:15:54 +0100
Subject: [PATCH] add comments

---
 image_processing/build_dataset.py | 31 +++++++------------------------
 image_processing/build_image.py   |  9 ++++-----
 main.py                           | 22 +++++++++++++++++++---
 3 files changed, 30 insertions(+), 32 deletions(-)

diff --git a/image_processing/build_dataset.py b/image_processing/build_dataset.py
index 21e181d..d2820b2 100644
--- a/image_processing/build_dataset.py
+++ b/image_processing/build_dataset.py
@@ -98,7 +98,7 @@ antibiotic_enterrobacter_breakpoints = {
 
 def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx',suffix='-d200'):
     """
-    Extract and build file name corresponding to each sample
+    Extract and build file name corresponding to each sample and transform antioresistance measurements to labels
     :param path: excel path
     :return: dataframe
     """
@@ -114,8 +114,9 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
     'TIC (disk)','TIC (vitek)','TOB (disk)','TOB (vitek)','TZP (disk)','TZP (mic)','TZP (vitek)']]
 
     for test in antibiotic_tests :# S - Susceptible R - Resistant U- Uncertain
-        #convert to string and transform >8 to 8
+        #convert to string and transform (pex >8 to 8)
         df[test] = df[test].map(lambda x :float(str(x).replace('>','').replace('<','')))
+        #categorise each antibioresistance according to AST breakpoints table
         df[test+' cat']= 'NA'
         if 'mic' in test or 'vitek' in test :
             try :
@@ -123,6 +124,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
                 df.loc[df[test] >= antibiotic_enterrobacter_breakpoints[test]['R'], test + ' cat'] = 'R'
                 df.loc[(antibiotic_enterrobacter_breakpoints[test]['S'] < df[test]) & (df[test]  < antibiotic_enterrobacter_breakpoints[test]['R']), test + ' cat'] = 'U'
             except:
+                #for empty cells
                 pass
         elif 'disk' in test:
             try :
@@ -160,6 +162,7 @@ def create_dataset():
     for path in glob.glob("../data/raw_data/**.mzML"):
         print(path)
         species = None
+        #check if file exists in the label table
         if path.split("/")[-1] in label['path_ana'].values:
             species = label[label['path_ana'] == path.split("/")[-1]]['species'].values[0]
             name = label[label['path_ana'] == path.split("/")[-1]]['sample_name'].values[0]
@@ -168,7 +171,7 @@ def create_dataset():
             species = label[label['path_aer'] == path.split("/")[-1]]['species'].values[0]
             name = label[label['path_aer'] == path.split("/")[-1]]['sample_name'].values[0]
             analyse = 'AER'
-        if species is not None:
+        if species is not None: #save image in species specific dir
             directory_path_png = '../data/processed_data/png_image/{}'.format(species)
             directory_path_npy = '../data/processed_data/npy_image/{}'.format(species)
             if not os.path.isdir(directory_path_png):
@@ -179,6 +182,7 @@ def create_dataset():
             mpimg.imsave(directory_path_png + "/" + name + '_' + analyse + '.png', mat)
             np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat)
 
+    #reiterate for other kind of raw file
     label = create_antibio_dataset(suffix='_100vW_100SPD')
     for path in glob.glob("../data/raw_data/**.mzML"):
         print(path)
@@ -203,26 +207,5 @@ def create_dataset():
             np.save(directory_path_npy + "/" + name + '_' + analyse + '.npy', mat)
 
 
-def extract_antio_res_labels():
-    """
-    Extract and organise labels from raw excel file
-    :param
-    path: excel
-    path
-    :return: dataframe
-    """
-    path = '../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx'
-    df = pd.read_excel(path, header=1)
-    df = df[['sample_name','species','AMC (disk)','AMK (disk)','AMK (mic)','AMK (vitek)','AMP (vitek)','AMX (disk)',
-    'AMX (vitek)','ATM (disk)','ATM (vitek)','CAZ (disk)','CAZ (mic)','CAZ (vitek)','CHL (vitek)','CIP (disk)',
-    'CIP (vitek)','COL (disk)','COL (mic)','CRO (mic)','CRO (vitek)','CTX (disk)','CTX (mic)','CTX (vitek)',
-    'CXM (vitek)','CZA (disk)','CZA (vitek)','CZT (disk)','CZT (vitek)','ETP (disk)','ETP (mic)','ETP (vitek)',
-    'FEP (disk)','FEP (mic)','FEP (vitek)','FOS (disk)','FOX (disk)','FOX (vitek)','GEN (disk)','GEN (mic)',
-    'GEN (vitek)','IPM (disk)','IPM (mic)','IPM (vitek)','LTM (disk)','LVX (disk)','LVX (vitek)','MEC (disk)',
-    'MEM (disk)','MEM (mic)','MEM (vitek)','NAL (vitek)','NET (disk)','OFX (vitek)','PIP (vitek)','PRL (disk)',
-    'SXT (disk)','SXT (vitek)','TCC (disk)','TCC (vitek)','TEM (disk)','TEM (vitek)','TGC (disk)','TGC (vitek)',
-    'TIC (disk)','TIC (vitek)','TOB (disk)','TOB (vitek)','TZP (disk)','TZP (mic)','TZP (vitek)']]
-
-
 if __name__ =='__main__' :
     df = create_antibio_dataset()
\ No newline at end of file
diff --git a/image_processing/build_image.py b/image_processing/build_image.py
index 336940a..a2d2143 100644
--- a/image_processing/build_image.py
+++ b/image_processing/build_image.py
@@ -28,6 +28,7 @@ def plot_spectra_2d(exp, ms_level=1, marker_size=5, out_path='temp.png'):
 
 
 def build_image_ms1(path, bin_mz):
+    #load raw data
     e = oms.MSExperiment()
     oms.MzMLFile().load(path, e)
     e.updateRanges()
@@ -36,7 +37,7 @@ def build_image_ms1(path, bin_mz):
     dico = dict(s.split('=', 1) for s in id.split())
     max_cycle = int(dico['cycle'])
     list_cycle = [[] for _ in range(max_cycle)]
-
+    #get ms window size from first ms1 spectra (similar for all ms1 spectra)
     for s in e:
         if s.getMSLevel() == 1:
             ms1_start_mz = s.getInstrumentSettings().getScanWindows()[0].begin
@@ -47,6 +48,7 @@ def build_image_ms1(path, bin_mz):
     print('start',ms1_start_mz,'end',ms1_end_mz)
     n_bin_ms1 = int(total_ms1_mz//bin_mz)
     size_bin_ms1 = total_ms1_mz / n_bin_ms1
+    #organise sepctra by their MSlevel (only MS1 are kept)
     for spec in e:  # data structure
         id = spec.getNativeID()
         dico = dict(s.split('=', 1) for s in id.split())
@@ -54,16 +56,13 @@ def build_image_ms1(path, bin_mz):
             list_cycle[int(dico['cycle']) - 1].insert(0, spec)
 
     im = np.zeros([max_cycle, n_bin_ms1])
-
-    for c in range(max_cycle):  # Build one cycle image
+    for c in range(max_cycle):  # Build image line by line
         line = np.zeros(n_bin_ms1)
         if len(list_cycle[c]) > 0:
             for k in range(len(list_cycle[c])):
                 ms1 = list_cycle[c][k]
                 intensity = ms1.get_peaks()[1]
                 mz = ms1.get_peaks()[0]
-                id = ms1.getNativeID()
-                dico = dict(s.split('=', 1) for s in id.split())
                 for i in range(ms1.size()):
                     line[int((mz[i] - ms1_start_mz) // size_bin_ms1)] += intensity[i]
 
diff --git a/main.py b/main.py
index 30f7843..91a64ce 100644
--- a/main.py
+++ b/main.py
@@ -59,20 +59,26 @@ def test(model, data_test, loss_function, epoch):
     return losses,acc
 
 def run(args):
+    #load data
     data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
+    #load model
     model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes))
+    #load weights
     if args.pretrain_path is not None :
         load_model(model,args.pretrain_path)
+    #move parameters to GPU
     if torch.cuda.is_available():
         model = model.cuda()
+    #init accumulator
     best_acc = 0
     train_acc=[]
     train_loss=[]
     val_acc=[]
     val_loss=[]
+    #init training
     loss_function = nn.CrossEntropyLoss()
     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
-
+    #traing
     for e in range(args.epoches):
         loss, acc = train(model,data_train,optimizer,loss_function,e)
         train_loss.append(loss)
@@ -84,6 +90,7 @@ def run(args):
             if acc > best_acc :
                 save_model(model,args.save_path)
                 best_acc = acc
+    #plot and save training figs
     plt.plot(train_acc)
     plt.plot(val_acc)
     plt.plot(train_acc)
@@ -92,6 +99,7 @@ def run(args):
     plt.show()
     plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
 
+    #load and evaluated best model
     load_model(model, args.save_path)
     make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
 
@@ -175,21 +183,28 @@ def test_duo(model, data_test, loss_function, epoch):
     return losses,acc
 
 def run_duo(args):
+    #load data
     data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
+    #load model
     model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
     model.double()
+    #load weight
     if args.pretrain_path is not None :
         load_model(model,args.pretrain_path)
+    #move parameters to GPU
     if torch.cuda.is_available():
         model = model.cuda()
+
+    #init accumulators
     best_acc = 0
     train_acc=[]
     train_loss=[]
     val_acc=[]
     val_loss=[]
+    #init training
     loss_function = nn.CrossEntropyLoss()
     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
-
+    #train model
     for e in range(args.epoches):
         loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
         train_loss.append(loss)
@@ -201,6 +216,7 @@ def run_duo(args):
             if acc > best_acc :
                 save_model(model,args.save_path)
                 best_acc = acc
+    # plot and save training figs
     plt.plot(train_acc)
     plt.plot(val_acc)
     plt.plot(train_acc)
@@ -209,7 +225,7 @@ def run_duo(args):
     plt.show()
 
     plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
-
+    #load and evaluate best model
     load_model(model, args.save_path)
     make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
 
-- 
GitLab