Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
f9d36ede
Commit
f9d36ede
authored
1 year ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Add setup for training
parent
ece9ede3
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
.gitmodules
+0
-3
0 additions, 3 deletions
.gitmodules
runs/test/config.json
+38
-0
38 additions, 0 deletions
runs/test/config.json
train_lit.py
+224
-225
224 additions, 225 deletions
train_lit.py
with
262 additions
and
228 deletions
.gitmodules
+
0
−
3
View file @
f9d36ede
[submodule "segment-anything"]
path = segment-anything
url = https://github.com/facebookresearch/segment-anything.git
This diff is collapsed.
Click to expand it.
runs/test/config.json
0 → 100644
+
38
−
0
View file @
f9d36ede
{
"data"
:
{
"dataset"
:
"clevr3d"
,
"num_points"
:
2000
,
"kwargs"
:
{
"downsample"
:
1
}
},
"model"
:{
"encoder"
:
"osrt"
,
"encoder_kwargs"
:
{
"pos_start_octave"
:
-5
,
"num_slots"
:
6
},
"decoder"
:
"slot_mixer"
,
"decoder_kwargs"
:{
"pos_start_octave"
:
-5
}
},
"training"
:{
"num_workers"
:
4
,
"batch_size"
:
64
,
"num_gpu"
:
8
,
"model_selection_metric"
:
"psnr"
,
"model_selection_mode"
:
"max"
,
"print_every"
:
10
,
"visualize_every"
:
5000
,
"validate_every"
:
5000
,
"checkpoint_every"
:
1000
,
"backup_every"
:
25000
,
"max_it"
:
333000000
,
"decay_it"
:
4000000
,
"lr_warmup"
:
5000
,
"precision"
:
"16-mixed"
,
"out_dir"
:
"."
}
}
\ No newline at end of file
This diff is collapsed.
Click to expand it.
train_lit.py
+
224
−
225
View file @
f9d36ede
"""
Code inspired
from Lit-Llama training script : https://github.com/Lightning-AI/lit-ll
am
a
/blob/main/
finetune/full
.py
Code inspired
and adapted from : https://github.com/luca-medeiros/lightning-s
am/blob/main/
lightning_sam/train
.py
"""
import
sys
from
pathlib
import
Path
import
os
import
time
from
functools
import
partial
import
json
import
argparse
import
math
import
lightning
as
L
from
lightning.fabric.strategies
import
FSDPStrategy
import
numpy
as
np
import
segmentation_models_pytorch
as
smp
import
torch
import
torch.nn.functional
as
F
from
lightning.fabric.fabric
import
_FabricOptimizer
from
lightning.fabric.loggers
import
TensorBoardLogger
from
torch.utils.data
import
DataLoader
from
torch.distributed.fsdp.wrap
import
transformer_auto_wrap_policy
from
jsonargparse.cli
import
CLI
import
json
# support running without installing as a package
wd
=
Path
(
__file__
).
parent
.
parent
.
resolve
()
sys
.
path
.
append
(
str
(
wd
))
from
osrt.model
import
OSRT
from
osrt.encoder
import
FeatureMasking
from
osrt
import
data
from
osrt.utils.training
import
AverageMeter
from
osrt.utils.losses
import
DiceLoss
,
FocalLoss
from
generate
import
generate
from
lit_llama.model
import
Block
,
LLaMA
,
LLaMAConfig
from
lit_llama.tokenizer
import
Tokenizer
from
lit_llama.utils
import
save_model_checkpoint
from
scripts.prepare_alpaca
import
generate_prompt
torch
.
set_float32_matmul_precision
(
'
high
'
)
from
osrt.layers
import
Transformer
from
osrt
import
data
from
osrt.model
import
OSRT
__LOG10
=
math
.
log
(
10
)
from
segment_anything.modeling.transformer
import
TwoWayTransformer
instruction_tuning
=
True
eval_interval
=
1000
save_interval
=
1000
eval_iters
=
100
log_interval
=
100
# Hyperparameters
learning_rate
=
3e-5
micro_batch_size
=
4
"""
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
"""
epoch_size
=
50000
# train dataset size
num_epochs
=
5
#max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
weight_decay
=
0.0
block_size
=
512
warmup_iters
=
100
class
LrScheduler
():
"""
Implements a learning rate schedule with warum up and decay
"""
def
__init__
(
self
,
peak_lr
=
4e-4
,
peak_it
=
10000
,
decay_rate
=
0.5
,
decay_it
=
100000
):
self
.
peak_lr
=
peak_lr
self
.
peak_it
=
peak_it
self
.
decay_rate
=
decay_rate
self
.
decay_it
=
decay_it
def
get_cur_lr
(
self
,
it
):
if
it
<
self
.
peak_it
:
# Warmup period
return
self
.
peak_lr
*
(
it
/
self
.
peak_it
)
it_since_peak
=
it
-
self
.
peak_it
return
self
.
peak_lr
*
(
self
.
decay_rate
**
(
it_since_peak
/
self
.
decay_it
))
def
main
(
config_path
:
str
,
data_dir
:
str
=
"
data/alpaca
"
,
out_dir
:
str
=
"
out/full/alpaca
"
,
checkpoint
:
str
=
None
def
validate
(
fabric
:
L
.
Fabric
,
model
:
OSRT
,
val_dataloader
:
DataLoader
,
epoch
:
int
=
0
):
# TODO : add segmentation also to select the model following how it's done in the training
model
.
eval
()
mses
=
AverageMeter
()
psnrs
=
AverageMeter
()
sceneids
=
[]
with
torch
.
no_grad
():
for
iter
,
data
in
enumerate
(
val_dataloader
):
sceneids
.
append
(
data
[
'
sceneid
'
])
input_images
=
data
.
get
(
'
input_images
'
)
input_camera_pos
=
data
.
get
(
'
input_camera_pos
'
)
input_rays
=
data
.
get
(
'
input_rays
'
)
target_pixels
=
data
.
get
(
'
target_pixels
'
)
if
isinstance
(
model
.
encoder
,
FeatureMasking
):
input_images
=
input_images
.
permute
(
0
,
1
,
3
,
4
,
2
)
# from [b, k, c, h, w] to [b, k, h, w, c]
h
,
w
,
c
=
input_images
[
0
][
0
].
shape
z
=
model
.
encoder
(
input_images
,(
h
,
w
),
input_camera_pos
,
input_rays
)
else
:
z
=
model
.
encoder
(
input_images
,
input_camera_pos
,
input_rays
)
target_camera_pos
=
data
.
get
(
'
target_camera_pos
'
)
target_rays
=
data
.
get
(
'
target_rays
'
)
loss_mse
=
torch
.
tensor
(
0.
,
device
=
fabric
.
device
)
pred_pixels
,
extras
=
model
.
decoder
(
z
,
target_camera_pos
,
target_rays
)
#, **self.render_kwargs)
### Compute MSE on pixels
loss_mse
=
loss_mse
+
((
pred_pixels
-
target_pixels
)
**
2
).
mean
((
1
,
2
))
psnr
=
-
10.
*
torch
.
log
(
loss_mse
)
/
__LOG10
mses
.
update
(
loss_mse
)
psnrs
.
update
(
psnr
)
fabric
.
print
(
f
"
Val [
{
epoch
}
] - [
{
iter
}
/
{
len
(
val_dataloader
)
}
] : psnr
{
psnr
}
, mse:
{
loss_mse
}
"
)
fabric
.
print
(
f
'
Validation [
{
epoch
}
]: Mean psnr: [
{
psnrs
.
avg
:
.
4
f
}
] -- Mean mse: [
{
mses
.
avg
:
.
4
f
}
]
'
)
fabric
.
print
(
f
"
Saving checkpoint to
{
cfg
.
out_dir
}
"
)
state_dict
=
model
.
state_dict
()
if
fabric
.
global_rank
==
0
:
torch
.
save
(
state_dict
,
os
.
path
.
join
(
cfg
.
out_dir
,
f
"
epoch-
{
epoch
:
06
d
}
-psnr
{
psnrs
.
avg
:
.
2
f
}
-mse
{
mses
.
avg
:
.
2
f
}
-ckpt.pth
"
))
model
.
train
()
def
train_sam
(
cfg
,
fabric
:
L
.
Fabric
,
model
:
OSRT
,
optimizer
:
_FabricOptimizer
,
scheduler
:
_FabricOptimizer
,
train_dataloader
:
DataLoader
,
val_dataloader
:
DataLoader
,
):
"""
The SAM training loop.
"""
focal_loss
=
FocalLoss
()
dice_loss
=
DiceLoss
()
nb_epochs
=
cfg
[
"
training
"
][
"
max_it
"
]
//
cfg
[
"
training
"
][
"
batch_size
"
]
for
epoch
in
range
(
1
,
nb_epochs
):
# TODO : add psnr loss ?
batch_time
=
AverageMeter
()
data_time
=
AverageMeter
()
focal_losses
=
AverageMeter
()
dice_losses
=
AverageMeter
()
mse_losses
=
AverageMeter
()
total_losses
=
AverageMeter
()
end
=
time
.
time
()
validated
=
False
for
iter
,
data
in
enumerate
(
train_dataloader
):
if
epoch
>
1
and
epoch
%
cfg
[
"
training
"
][
"
validate_every
"
]
==
0
and
not
validated
:
validate
(
fabric
,
model
,
val_dataloader
,
epoch
)
validated
=
True
data_time
.
update
(
time
.
time
()
-
end
)
# TODO : adapt to our model
input_images
=
data
.
get
(
'
input_images
'
)
input_camera_pos
=
data
.
get
(
'
input_camera_pos
'
)
input_rays
=
data
.
get
(
'
input_rays
'
)
target_pixels
=
data
.
get
(
'
target_pixels
'
)
if
isinstance
(
model
.
encoder
,
FeatureMasking
):
input_images
=
input_images
.
permute
(
0
,
1
,
3
,
4
,
2
)
# from [b, k, c, h, w] to [b, k, h, w, c]
h
,
w
,
c
=
input_images
[
0
][
0
].
shape
masks_info
,
z
=
model
.
encoder
(
input_images
,(
h
,
w
),
input_camera_pos
,
input_rays
,
extract_masks
=
True
)
else
:
z
=
model
.
encoder
(
input_images
,
input_camera_pos
,
input_rays
)
target_camera_pos
=
data
.
get
(
'
target_camera_pos
'
)
target_rays
=
data
.
get
(
'
target_rays
'
)
loss_mse
=
torch
.
tensor
(
0.
,
device
=
fabric
.
device
)
loss_focal
=
torch
.
tensor
(
0.
,
device
=
fabric
.
device
)
loss_dice
=
torch
.
tensor
(
0.
,
device
=
fabric
.
device
)
pred_pixels
,
extras
=
model
.
decoder
(
z
,
target_camera_pos
,
target_rays
)
#, **self.render_kwargs)
### Compute MSE on pixels
loss_mse
=
loss_mse
+
((
pred_pixels
-
target_pixels
)
**
2
).
mean
((
1
,
2
))
batch_size
=
input_images
.
shape
[
0
]
if
'
segmentation
'
in
extras
:
# TODO : for visualisation only, could be interesting to check real GT
#true_seg = data['target_masks'].float()
pred_masks
=
extras
[
'
segmentation
'
]
# TODO : check the content of num_masks
num_masks
=
sum
(
len
(
pred_mask
)
for
pred_mask
in
pred_mask
)
for
pred_mask
,
gt_mask
in
zip
(
pred_masks
,
masks_info
[
"
segmentations
"
]):
loss_focal
+=
focal_loss
(
pred_mask
,
gt_mask
,
num_masks
)
loss_dice
+=
dice_loss
(
pred_mask
,
gt_mask
,
num_masks
)
# TODO : check the values of the loss and see if scale is ok
loss_total
=
20.
*
loss_focal
+
loss_dice
+
loss_mse
# TODO : check also with ARI, FG-ARI values and new from recent paper
"""
loss_terms[
'
ari
'
] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
loss_terms[
'
fg_ari
'
] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
"""
optimizer
.
zero_grad
()
fabric
.
backward
(
loss_total
)
optimizer
.
step
()
scheduler
.
step
()
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
focal_losses
.
update
(
loss_focal
.
item
(),
batch_size
)
dice_losses
.
update
(
loss_dice
.
item
(),
batch_size
)
mse_losses
.
update
(
loss_mse
.
item
(),
batch_size
)
total_losses
.
update
(
loss_total
.
item
(),
batch_size
)
fabric
.
print
(
f
'
Epoch: [
{
epoch
}
][
{
iter
+
1
}
/
{
len
(
train_dataloader
)
}
]
'
f
'
| Time [
{
batch_time
.
val
:
.
3
f
}
s (
{
batch_time
.
avg
:
.
3
f
}
s)]
'
f
'
| Data [
{
data_time
.
val
:
.
3
f
}
s (
{
data_time
.
avg
:
.
3
f
}
s)]
'
f
'
| Focal Loss [
{
focal_losses
.
val
:
.
4
f
}
(
{
focal_losses
.
avg
:
.
4
f
}
)]
'
f
'
| Dice Loss [
{
dice_losses
.
val
:
.
4
f
}
(
{
dice_losses
.
avg
:
.
4
f
}
)]
'
f
'
| MSE Loss [
{
mse_losses
.
val
:
.
4
f
}
(
{
mse_losses
.
avg
:
.
4
f
}
)]
'
f
'
| Total Loss [
{
total_losses
.
val
:
.
4
f
}
(
{
total_losses
.
avg
:
.
4
f
}
)]
'
)
def
configure_opt
(
cfg
,
model
:
OSRT
):
warmup_iters
=
cfg
[
'
training
'
][
'
decay_it
'
]
if
'
decay_it
'
in
cfg
[
'
training
'
]
else
4000000
peak_it
=
cfg
[
'
training
'
][
'
lr_warmup
'
]
if
'
lr_warmup
'
in
cfg
[
'
training
'
]
else
2500
peak_lr
=
1e-4
decay_rate
=
0.16
with
open
(
config_path
,
'
r
'
)
as
f
:
cfg
=
json
.
load
(
f
)
# LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16)
def
lr_lambda
(
step
):
if
step
<
peak_it
:
# Warmup period
return
peak_lr
*
(
step
/
peak_it
)
it_since_peak
=
step
-
peak_it
return
peak_lr
*
(
decay_rate
**
(
it_since_peak
/
warmup_iters
))
# TODO : check begin value of lr
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
weight_decay
=
decay_rate
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
LambdaLR
(
optimizer
,
lr_lambda
)
return
optimizer
,
scheduler
def
main
(
cfg
)
->
None
:
#########################
### Setup parameters
#########################
num_devices
=
cfg
[
'
training
'
][
'
num_gpu
'
]
if
'
num_gpu
'
in
cfg
[
'
training
'
]
else
1
num_workers
=
cfg
[
'
training
'
][
'
num_workers
'
]
if
'
num_workers
'
in
cfg
[
'
training
'
]
else
1
batch_size
=
cfg
[
'
training
'
][
'
batch_size
'
]
//
num_devices
auto_wrap_policy
=
partial
(
transformer_auto_wrap_policy
,
transformer_layer_cls
=
{
Transformer
,
TwoWayTransformer
})
strategy
=
FSDPStrategy
(
auto_wrap_policy
=
auto_wrap_policy
,
activation_checkpointing
=
{
Transformer
,
TwoWayTransformer
},
limit_all_gathers
=
True
)
# TODO : activer precision bf16
fabric
=
L
.
Fabric
(
accelerator
=
"
cuda
"
,
devices
=
num_devices
,
precision
=
cfg
[
"
training
"
][
"
precision
"
],
strategy
=
strategy
)
#########################
### Launch the model
#########################
fabric
=
L
.
Fabric
(
accelerator
=
"
gpu
"
,
devices
=
num_devices
,
strategy
=
"
auto
"
,
loggers
=
[
TensorBoardLogger
(
cfg
[
'
training
'
][
'
out_dir
'
],
name
=
"
lightning-sam
"
)])
fabric
.
launch
()
fabric
.
seed_everything
(
1337
+
fabric
.
global_rank
)
if
fabric
.
global_rank
==
0
:
os
.
makedirs
(
out_dir
,
exist_ok
=
True
)
os
.
makedirs
(
cfg
[
'
training
'
][
'
out_dir
'
]
,
exist_ok
=
True
)
###################
# Import Dataset
###################
with
fabric
.
device
:
model
=
OSRT
(
cfg
)
#########################
### Loading the dataset
#########################
train_dataset
=
data
.
get_dataset
(
'
train
'
,
cfg
[
'
data
'
])
val_dataset
=
data
.
get_dataset
(
'
val
'
,
cfg
[
'
data
'
])
test_dataset
=
data
.
get_dataset
(
'
test
'
,
cfg
[
'
data
'
])
...
...
@@ -107,162 +232,36 @@ def main(
train_loader
,
val_loader
,
test_loader
=
fabric
.
setup_dataloaders
(
train_loader
,
val_loader
,
test_loader
)
data_vis_val
=
next
(
iter
(
vis_loader_val
))
# Validation set data for visualization
data_vis_val
=
fabric
.
to_device
(
data_vis_val
)
if
checkpoint
:
checkpoint
=
torch
.
load
(
checkpoint
)
with
fabric
.
device
:
torch
.
set_default_tensor_type
(
torch
.
HalfTensor
)
model
=
OSRT
(
cfg
[
'
model
'
]).
bfloat16
()
torch
.
set_default_tensor_type
(
torch
.
FloatTensor
)
if
checkpoint
:
model
.
load_state_dict
(
checkpoint
,
strict
=
False
)
model
=
fabric
.
setup_module
(
model
)
params
=
[
p
for
p
in
model
.
parameters
()
if
p
.
requires_grad
]
# Setup scheduler
warmup_iters
=
cfg
[
'
training
'
][
'
decay_it
'
]
if
'
decay_it
'
in
cfg
[
'
training
'
]
else
4000000
peak_it
=
cfg
[
'
training
'
][
'
lr_warmup
'
]
if
'
lr_warmup
'
in
cfg
[
'
training
'
]
else
2500
lr_scheduler
=
LrScheduler
(
peak_lr
=
1e-4
,
peak_it
=
peak_it
,
decay_it
=
warmup_iters
,
decay_rate
=
0.16
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
learning_rate
,
foreach
=
False
)
optimizer
=
fabric
.
setup_optimizers
(
optimizer
)
train
(
fabric
,
model
,
optimizer
,
train_loader
,
val_loader
,
out_dir
)
# Save the final checkpoint at the end of training
save_model_checkpoint
(
fabric
,
model
,
os
.
path
.
join
(
out_dir
,
"
lit-llama-full-finetuned.pth
"
))
def
train
(
fabric
:
L
.
Fabric
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
train_data
:
DataLoader
,
# TODO : maybe use np.array
val_data
:
DataLoader
,
out_dir
:
str
,
)
->
None
:
"""
The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count
=
0
model
.
train
()
#########################
### Prepare the optimizer
#########################
optimizer
,
scheduler
=
configure_opt
(
cfg
,
model
)
model
,
optimizer
=
fabric
.
setup
(
model
,
optimizer
)
for
iter_num
in
range
(
max_iters
):
#########################
### Training
#########################
train_sam
(
cfg
,
fabric
,
model
,
optimizer
,
scheduler
,
train_loader
,
val_loader
)
validate
(
fabric
,
model
,
val_loader
,
epoch
=
0
)
is_accumulating
=
(
iter_num
+
1
)
%
gradient_accumulation_iters
!=
0
if
step_count
<=
warmup_iters
:
# linear warmup
lr
=
learning_rate
*
step_count
/
warmup_iters
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'
lr
'
]
=
lr
t0
=
time
.
time
()
input_ids
,
targets
=
get_batch
(
fabric
,
train_data
)
with
fabric
.
no_backward_sync
(
model
,
enabled
=
is_accumulating
):
logits
=
model
(
input_ids
)
loss
=
loss_fn
(
logits
,
targets
)
fabric
.
backward
(
loss
/
gradient_accumulation_iters
)
if
not
is_accumulating
:
optimizer
.
step
()
optimizer
.
zero_grad
()
step_count
+=
1
if
step_count
%
eval_interval
==
0
:
val_loss
=
validate
(
fabric
,
model
,
val_data
)
fabric
.
print
(
f
"
step
{
iter_num
}
: val loss
{
val_loss
:
.
4
f
}
"
)
fabric
.
barrier
()
if
step_count
%
save_interval
==
0
:
print
(
f
"
Saving weights to
{
out_dir
}
"
)
save_model_checkpoint
(
fabric
,
model
,
os
.
path
.
join
(
out_dir
,
f
"
iter-
{
iter_num
:
06
d
}
-ckpt.pth
"
))
dt
=
time
.
time
()
-
t0
if
iter_num
%
log_interval
==
0
:
fabric
.
print
(
f
"
iter
{
iter_num
}
: loss
{
loss
.
item
()
:
.
4
f
}
, time:
{
dt
*
1000
:
.
2
f
}
ms
"
)
def
generate_response
(
model
,
instruction
):
tokenizer
=
Tokenizer
(
"
checkpoints/lit-llama/tokenizer.model
"
)
sample
=
{
"
instruction
"
:
instruction
,
"
input
"
:
""
}
prompt
=
instruction
if
instruction_tuning
:
prompt
=
generate_prompt
(
sample
)
encoded
=
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
,
device
=
model
.
device
)
output
=
generate
(
model
,
idx
=
encoded
,
max_seq_length
=
block_size
,
max_new_tokens
=
100
,
if
__name__
==
"
__main__
"
:
### Arguments
parser
=
argparse
.
ArgumentParser
(
description
=
'
Train a 3D scene representation model.
'
)
output
=
tokenizer
.
decode
(
output
)
return
output
# output.split("### Response:")[1].strip(
)
parser
.
add_argument
(
'
config
'
,
type
=
str
,
help
=
'
Path to config file.
'
)
parser
.
add_argument
(
'
--wandb
'
,
action
=
'
store_true
'
,
help
=
'
Log run to Weights and Biases.
'
)
parser
.
add_argument
(
'
--checkpoint
'
,
type
=
str
,
default
=
''
,
help
=
'
Path to a model checkpoint
'
)
@torch.no_grad
()
def
validate
(
fabric
:
L
.
Fabric
,
model
:
torch
.
nn
.
Module
,
val_data
:
np
.
ndarray
)
->
torch
.
Tensor
:
fabric
.
print
(
"
Validating ...
"
)
model
.
eval
()
losses
=
torch
.
zeros
(
eval_iters
)
for
k
in
range
(
eval_iters
):
input_ids
,
targets
=
get_batch
(
fabric
,
val_data
)
logits
=
model
(
input_ids
)
loss
=
loss_fn
(
logits
,
targets
)
losses
[
k
]
=
loss
.
item
()
out
=
losses
.
mean
()
# produce an example:
instruction
=
"
Recommend a movie for me to watch during the weekend and explain the reason.
"
args
=
parser
.
parse_args
()
output
=
generate_response
(
model
,
instruction
)
fabric
.
print
(
instruction
)
fabric
.
print
(
output
)
model
.
train
()
return
out
.
item
()
def
loss_fn
(
logits
,
targets
):
# shift the targets such that output n predicts token n+1
logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
targets
=
targets
[...,
1
:].
contiguous
()
loss
=
torch
.
nn
.
functional
.
cross_entropy
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
targets
.
view
(
-
1
),
ignore_index
=-
1
)
return
loss
def
get_batch
(
fabric
:
L
.
Fabric
,
data
:
list
):
ix
=
torch
.
randint
(
len
(
data
),
(
micro_batch_size
,))
input_ids
=
[
data
[
i
][
"
input_ids
"
].
type
(
torch
.
int64
)
for
i
in
ix
]
labels
=
[
data
[
i
][
"
labels
"
].
type
(
torch
.
int64
)
for
i
in
ix
]
max_len
=
max
(
len
(
s
)
for
s
in
input_ids
)
def
pad_right
(
x
,
pad_id
):
# pad right based on the longest sequence
n
=
max_len
-
len
(
x
)
return
torch
.
cat
((
x
,
torch
.
full
((
n
,),
pad_id
,
dtype
=
x
.
dtype
)))
x
=
torch
.
stack
([
pad_right
(
x
,
pad_id
=
0
)
for
x
in
input_ids
])
y
=
torch
.
stack
([
pad_right
(
x
,
pad_id
=-
1
)
for
x
in
labels
])
x
,
y
=
fabric
.
to_device
((
x
.
pin_memory
(),
y
.
pin_memory
()))
return
x
,
y
def
load_datasets
(
data_dir
):
train_data
=
torch
.
load
(
os
.
path
.
join
(
data_dir
,
"
train.pt
"
))
val_data
=
torch
.
load
(
os
.
path
.
join
(
data_dir
,
"
test.pt
"
))
return
train_data
,
val_data
if
__name__
==
"
__main__
"
:
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch
.
set_float32_matmul_precision
(
"
high
"
)
CLI
(
main
)
\ No newline at end of file
#########################
### Creating utility var
#########################
with
open
(
args
.
config
,
'
r
'
)
as
f
:
cfg
=
json
.
load
(
f
)
main
(
cfg
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment