Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
Peptide Detectability
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Léo Schneider
Peptide Detectability
Commits
89df63ba
Commit
89df63ba
authored
2 months ago
by
Schneider Leo
Browse files
Options
Downloads
Patches
Plain Diff
astral dataset
parent
7537263a
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
dataset_comparison.py
+1
-1
1 addition, 1 deletion
dataset_comparison.py
dataset_extraction.py
+60
-6
60 additions, 6 deletions
dataset_extraction.py
main_fine_tune.py
+87
-18
87 additions, 18 deletions
main_fine_tune.py
with
148 additions
and
25 deletions
dataset_comparison.py
+
1
−
1
View file @
89df63ba
...
@@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict
...
@@ -3,7 +3,7 @@ from datasets import load_dataset, DatasetDict
df_list
=
[
"
Wilhelmlab/detectability-proteometools
"
,
"
Wilhelmlab/detectability-wang
"
,
"
Wilhelmlab/detectability-sinitcyn
"
]
df_list
=
[
"
Wilhelmlab/detectability-proteometools
"
,
"
Wilhelmlab/detectability-wang
"
,
"
Wilhelmlab/detectability-sinitcyn
"
]
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_f
inetune
_no_miscleavage.csv
'
)
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_f
lyer
_no_miscleavage.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
)
for
label_type
in
[
'
Classes fragment
'
,
'
Classes precursor
'
,
'
Classes MaxLFQ
'
]
:
for
label_type
in
[
'
Classes fragment
'
,
'
Classes precursor
'
,
'
Classes MaxLFQ
'
]
:
...
...
This diff is collapsed.
Click to expand it.
dataset_extraction.py
+
60
−
6
View file @
89df63ba
...
@@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"}
...
@@ -11,7 +11,7 @@ binary_labels = {0: "Non-Flyer", 1: "Flyer"}
"""
"""
def
build_dataset
(
intensity_col
=
'
Fragment.Quant.Raw
'
,
coverage_treshold
=
20
,
min_peptide
=
4
,
f_name
=
'
out_df.csv
'
):
def
build_dataset
(
coverage_treshold
=
20
,
min_peptide
=
4
,
f_name
=
'
out_df.csv
'
):
df
=
pd
.
read_excel
(
'
ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx
'
)
df
=
pd
.
read_excel
(
'
ISA_data/250326_gut_microbiome_std_17_proteomes_data_training_detectability.xlsx
'
)
df_non_flyer
=
pd
.
read_csv
(
'
ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv
'
)
df_non_flyer
=
pd
.
read_csv
(
'
ISA_data/250422_FASTA_17_proteomes_gut_std_ozyme_+_conta_peptides_digested_filtered.csv
'
)
#No flyer
#No flyer
...
@@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
...
@@ -38,12 +38,12 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
dico_final
=
{}
dico_final
=
{}
# iterate over each group
# iterate over each group
for
group_name
,
df_group
in
df1_grouped
:
for
group_name
,
df_group
in
df1_grouped
:
seq
=
df_group
.
sort_values
(
by
=
[
intensity_col
])[
'
Stripped.Sequence
'
].
to_list
()
seq
=
df_group
.
sort_values
(
by
=
[
'
Fragment.Quant.Raw
'
])[
'
Stripped.Sequence
'
].
to_list
()
value_frag
=
df_group
.
sort_values
(
by
=
[
intensity_col
])[
intensity_col
].
to_list
()
value_frag
=
df_group
.
sort_values
(
by
=
[
'
Fragment.Quant.Raw
'
])[
'
Fragment.Quant.Raw
'
].
to_list
()
value_prec
=
df_group
.
sort_values
(
by
=
[
'
Precursor.Quantity
'
])[
'
Precursor.Quantity
'
].
to_list
()
value_prec
=
df_group
.
sort_values
(
by
=
[
'
Precursor.Quantity
'
])[
'
Precursor.Quantity
'
].
to_list
()
value_prec_frag
=
df_group
.
sort_values
(
by
=
[
intensity_col
])[
'
Precursor.Quantity
'
].
to_list
()
value_prec_frag
=
df_group
.
sort_values
(
by
=
[
'
Fragment.Quant.Raw
'
])[
'
Precursor.Quantity
'
].
to_list
()
value_maxlfq
=
df_group
.
sort_values
(
by
=
[
'
MaxLFQ
'
])[
'
MaxLFQ
'
].
to_list
()
value_maxlfq
=
df_group
.
sort_values
(
by
=
[
'
MaxLFQ
'
])[
'
MaxLFQ
'
].
to_list
()
value_maxlfq_frag
=
df_group
.
sort_values
(
by
=
[
intensity_col
])[
'
MaxLFQ
'
].
to_list
()
value_maxlfq_frag
=
df_group
.
sort_values
(
by
=
[
'
Fragment.Quant.Raw
'
])[
'
MaxLFQ
'
].
to_list
()
threshold_weak_flyer_frag
=
value_frag
[
int
(
len
(
seq
)
/
3
)]
threshold_weak_flyer_frag
=
value_frag
[
int
(
len
(
seq
)
/
3
)]
threshold_medium_flyer_frag
=
value_frag
[
int
(
2
*
len
(
seq
)
/
3
)]
threshold_medium_flyer_frag
=
value_frag
[
int
(
2
*
len
(
seq
)
/
3
)]
threshold_weak_flyer_prec
=
value_prec
[
int
(
len
(
seq
)
/
3
)]
threshold_weak_flyer_prec
=
value_prec
[
int
(
len
(
seq
)
/
3
)]
...
@@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
...
@@ -82,5 +82,59 @@ def build_dataset(intensity_col = 'Fragment.Quant.Raw',coverage_treshold = 20, m
df_final
=
df_final
[[
'
Sequences
'
,
'
Proteins
'
,
'
Classes fragment
'
,
'
Classes precursor
'
,
'
Classes MaxLFQ
'
]]
df_final
=
df_final
[[
'
Sequences
'
,
'
Proteins
'
,
'
Classes fragment
'
,
'
Classes precursor
'
,
'
Classes MaxLFQ
'
]]
df_final
.
to_csv
(
f_name
,
index
=
False
)
df_final
.
to_csv
(
f_name
,
index
=
False
)
df_non_flyer
.
to_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
,
index
=
False
)
df_non_flyer
.
to_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
,
index
=
False
)
def
build_dataset_astral
(
coverage_treshold
=
20
,
min_peptide
=
4
,
f_name
=
'
out_df.csv
'
):
df
=
pd
.
read_excel
(
'
ISA_data/250505_Flyers_ASTRAL_mix_12_species.xlsx
'
)
df_non_flyer
=
pd
.
read_excel
(
'
ISA_data/250505_Non_flyers_ASTRAL_mix_12_species.xlsx
'
)
#No flyer
df_non_flyer
=
df_non_flyer
[
df_non_flyer
[
'
Cystein ?
'
]
==
0
]
df_non_flyer
=
df_non_flyer
[
pd
.
isna
(
df_non_flyer
[
'
Miscleavage ?
'
])]
df_non_flyer
=
df_non_flyer
[
pd
.
isna
(
df_non_flyer
[
'
MaxLFQ
'
])]
df_non_flyer
[
'
Sequences
'
]
=
df_non_flyer
[
'
Peptide
'
]
df_non_flyer
[
'
Proteins
'
]
=
df_non_flyer
[
'
ProteinID
'
]
df_non_flyer
=
df_non_flyer
[[
'
Sequences
'
,
'
Proteins
'
]].
drop_duplicates
()
df_non_flyer
[
'
Classes MaxLFQ
'
]
=
0
#Flyer
df_filtered
=
df
[
~
(
pd
.
isna
(
df
[
'
Proteotypic ?
'
]))]
df_filtered
=
df_filtered
[
df_filtered
[
'
Coverage
'
]
>=
coverage_treshold
]
df_filtered
=
df_filtered
[
pd
.
isna
(
df_filtered
[
'
Miscleavage ?
'
])]
peptide_count
=
df_filtered
.
groupby
([
"
Protein.Names
"
]).
size
().
reset_index
(
name
=
'
counts
'
)
filtered_sequence
=
peptide_count
[
peptide_count
[
'
counts
'
]
>=
min_peptide
][
"
Protein.Names
"
]
df_filtered
=
df_filtered
[
df_filtered
[
"
Protein.Names
"
].
isin
(
filtered_sequence
.
to_list
())]
df1_grouped
=
df_filtered
.
groupby
(
"
Protein.Names
"
)
dico_final
=
{}
# iterate over each group
for
group_name
,
df_group
in
df1_grouped
:
seq
=
df_group
.
sort_values
(
by
=
[
'
20250129_ISA_MIX-1_48SPD_001
'
])[
'
Stripped.Sequence
'
].
to_list
()
value_maxlfq
=
df_group
.
sort_values
(
by
=
[
'
20250129_ISA_MIX-1_48SPD_001
'
])[
'
20250129_ISA_MIX-1_48SPD_001
'
].
to_list
()
value_maxlfq_frag
=
df_group
.
sort_values
(
by
=
[
'
20250129_ISA_MIX-1_48SPD_001
'
])[
'
20250129_ISA_MIX-1_48SPD_001
'
].
to_list
()
threshold_weak_flyer_maxflq
=
value_maxlfq
[
int
(
len
(
seq
)
/
3
)]
threshold_medium_flyer_maxlfq
=
value_maxlfq
[
int
(
2
*
len
(
seq
)
/
3
)]
prot
=
df_group
[
'
Protein.Group
'
].
to_list
()[
0
]
for
i
in
range
(
len
(
seq
)):
if
value_maxlfq_frag
[
i
]
<
threshold_weak_flyer_maxflq
:
label_maxlfq
=
1
elif
value_maxlfq_frag
[
i
]
<
threshold_medium_flyer_maxlfq
:
label_maxlfq
=
2
else
:
label_maxlfq
=
3
dico_final
[
seq
[
i
]]
=
(
prot
,
label_maxlfq
)
df_final
=
pd
.
DataFrame
.
from_dict
(
dico_final
,
orient
=
'
index
'
,
columns
=
[
'
Proteins
'
,
'
Classes MaxLFQ
'
])
df_final
[
'
Sequences
'
]
=
df_final
.
index
df_final
=
df_final
.
reset_index
()
df_final
=
df_final
[[
'
Sequences
'
,
'
Proteins
'
,
'
Classes MaxLFQ
'
]]
df_final
.
to_csv
(
'
ISA_data/df_flyer_no_miscleavage_astral_15.csv
'
,
index
=
False
)
df_non_flyer
.
to_csv
(
'
ISA_data/df_non_flyer_no_miscleavage_astral.csv
'
,
index
=
False
)
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
build_dataset
(
coverage_treshold
=
20
,
min_peptide
=
4
,
f_name
=
'
ISA_data/df_finetune_no_miscleavage.csv
'
)
build_dataset
_astral
(
coverage_treshold
=
20
,
min_peptide
=
15
)
This diff is collapsed.
Click to expand it.
main_fine_tune.py
+
87
−
18
View file @
89df63ba
...
@@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict
...
@@ -9,8 +9,8 @@ from datasets import load_dataset, DatasetDict
from
dlomix.reports.DetectabilityReport
import
DetectabilityReport
,
predictions_report
from
dlomix.reports.DetectabilityReport
import
DetectabilityReport
,
predictions_report
def
create_ISA_dataset
(
classe_type
=
'
Classes MaxLFQ
'
,
manual_seed
=
42
,
split
=
(
0.8
,
0.2
),
frac_no_fly_train
=
1
,
frac_no_fly_val
=
2
):
def
create_ISA_dataset
(
classe_type
=
'
Classes MaxLFQ
'
,
manual_seed
=
42
,
split
=
(
0.8
,
0.2
),
frac_no_fly_train
=
1
,
frac_no_fly_val
=
1
):
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_f
inetune
_no_miscleavage.csv
'
)
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_f
lyer
_no_miscleavage.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
)
df_no_flyer
[
'
Classes
'
]
=
df_no_flyer
[
classe_type
]
df_no_flyer
[
'
Classes
'
]
=
df_no_flyer
[
classe_type
]
df_no_flyer
=
df_no_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
df_no_flyer
=
df_no_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
...
@@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
...
@@ -28,13 +28,83 @@ def create_ISA_dataset(classe_type='Classes MaxLFQ', manual_seed = 42,split = (0
list_train_split
.
append
(
df_no_flyer
.
iloc
[:
int
(
class_count
*
split
[
0
]
*
frac_no_fly_train
),
:])
list_train_split
.
append
(
df_no_flyer
.
iloc
[:
int
(
class_count
*
split
[
0
]
*
frac_no_fly_train
),
:])
list_val_split
.
append
(
df_no_flyer
.
iloc
[
df_no_flyer
.
shape
[
0
]
-
int
(
class_count
*
split
[
1
]
*
frac_no_fly_val
):,
:])
list_val_split
.
append
(
df_no_flyer
.
iloc
[
df_no_flyer
.
shape
[
0
]
-
int
(
class_count
*
split
[
1
]
*
frac_no_fly_val
):,
:])
df_train
=
pd
.
concat
(
list_train_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
df_train
=
pd
.
concat
(
list_train_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_val
=
pd
.
concat
(
list_val_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
df_val
=
pd
.
concat
(
list_val_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_train
[
'
Proteins
'
]
=
0
df_train
[
'
Proteins
'
]
=
0
df_val
[
'
Proteins
'
]
=
0
df_val
[
'
Proteins
'
]
=
0
df_train
.
to_csv
(
'
temp_fine_tune_df_train.csv
'
,
index
=
False
)
df_train
.
to_csv
(
'
df_preprocessed/df_train_ISA.csv
'
,
index
=
False
)
df_val
.
to_csv
(
'
temp_fine_tune_df_val.csv
'
,
index
=
False
)
df_val
.
to_csv
(
'
df_preprocessed/df_val_ISA_multiclass.csv
'
,
index
=
False
)
def
create_astral_dataset
(
manual_seed
=
42
,
split
=
(
0.8
,
0.2
),
frac_no_fly_train
=
1
,
frac_no_fly_val
=
1
):
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_flyer_no_miscleavage_astral_15.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage_astral.csv
'
)
df_no_flyer
[
'
Classes
'
]
=
df_no_flyer
[
'
Classes MaxLFQ
'
]
df_no_flyer
=
df_no_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
df_flyer
[
'
Classes
'
]
=
df_flyer
[
'
Classes MaxLFQ
'
]
df_flyer
=
df_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
#stratified split
list_train_split
=
[]
list_val_split
=
[]
for
cl
in
[
1
,
2
,
3
]:
df_class
=
df_flyer
[
df_flyer
[
'
Classes
'
]
==
cl
]
class_count
=
df_class
.
shape
[
0
]
list_train_split
.
append
(
df_class
.
iloc
[:
int
(
class_count
*
split
[
0
]),:])
list_val_split
.
append
(
df_class
.
iloc
[
int
(
class_count
*
split
[
0
]):,
:])
list_train_split
.
append
(
df_no_flyer
.
iloc
[:
int
(
class_count
*
split
[
0
]
*
frac_no_fly_train
),
:])
list_val_split
.
append
(
df_no_flyer
.
iloc
[
df_no_flyer
.
shape
[
0
]
-
int
(
class_count
*
split
[
1
]
*
frac_no_fly_val
):,
:])
df_train
=
pd
.
concat
(
list_train_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_val
=
pd
.
concat
(
list_val_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_train
[
'
Proteins
'
]
=
0
df_val
[
'
Proteins
'
]
=
0
df_train
.
to_csv
(
'
df_preprocessed/df_train_astral_15.csv
'
,
index
=
False
)
df_val
.
to_csv
(
'
df_preprocessed/df_val_astral_multiclass_15.csv
'
,
index
=
False
)
def
create_combine_dataset
(
manual_seed
=
42
,
split
=
(
0.8
,
0.2
),
frac_no_fly_train
=
1
,
frac_no_fly_val
=
1
):
df_flyer_astral
=
pd
.
read_csv
(
'
ISA_data/df_flyer_no_miscleavage_astral_7.csv
'
)
df_no_flyer_astral
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage_astral.csv
'
)
df_no_flyer_astral
[
'
Classes
'
]
=
df_no_flyer_astral
[
'
Classes MaxLFQ
'
]
df_no_flyer_astral
=
df_no_flyer_astral
[[
'
Sequences
'
,
'
Classes
'
]]
df_flyer_astral
[
'
Classes
'
]
=
df_flyer_astral
[
'
Classes MaxLFQ
'
]
df_flyer_astral
=
df_flyer_astral
[[
'
Sequences
'
,
'
Classes
'
]]
df_flyer
=
pd
.
read_csv
(
'
ISA_data/df_flyer_no_miscleavage.csv
'
)
df_no_flyer
=
pd
.
read_csv
(
'
ISA_data/df_non_flyer_no_miscleavage.csv
'
)
df_no_flyer
[
'
Classes
'
]
=
df_no_flyer
[
'
Classes MaxLFQ
'
]
df_no_flyer
=
df_no_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
df_flyer
[
'
Classes
'
]
=
df_flyer
[
'
Classes MaxLFQ
'
]
df_flyer
=
df_flyer
[[
'
Sequences
'
,
'
Classes
'
]]
#stratified split
list_train_split
=
[]
list_val_split
=
[]
for
cl
in
[
1
,
2
,
3
]:
df_class_astral
=
df_flyer_astral
[
df_flyer_astral
[
'
Classes
'
]
==
cl
]
class_count_astral
=
df_class_astral
.
shape
[
0
]
df_class_ISA
=
df_flyer
[
df_flyer
[
'
Classes
'
]
==
cl
]
class_count_ISA
=
df_class_ISA
.
shape
[
0
]
list_train_split
.
append
(
df_class_astral
.
iloc
[:
int
(
class_count_astral
*
split
[
0
]),:])
list_val_split
.
append
(
df_class_astral
.
iloc
[
int
(
class_count_astral
*
split
[
0
]):,
:])
list_train_split
.
append
(
df_class_ISA
.
iloc
[:
int
(
class_count_ISA
*
split
[
0
]),
:])
list_val_split
.
append
(
df_class_ISA
.
iloc
[
int
(
class_count_ISA
*
split
[
0
]):,
:])
list_train_split
.
append
(
df_no_flyer_astral
.
iloc
[:
int
(
class_count_astral
*
split
[
0
]
*
frac_no_fly_train
),
:])
list_val_split
.
append
(
df_no_flyer_astral
.
iloc
[
df_no_flyer_astral
.
shape
[
0
]
-
int
(
class_count_astral
*
split
[
1
]
*
frac_no_fly_val
):,
:])
list_train_split
.
append
(
df_no_flyer
.
iloc
[:
int
(
class_count_ISA
*
split
[
0
]
*
frac_no_fly_train
),
:])
list_val_split
.
append
(
df_no_flyer
.
iloc
[
df_no_flyer
.
shape
[
0
]
-
int
(
class_count_ISA
*
split
[
1
]
*
frac_no_fly_val
):,
:])
df_train
=
pd
.
concat
(
list_train_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_val
=
pd
.
concat
(
list_val_split
).
sample
(
frac
=
1
,
random_state
=
manual_seed
)
#shuffle
df_train
[
'
Proteins
'
]
=
0
df_val
[
'
Proteins
'
]
=
0
df_train
.
to_csv
(
'
df_preprocessed/df_train_combined_7.csv
'
,
index
=
False
)
df_val
.
to_csv
(
'
df_preprocessed/df_val_combined_multiclass_7.csv
'
,
index
=
False
)
def
density_plot
(
prediction_path
,
prediction_path_2
,
criteria
=
'
base
'
):
def
density_plot
(
prediction_path
,
prediction_path_2
,
criteria
=
'
base
'
):
df
=
pd
.
read_csv
(
prediction_path
)
df
=
pd
.
read_csv
(
prediction_path
)
...
@@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'):
...
@@ -76,10 +146,8 @@ def density_plot(prediction_path,prediction_path_2,criteria='base'):
def
main
():
def
main
():
total_num_classes
=
len
(
CLASSES_LABELS
)
total_num_classes
=
len
(
CLASSES_LABELS
)
input_dimension
=
len
(
alphabet
)
num_cells
=
64
num_cells
=
64
fine_tuned_model
=
DetectabilityModel
(
num_units
=
num_cells
,
num_clases
=
total_num_classes
)
load_model_path
=
'
pretrained_model/original_detectability_fine_tuned_model_FINAL
'
load_model_path
=
'
pretrained_model/original_detectability_fine_tuned_model_FINAL
'
fine_tuned_model
=
DetectabilityModel
(
num_units
=
num_cells
,
fine_tuned_model
=
DetectabilityModel
(
num_units
=
num_cells
,
num_clases
=
total_num_classes
)
num_clases
=
total_num_classes
)
...
@@ -94,8 +162,8 @@ def main():
...
@@ -94,8 +162,8 @@ def main():
print
(
'
Initialising dataset
'
)
print
(
'
Initialising dataset
'
)
## Data init
## Data init
fine_tune_data
=
DetectabilityDataset
(
data_source
=
'
temp_fine_tune_df_train
.csv
'
,
fine_tune_data
=
DetectabilityDataset
(
data_source
=
'
df_preprocessed/df_train_combined_15
.csv
'
,
val_data_source
=
'
temp_fine_tune_df_val
.csv
'
,
val_data_source
=
'
df_preprocessed/df_val_combined_multiclass_15
.csv
'
,
data_format
=
'
csv
'
,
data_format
=
'
csv
'
,
max_seq_len
=
max_pep_length
,
max_seq_len
=
max_pep_length
,
label_column
=
"
Classes
"
,
label_column
=
"
Classes
"
,
...
@@ -114,7 +182,7 @@ def main():
...
@@ -114,7 +182,7 @@ def main():
verbose
=
1
,
verbose
=
1
,
patience
=
5
)
patience
=
5
)
model_save_path_FT
=
'
output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability
'
model_save_path_FT
=
'
output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability
_combined
'
model_checkpoint_FT
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
filepath
=
model_save_path_FT
,
model_checkpoint_FT
=
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
filepath
=
model_save_path_FT
,
monitor
=
'
val_loss
'
,
monitor
=
'
val_loss
'
,
...
@@ -131,13 +199,13 @@ def main():
...
@@ -131,13 +199,13 @@ def main():
history_fine_tuned
=
fine_tuned_model
.
fit
(
fine_tune_data
.
tensor_train_data
,
history_fine_tuned
=
fine_tuned_model
.
fit
(
fine_tune_data
.
tensor_train_data
,
validation_data
=
fine_tune_data
.
tensor_val_data
,
validation_data
=
fine_tune_data
.
tensor_val_data
,
epochs
=
1
,
epochs
=
1
50
,
callbacks
=
[
callback_FT
,
model_checkpoint_FT
])
callbacks
=
[
callback_FT
,
model_checkpoint_FT
])
## Loading best model weights
## Loading best model weights
#
model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability
'
model_save_path_FT
=
'
output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability
_combined
'
#model fined tuned on ISA data
model_save_path_FT
=
'
pretrained_model/original_detectability_fine_tuned_model_FINAL
'
#base model
#
model_save_path_FT = 'pretrained_model/original_detectability_fine_tuned_model_FINAL' #base model
fine_tuned_model
.
load_weights
(
model_save_path_FT
)
fine_tuned_model
.
load_weights
(
model_save_path_FT
)
...
@@ -170,16 +238,17 @@ def main():
...
@@ -170,16 +238,17 @@ def main():
report_FT
=
DetectabilityReport
(
test_targets_FT_one_hot
,
report_FT
=
DetectabilityReport
(
test_targets_FT_one_hot
,
predictions_FT
,
predictions_FT
,
test_data_df_FT
,
test_data_df_FT
,
output_path
=
'
./output/report_on_
ISA (Base model
categorical train,
binary
val
)
'
,
output_path
=
'
./output/report_on_
combined_15 (Fine tuned model (combined_10)
categorical train,
categorical
val)
'
,
history
=
history_fine_tuned
,
history
=
history_fine_tuned
,
rank_by_prot
=
True
,
rank_by_prot
=
True
,
threshold
=
None
,
threshold
=
None
,
name_of_dataset
=
'
ISA
val dataset (
binary
balanced)
'
,
name_of_dataset
=
'
combined_15
val dataset (
categorical
balanced)
'
,
name_of_model
=
'
Base model (ISA
)
'
)
name_of_model
=
'
Fine tuned model (combined_15
)
'
)
report_FT
.
generate_report
()
report_FT
.
generate_report
()
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
create_ISA_dataset
()
# create_astral_dataset()
# create_combine_dataset(frac_no_fly_val=1,frac_no_fly_train=1)
main
()
main
()
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
# density_plot('output/report_on_ISA (Base model)/Dectetability_prediction_report.csv','output/report_on_ISA (Fine-tuned model, half non flyer)/Dectetability_prediction_report.csv')
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment