Skip to content
Snippets Groups Projects
config.py 1.03 KiB
import argparse


def load_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--save_inter', type=int, default=100)
    parser.add_argument('--eval_inter', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--n_test', type=int, default=None)
    parser.add_argument('--n_train', type=int, default=None)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--model', type=str, default='RT_multi_sum')
    parser.add_argument('--wandb', type=str, default=None)
    parser.add_argument('--coef_pretext', type=float, default=1.)
    parser.add_argument('--dataset_train', type=str, default='database/data.csv')
    parser.add_argument('--dataset_test', type=str, default='database/data.csv')
    parser.add_argument('--layers_sizes', nargs='+', type=int, default=[256, 512, 512])
    args = parser.parse_args()

    return args