Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pseudo_image
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Léo Calmettes
pseudo_image
Commits
a93e5bcc
Commit
a93e5bcc
authored
2 weeks ago
by
Schneider Leo
Browse files
Options
Downloads
Patches
Plain Diff
add comments
parent
92d24083
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
image_processing/build_dataset.py
+7
-24
7 additions, 24 deletions
image_processing/build_dataset.py
image_processing/build_image.py
+4
-5
4 additions, 5 deletions
image_processing/build_image.py
main.py
+19
-3
19 additions, 3 deletions
main.py
with
30 additions
and
32 deletions
image_processing/build_dataset.py
+
7
−
24
View file @
a93e5bcc
...
@@ -98,7 +98,7 @@ antibiotic_enterrobacter_breakpoints = {
...
@@ -98,7 +98,7 @@ antibiotic_enterrobacter_breakpoints = {
def
create_antibio_dataset
(
path
=
'
../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx
'
,
suffix
=
'
-d200
'
):
def
create_antibio_dataset
(
path
=
'
../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx
'
,
suffix
=
'
-d200
'
):
"""
"""
Extract and build file name corresponding to each sample
Extract and build file name corresponding to each sample
and transform antioresistance measurements to labels
:param path: excel path
:param path: excel path
:return: dataframe
:return: dataframe
"""
"""
...
@@ -114,8 +114,9 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
...
@@ -114,8 +114,9 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
'
TIC (disk)
'
,
'
TIC (vitek)
'
,
'
TOB (disk)
'
,
'
TOB (vitek)
'
,
'
TZP (disk)
'
,
'
TZP (mic)
'
,
'
TZP (vitek)
'
]]
'
TIC (disk)
'
,
'
TIC (vitek)
'
,
'
TOB (disk)
'
,
'
TOB (vitek)
'
,
'
TZP (disk)
'
,
'
TZP (mic)
'
,
'
TZP (vitek)
'
]]
for
test
in
antibiotic_tests
:
# S - Susceptible R - Resistant U- Uncertain
for
test
in
antibiotic_tests
:
# S - Susceptible R - Resistant U- Uncertain
#convert to string and transform >8 to 8
#convert to string and transform
(pex
>8 to 8
)
df
[
test
]
=
df
[
test
].
map
(
lambda
x
:
float
(
str
(
x
).
replace
(
'
>
'
,
''
).
replace
(
'
<
'
,
''
)))
df
[
test
]
=
df
[
test
].
map
(
lambda
x
:
float
(
str
(
x
).
replace
(
'
>
'
,
''
).
replace
(
'
<
'
,
''
)))
#categorise each antibioresistance according to AST breakpoints table
df
[
test
+
'
cat
'
]
=
'
NA
'
df
[
test
+
'
cat
'
]
=
'
NA
'
if
'
mic
'
in
test
or
'
vitek
'
in
test
:
if
'
mic
'
in
test
or
'
vitek
'
in
test
:
try
:
try
:
...
@@ -123,6 +124,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
...
@@ -123,6 +124,7 @@ def create_antibio_dataset(path='../data/label_raw/230804_strain_peptides_antibi
df
.
loc
[
df
[
test
]
>=
antibiotic_enterrobacter_breakpoints
[
test
][
'
R
'
],
test
+
'
cat
'
]
=
'
R
'
df
.
loc
[
df
[
test
]
>=
antibiotic_enterrobacter_breakpoints
[
test
][
'
R
'
],
test
+
'
cat
'
]
=
'
R
'
df
.
loc
[(
antibiotic_enterrobacter_breakpoints
[
test
][
'
S
'
]
<
df
[
test
])
&
(
df
[
test
]
<
antibiotic_enterrobacter_breakpoints
[
test
][
'
R
'
]),
test
+
'
cat
'
]
=
'
U
'
df
.
loc
[(
antibiotic_enterrobacter_breakpoints
[
test
][
'
S
'
]
<
df
[
test
])
&
(
df
[
test
]
<
antibiotic_enterrobacter_breakpoints
[
test
][
'
R
'
]),
test
+
'
cat
'
]
=
'
U
'
except
:
except
:
#for empty cells
pass
pass
elif
'
disk
'
in
test
:
elif
'
disk
'
in
test
:
try
:
try
:
...
@@ -160,6 +162,7 @@ def create_dataset():
...
@@ -160,6 +162,7 @@ def create_dataset():
for
path
in
glob
.
glob
(
"
../data/raw_data/**.mzML
"
):
for
path
in
glob
.
glob
(
"
../data/raw_data/**.mzML
"
):
print
(
path
)
print
(
path
)
species
=
None
species
=
None
#check if file exists in the label table
if
path
.
split
(
"
/
"
)[
-
1
]
in
label
[
'
path_ana
'
].
values
:
if
path
.
split
(
"
/
"
)[
-
1
]
in
label
[
'
path_ana
'
].
values
:
species
=
label
[
label
[
'
path_ana
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
species
'
].
values
[
0
]
species
=
label
[
label
[
'
path_ana
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
species
'
].
values
[
0
]
name
=
label
[
label
[
'
path_ana
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
sample_name
'
].
values
[
0
]
name
=
label
[
label
[
'
path_ana
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
sample_name
'
].
values
[
0
]
...
@@ -168,7 +171,7 @@ def create_dataset():
...
@@ -168,7 +171,7 @@ def create_dataset():
species
=
label
[
label
[
'
path_aer
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
species
'
].
values
[
0
]
species
=
label
[
label
[
'
path_aer
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
species
'
].
values
[
0
]
name
=
label
[
label
[
'
path_aer
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
sample_name
'
].
values
[
0
]
name
=
label
[
label
[
'
path_aer
'
]
==
path
.
split
(
"
/
"
)[
-
1
]][
'
sample_name
'
].
values
[
0
]
analyse
=
'
AER
'
analyse
=
'
AER
'
if
species
is
not
None
:
if
species
is
not
None
:
#save image in species specific dir
directory_path_png
=
'
../data/processed_data/png_image/{}
'
.
format
(
species
)
directory_path_png
=
'
../data/processed_data/png_image/{}
'
.
format
(
species
)
directory_path_npy
=
'
../data/processed_data/npy_image/{}
'
.
format
(
species
)
directory_path_npy
=
'
../data/processed_data/npy_image/{}
'
.
format
(
species
)
if
not
os
.
path
.
isdir
(
directory_path_png
):
if
not
os
.
path
.
isdir
(
directory_path_png
):
...
@@ -179,6 +182,7 @@ def create_dataset():
...
@@ -179,6 +182,7 @@ def create_dataset():
mpimg
.
imsave
(
directory_path_png
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.png
'
,
mat
)
mpimg
.
imsave
(
directory_path_png
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.png
'
,
mat
)
np
.
save
(
directory_path_npy
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.npy
'
,
mat
)
np
.
save
(
directory_path_npy
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.npy
'
,
mat
)
#reiterate for other kind of raw file
label
=
create_antibio_dataset
(
suffix
=
'
_100vW_100SPD
'
)
label
=
create_antibio_dataset
(
suffix
=
'
_100vW_100SPD
'
)
for
path
in
glob
.
glob
(
"
../data/raw_data/**.mzML
"
):
for
path
in
glob
.
glob
(
"
../data/raw_data/**.mzML
"
):
print
(
path
)
print
(
path
)
...
@@ -203,26 +207,5 @@ def create_dataset():
...
@@ -203,26 +207,5 @@ def create_dataset():
np
.
save
(
directory_path_npy
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.npy
'
,
mat
)
np
.
save
(
directory_path_npy
+
"
/
"
+
name
+
'
_
'
+
analyse
+
'
.npy
'
,
mat
)
def
extract_antio_res_labels
():
"""
Extract and organise labels from raw excel file
:param
path: excel
path
:return: dataframe
"""
path
=
'
../data/label_raw/230804_strain_peptides_antibiogram_Enterobacterales.xlsx
'
df
=
pd
.
read_excel
(
path
,
header
=
1
)
df
=
df
[[
'
sample_name
'
,
'
species
'
,
'
AMC (disk)
'
,
'
AMK (disk)
'
,
'
AMK (mic)
'
,
'
AMK (vitek)
'
,
'
AMP (vitek)
'
,
'
AMX (disk)
'
,
'
AMX (vitek)
'
,
'
ATM (disk)
'
,
'
ATM (vitek)
'
,
'
CAZ (disk)
'
,
'
CAZ (mic)
'
,
'
CAZ (vitek)
'
,
'
CHL (vitek)
'
,
'
CIP (disk)
'
,
'
CIP (vitek)
'
,
'
COL (disk)
'
,
'
COL (mic)
'
,
'
CRO (mic)
'
,
'
CRO (vitek)
'
,
'
CTX (disk)
'
,
'
CTX (mic)
'
,
'
CTX (vitek)
'
,
'
CXM (vitek)
'
,
'
CZA (disk)
'
,
'
CZA (vitek)
'
,
'
CZT (disk)
'
,
'
CZT (vitek)
'
,
'
ETP (disk)
'
,
'
ETP (mic)
'
,
'
ETP (vitek)
'
,
'
FEP (disk)
'
,
'
FEP (mic)
'
,
'
FEP (vitek)
'
,
'
FOS (disk)
'
,
'
FOX (disk)
'
,
'
FOX (vitek)
'
,
'
GEN (disk)
'
,
'
GEN (mic)
'
,
'
GEN (vitek)
'
,
'
IPM (disk)
'
,
'
IPM (mic)
'
,
'
IPM (vitek)
'
,
'
LTM (disk)
'
,
'
LVX (disk)
'
,
'
LVX (vitek)
'
,
'
MEC (disk)
'
,
'
MEM (disk)
'
,
'
MEM (mic)
'
,
'
MEM (vitek)
'
,
'
NAL (vitek)
'
,
'
NET (disk)
'
,
'
OFX (vitek)
'
,
'
PIP (vitek)
'
,
'
PRL (disk)
'
,
'
SXT (disk)
'
,
'
SXT (vitek)
'
,
'
TCC (disk)
'
,
'
TCC (vitek)
'
,
'
TEM (disk)
'
,
'
TEM (vitek)
'
,
'
TGC (disk)
'
,
'
TGC (vitek)
'
,
'
TIC (disk)
'
,
'
TIC (vitek)
'
,
'
TOB (disk)
'
,
'
TOB (vitek)
'
,
'
TZP (disk)
'
,
'
TZP (mic)
'
,
'
TZP (vitek)
'
]]
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
df
=
create_antibio_dataset
()
df
=
create_antibio_dataset
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
image_processing/build_image.py
+
4
−
5
View file @
a93e5bcc
...
@@ -28,6 +28,7 @@ def plot_spectra_2d(exp, ms_level=1, marker_size=5, out_path='temp.png'):
...
@@ -28,6 +28,7 @@ def plot_spectra_2d(exp, ms_level=1, marker_size=5, out_path='temp.png'):
def
build_image_ms1
(
path
,
bin_mz
):
def
build_image_ms1
(
path
,
bin_mz
):
#load raw data
e
=
oms
.
MSExperiment
()
e
=
oms
.
MSExperiment
()
oms
.
MzMLFile
().
load
(
path
,
e
)
oms
.
MzMLFile
().
load
(
path
,
e
)
e
.
updateRanges
()
e
.
updateRanges
()
...
@@ -36,7 +37,7 @@ def build_image_ms1(path, bin_mz):
...
@@ -36,7 +37,7 @@ def build_image_ms1(path, bin_mz):
dico
=
dict
(
s
.
split
(
'
=
'
,
1
)
for
s
in
id
.
split
())
dico
=
dict
(
s
.
split
(
'
=
'
,
1
)
for
s
in
id
.
split
())
max_cycle
=
int
(
dico
[
'
cycle
'
])
max_cycle
=
int
(
dico
[
'
cycle
'
])
list_cycle
=
[[]
for
_
in
range
(
max_cycle
)]
list_cycle
=
[[]
for
_
in
range
(
max_cycle
)]
#get ms window size from first ms1 spectra (similar for all ms1 spectra)
for
s
in
e
:
for
s
in
e
:
if
s
.
getMSLevel
()
==
1
:
if
s
.
getMSLevel
()
==
1
:
ms1_start_mz
=
s
.
getInstrumentSettings
().
getScanWindows
()[
0
].
begin
ms1_start_mz
=
s
.
getInstrumentSettings
().
getScanWindows
()[
0
].
begin
...
@@ -47,6 +48,7 @@ def build_image_ms1(path, bin_mz):
...
@@ -47,6 +48,7 @@ def build_image_ms1(path, bin_mz):
print
(
'
start
'
,
ms1_start_mz
,
'
end
'
,
ms1_end_mz
)
print
(
'
start
'
,
ms1_start_mz
,
'
end
'
,
ms1_end_mz
)
n_bin_ms1
=
int
(
total_ms1_mz
//
bin_mz
)
n_bin_ms1
=
int
(
total_ms1_mz
//
bin_mz
)
size_bin_ms1
=
total_ms1_mz
/
n_bin_ms1
size_bin_ms1
=
total_ms1_mz
/
n_bin_ms1
#organise sepctra by their MSlevel (only MS1 are kept)
for
spec
in
e
:
# data structure
for
spec
in
e
:
# data structure
id
=
spec
.
getNativeID
()
id
=
spec
.
getNativeID
()
dico
=
dict
(
s
.
split
(
'
=
'
,
1
)
for
s
in
id
.
split
())
dico
=
dict
(
s
.
split
(
'
=
'
,
1
)
for
s
in
id
.
split
())
...
@@ -54,16 +56,13 @@ def build_image_ms1(path, bin_mz):
...
@@ -54,16 +56,13 @@ def build_image_ms1(path, bin_mz):
list_cycle
[
int
(
dico
[
'
cycle
'
])
-
1
].
insert
(
0
,
spec
)
list_cycle
[
int
(
dico
[
'
cycle
'
])
-
1
].
insert
(
0
,
spec
)
im
=
np
.
zeros
([
max_cycle
,
n_bin_ms1
])
im
=
np
.
zeros
([
max_cycle
,
n_bin_ms1
])
for
c
in
range
(
max_cycle
):
# Build image line by line
for
c
in
range
(
max_cycle
):
# Build one cycle image
line
=
np
.
zeros
(
n_bin_ms1
)
line
=
np
.
zeros
(
n_bin_ms1
)
if
len
(
list_cycle
[
c
])
>
0
:
if
len
(
list_cycle
[
c
])
>
0
:
for
k
in
range
(
len
(
list_cycle
[
c
])):
for
k
in
range
(
len
(
list_cycle
[
c
])):
ms1
=
list_cycle
[
c
][
k
]
ms1
=
list_cycle
[
c
][
k
]
intensity
=
ms1
.
get_peaks
()[
1
]
intensity
=
ms1
.
get_peaks
()[
1
]
mz
=
ms1
.
get_peaks
()[
0
]
mz
=
ms1
.
get_peaks
()[
0
]
id
=
ms1
.
getNativeID
()
dico
=
dict
(
s
.
split
(
'
=
'
,
1
)
for
s
in
id
.
split
())
for
i
in
range
(
ms1
.
size
()):
for
i
in
range
(
ms1
.
size
()):
line
[
int
((
mz
[
i
]
-
ms1_start_mz
)
//
size_bin_ms1
)]
+=
intensity
[
i
]
line
[
int
((
mz
[
i
]
-
ms1_start_mz
)
//
size_bin_ms1
)]
+=
intensity
[
i
]
...
...
This diff is collapsed.
Click to expand it.
main.py
+
19
−
3
View file @
a93e5bcc
...
@@ -59,20 +59,26 @@ def test(model, data_test, loss_function, epoch):
...
@@ -59,20 +59,26 @@ def test(model, data_test, loss_function, epoch):
return
losses
,
acc
return
losses
,
acc
def
run
(
args
):
def
run
(
args
):
#load data
data_train
,
data_test
=
load_data
(
base_dir
=
args
.
dataset_dir
,
batch_size
=
args
.
batch_size
)
data_train
,
data_test
=
load_data
(
base_dir
=
args
.
dataset_dir
,
batch_size
=
args
.
batch_size
)
#load model
model
=
Classification_model
(
model
=
args
.
model
,
n_class
=
len
(
data_train
.
dataset
.
dataset
.
classes
))
model
=
Classification_model
(
model
=
args
.
model
,
n_class
=
len
(
data_train
.
dataset
.
dataset
.
classes
))
#load weights
if
args
.
pretrain_path
is
not
None
:
if
args
.
pretrain_path
is
not
None
:
load_model
(
model
,
args
.
pretrain_path
)
load_model
(
model
,
args
.
pretrain_path
)
#move parameters to GPU
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
model
=
model
.
cuda
()
#init accumulator
best_acc
=
0
best_acc
=
0
train_acc
=
[]
train_acc
=
[]
train_loss
=
[]
train_loss
=
[]
val_acc
=
[]
val_acc
=
[]
val_loss
=
[]
val_loss
=
[]
#init training
loss_function
=
nn
.
CrossEntropyLoss
()
loss_function
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
#traing
for
e
in
range
(
args
.
epoches
):
for
e
in
range
(
args
.
epoches
):
loss
,
acc
=
train
(
model
,
data_train
,
optimizer
,
loss_function
,
e
)
loss
,
acc
=
train
(
model
,
data_train
,
optimizer
,
loss_function
,
e
)
train_loss
.
append
(
loss
)
train_loss
.
append
(
loss
)
...
@@ -84,6 +90,7 @@ def run(args):
...
@@ -84,6 +90,7 @@ def run(args):
if
acc
>
best_acc
:
if
acc
>
best_acc
:
save_model
(
model
,
args
.
save_path
)
save_model
(
model
,
args
.
save_path
)
best_acc
=
acc
best_acc
=
acc
#plot and save training figs
plt
.
plot
(
train_acc
)
plt
.
plot
(
train_acc
)
plt
.
plot
(
val_acc
)
plt
.
plot
(
val_acc
)
plt
.
plot
(
train_acc
)
plt
.
plot
(
train_acc
)
...
@@ -92,6 +99,7 @@ def run(args):
...
@@ -92,6 +99,7 @@ def run(args):
plt
.
show
()
plt
.
show
()
plt
.
savefig
(
'
output/training_plot_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
plt
.
savefig
(
'
output/training_plot_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
#load and evaluated best model
load_model
(
model
,
args
.
save_path
)
load_model
(
model
,
args
.
save_path
)
make_prediction
(
model
,
data_test
,
'
output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
make_prediction
(
model
,
data_test
,
'
output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
...
@@ -175,21 +183,28 @@ def test_duo(model, data_test, loss_function, epoch):
...
@@ -175,21 +183,28 @@ def test_duo(model, data_test, loss_function, epoch):
return
losses
,
acc
return
losses
,
acc
def
run_duo
(
args
):
def
run_duo
(
args
):
#load data
data_train
,
data_test
=
load_data_duo
(
base_dir
=
args
.
dataset_dir
,
batch_size
=
args
.
batch_size
)
data_train
,
data_test
=
load_data_duo
(
base_dir
=
args
.
dataset_dir
,
batch_size
=
args
.
batch_size
)
#load model
model
=
Classification_model_duo
(
model
=
args
.
model
,
n_class
=
len
(
data_train
.
dataset
.
dataset
.
classes
))
model
=
Classification_model_duo
(
model
=
args
.
model
,
n_class
=
len
(
data_train
.
dataset
.
dataset
.
classes
))
model
.
double
()
model
.
double
()
#load weight
if
args
.
pretrain_path
is
not
None
:
if
args
.
pretrain_path
is
not
None
:
load_model
(
model
,
args
.
pretrain_path
)
load_model
(
model
,
args
.
pretrain_path
)
#move parameters to GPU
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
model
=
model
.
cuda
()
#init accumulators
best_acc
=
0
best_acc
=
0
train_acc
=
[]
train_acc
=
[]
train_loss
=
[]
train_loss
=
[]
val_acc
=
[]
val_acc
=
[]
val_loss
=
[]
val_loss
=
[]
#init training
loss_function
=
nn
.
CrossEntropyLoss
()
loss_function
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
#train model
for
e
in
range
(
args
.
epoches
):
for
e
in
range
(
args
.
epoches
):
loss
,
acc
=
train_duo
(
model
,
data_train
,
optimizer
,
loss_function
,
e
)
loss
,
acc
=
train_duo
(
model
,
data_train
,
optimizer
,
loss_function
,
e
)
train_loss
.
append
(
loss
)
train_loss
.
append
(
loss
)
...
@@ -201,6 +216,7 @@ def run_duo(args):
...
@@ -201,6 +216,7 @@ def run_duo(args):
if
acc
>
best_acc
:
if
acc
>
best_acc
:
save_model
(
model
,
args
.
save_path
)
save_model
(
model
,
args
.
save_path
)
best_acc
=
acc
best_acc
=
acc
# plot and save training figs
plt
.
plot
(
train_acc
)
plt
.
plot
(
train_acc
)
plt
.
plot
(
val_acc
)
plt
.
plot
(
val_acc
)
plt
.
plot
(
train_acc
)
plt
.
plot
(
train_acc
)
...
@@ -209,7 +225,7 @@ def run_duo(args):
...
@@ -209,7 +225,7 @@ def run_duo(args):
plt
.
show
()
plt
.
show
()
plt
.
savefig
(
'
output/training_plot_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
plt
.
savefig
(
'
output/training_plot_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
#load and evaluate best model
load_model
(
model
,
args
.
save_path
)
load_model
(
model
,
args
.
save_path
)
make_prediction_duo
(
model
,
data_test
,
'
output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
make_prediction_duo
(
model
,
data_test
,
'
output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png
'
.
format
(
args
.
noise_threshold
,
args
.
lr
,
args
.
model
,
args
.
model_type
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment