Skip to content
Snippets Groups Projects
Commit 6a9416ea authored by Schneider Leo's avatar Schneider Leo
Browse files

forward = transfer

parent e6425962
No related branches found
No related tags found
No related merge requests found
...@@ -175,15 +175,28 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m ...@@ -175,15 +175,28 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'): criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, if forward =='transfer' :
wandb=wandb) for e in range(1, epochs + 1):
if e % eval_inter == 0: train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=wandb)
wandb=wandb) if e % eval_inter == 0:
if e % save_inter == 0: eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
save(model, 'model_common_' + str(e) + '.pt') wandb=wandb)
save_pred(model, data_val, forward, output) if e % save_inter == 0:
save(model, 'model_common_' + str(e) + '.pt')
save_pred(model, data_val, 'rt', output)
else :
for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
wandb=wandb)
if e % eval_inter == 0:
eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
wandb=wandb)
if e % save_inter == 0:
save(model, 'model_common_' + str(e) + '.pt')
save_pred(model, data_val, forward, output)
def main(args): def main(args):
...@@ -202,7 +215,11 @@ def main(args): ...@@ -202,7 +215,11 @@ def main(args):
path_test=args.dataset_test, path_test=args.dataset_test,
batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc') batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc')
elif args.forward == 'rt': elif args.forward == 'rt':
data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
batch_size=args.batch_size, length=25)
elif args.forward == 'transfer':
data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.cvs','database/data_holdout.cvs'],
batch_size=args.batch_size, length=25) batch_size=args.batch_size, length=25)
_, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val, _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment