Skip to content
Snippets Groups Projects
Commit a856988f authored by Schneider Leo's avatar Schneider Leo
Browse files

data augmented

parent 17fbdb37
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ def load_args(): ...@@ -9,7 +9,7 @@ def load_args():
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=2048) parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--model', type=str, default='RT_multi_sum') parser.add_argument('--model', type=str, default='prosit_transformer')
parser.add_argument('--wandb', type=str, default=None) parser.add_argument('--wandb', type=str, default=None)
parser.add_argument('--dataset_train', type=str, default='data/data_prosit/data.csv') parser.add_argument('--dataset_train', type=str, default='data/data_prosit/data.csv')
parser.add_argument('--split_train', type=str, default='train') parser.add_argument('--split_train', type=str, default='train')
......
...@@ -142,8 +142,10 @@ def get_n_params(model): ...@@ -142,8 +142,10 @@ def get_n_params(model):
return pp return pp
def save_pred(model, data_val, output_path): def save_pred(model, data_val, output_path, criterion_rt, metric_rt, wandb=None):
data_frame = pd.DataFrame() data_frame = pd.DataFrame()
losses_rt = 0.
dist_rt_acc = 0.
model.eval() model.eval()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -158,6 +160,21 @@ def save_pred(model, data_val, output_path): ...@@ -158,6 +160,21 @@ def save_pred(model, data_val, output_path):
seqs.extend(seq.data.cpu().tolist()) seqs.extend(seq.data.cpu().tolist())
true_rt.extend(rt.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist())
loss_rt = criterion_rt(rt, pred_rt)
losses_rt += loss_rt.item()
dist_rt = metric_rt(rt, pred_rt)
dist_rt_acc += dist_rt.item()
if wandb is not None:
wdb.log({"test rt loss": losses_rt / len(data_val),
"test rt mean metric": dist_rt_acc / len(data_val)})
print('val rt loss', losses_rt / len(data_val),
"val rt mean metric : ",
dist_rt_acc / len(data_val))
data_frame['rt pred'] = pred_rt data_frame['rt pred'] = pred_rt
data_frame['seq'] = seqs data_frame['seq'] = seqs
data_frame['true rt'] = true_rt data_frame['true rt'] = true_rt
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment