Skip to content
Snippets Groups Projects
Commit 0282d0d8 authored by Schneider Leo's avatar Schneider Leo
Browse files

data augmented

parent 1cb6c3ab
No related branches found
No related tags found
No related merge requests found
...@@ -81,8 +81,6 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, ...@@ -81,8 +81,6 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
mem = losses_rt mem = losses_rt
torch.save(model.state_dict(), output.strip('.csv')+'pt') torch.save(model.state_dict(), output.strip('.csv')+'pt')
print('model saved') print('model saved')
if e % save_inter == 0:
save(model, 'model_common_' + str(e) + '.pt')
model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True)) model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True))
save_pred(model, data_test, output, criterion_rt, metric_rt, wandb) save_pred(model, data_test, output, criterion_rt, metric_rt, wandb)
...@@ -162,7 +160,7 @@ def save_pred(model, data_val, output_path, criterion_rt, metric_rt, wandb=None ...@@ -162,7 +160,7 @@ def save_pred(model, data_val, output_path, criterion_rt, metric_rt, wandb=None
loss_rt = criterion_rt(rt, pred_rt) loss_rt = criterion_rt(rt, pr_rt)
losses_rt += loss_rt.item() losses_rt += loss_rt.item()
dist_rt = metric_rt(rt, pred_rt) dist_rt = metric_rt(rt, pred_rt)
dist_rt_acc += dist_rt.item() dist_rt_acc += dist_rt.item()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment