Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
O
outillage
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alice Brenon
outillage
Commits
a12c79c3
Commit
a12c79c3
authored
1 year ago
by
Alice Brenon
Browse files
Options
Downloads
Patches
Plain Diff
A nice little script to make predictions from an already fine-tuned BERT model
parent
f09a3cf4
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
scripts/ML/predict.py
+123
-0
123 additions, 0 deletions
scripts/ML/predict.py
with
123 additions
and
0 deletions
scripts/ML/predict.py
0 → 100644
+
123
−
0
View file @
a12c79c3
#!/usr/bin/env python3
import
numpy
import
pandas
import
pickle
import
sklearn
from
sys
import
argv
import
torch
from
tqdm
import
tqdm
from
transformers
import
BertForSequenceClassification
,
BertTokenizer
,
TextClassificationPipeline
class
Classifier
:
"""
A class wrapping all the different models and classes used throughout a
classification task:
- tokenizer
- classifier
- pipeline
- label encoder
Once created, it behaves as a function which you apply to a generator
containing the texts to classify
"""
def
__init__
(
self
,
root_path
):
self
.
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
self
.
_init_tokenizer
()
self
.
_init_model
(
root_path
)
self
.
_init_pipe
()
self
.
_init_encoder
(
f
"
{
root_path
}
/label_encoder.pkl
"
)
self
.
log
()
def
_init_model
(
self
,
path
):
bert
=
BertForSequenceClassification
.
from_pretrained
(
path
)
self
.
model
=
bert
.
to
(
self
.
device
.
type
)
def
_init_tokenizer
(
self
):
model_name
=
'
bert-base-multilingual-cased
'
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
model_name
)
def
_init_pipe
(
self
):
self
.
pipe
=
TextClassificationPipeline
(
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
return_all_scores
=
True
,
device
=
self
.
device
)
def
_init_encoder
(
self
,
path
):
with
open
(
path
,
'
rb
'
)
as
pickled
:
self
.
encoder
=
pickle
.
load
(
pickled
)
def
log
(
self
):
if
self
.
device
.
type
==
'
cpu
'
:
print
(
'
No GPU available, using the CPU instead.
'
)
else
:
print
(
'
We will use the GPU:
'
,
torch
.
cuda
.
get_device_name
(
0
))
def
__call__
(
self
,
text_generator
):
tokenizer_kwargs
=
{
'
padding
'
:
True
,
'
truncation
'
:
True
,
'
max_length
'
:
512
}
predictions
=
[]
for
output
in
tqdm
(
self
.
pipe
(
text_generator
,
**
tokenizer_kwargs
)):
byScoreDesc
=
sorted
(
output
,
key
=
lambda
d
:
d
[
'
score
'
],
reverse
=
True
)
predictions
.
append
([
int
(
byScoreDesc
[
0
][
'
label
'
][
6
:]),
byScoreDesc
[
0
][
'
score
'
],
int
(
byScoreDesc
[
1
][
'
label
'
][
6
:])])
predictions
=
numpy
.
array
(
predictions
)
return
list
(
self
.
encoder
.
inverse_transform
(
predictions
[:,
0
].
astype
(
int
)))
class
Source
:
"""
A class to handle the normalised path used in the project and loading the
actual text input as a generator from records when they are needed
"""
def
__init__
(
self
,
root_path
):
"""
Positional arguments
:param root_path: the path to a GÉODE-style folder containing the text
version of the corpus on which to predict the classes
"""
self
.
root_path
=
root_path
def
path_to
(
self
,
record
):
article_relative_path
=
"
{work}/T{volume}/{article}
"
.
format
(
**
record
)
prefix
=
f
"
{
self
.
root_path
}
/
{
article_relative_path
}
"
if
'
paragraph
'
in
record
:
return
f
"
{
prefix
}
/
{
record
.
paragraph
}
.txt
"
else
:
return
f
"
{
prefix
}
.txt
"
def
load_text
(
self
,
record
):
with
open
(
self
.
path_to
(
record
),
'
r
'
)
as
file
:
return
file
.
read
()
def
iterate
(
self
,
records
):
for
_
,
record
in
records
.
iterrows
():
yield
self
.
load_text
(
record
)
def
label
(
classify
,
source
,
tsv_path
,
name
=
'
label
'
):
"""
Make predictions on a set of document
Positional arguments
:param classify: an instance of the Classifier class above
:param source: an instance of the Source class above
:param tsv_path: the path to a TSV file containing (at least) article or
paragraph records (additional metadata will be ignored)
Keyword arguments
:param name: defaults to
'
label
'
— the name of the column to be created, that is
to say, the name of the category you are predicting with your model (if your
model labels in
"
Red
"
,
"
Green
"
, or
"
Blue
"
, you may want to use
`name=
'
color
'
`).
:return: a panda dataframe containing the records from the input TSV file plus
an additional column
"""
records
=
pandas
.
read_csv
(
tsv_path
,
sep
=
'
\t
'
)
records
[
name
]
=
classify
(
source
.
iterate
(
records
))
return
records
if
__name__
==
'
__main__
'
:
classify
=
Classifier
(
argv
[
1
])
source
=
Source
(
argv
[
2
])
label
(
classify
,
source
,
argv
[
3
]).
to_csv
(
argv
[
4
],
sep
=
'
\t
'
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment