-
Léo Schneider authored2cda896b
config_common.py 1.62 KiB
import argparse
def load_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--save_inter', type=int, default=100)
parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--n_head', type=int, default=1)
parser.add_argument('--embedding_dim', type=int, default=16)
parser.add_argument('--encoder_ff', type=int, default=2048)
parser.add_argument('--decoder_rt_ff', type=int, default=2048)
parser.add_argument('--decoder_int_ff', type=int, default=512)
parser.add_argument('--encoder_num_layer', type=int, default=2)
parser.add_argument('--decoder_rt_num_layer', type=int, default=1)
parser.add_argument('--decoder_int_num_layer', type=int, default=1)
parser.add_argument('--drop_rate', type=float, default=0.035)
parser.add_argument('--wandb', type=str, default=None)
parser.add_argument('--forward', type=str, default='both')
parser.add_argument('--dataset_train', type=str, default='database/data_DIA_ISA_55_train.pkl')
parser.add_argument('--dataset_val', type=str, default='database/data_DIA_ISA_55_test.pkl')
parser.add_argument('--dataset_test', type=str, default='database/data_DIA_ISA_55_test.pkl')
parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
parser.add_argument('--activation', type=str,default='relu')
args = parser.parse_args()
return args