diff --git a/image_ref/config.py b/image_ref/config.py index 95bd541903d2907d4cc37b5f7990fccdd10322d4..817e0cda7e5b7ac5fd9e41fd298ba308c005cf3f 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -22,7 +22,7 @@ def load_args_contrastive(): parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default=None) - parser.add_argument('--wandb', type=str, default='image_ref/best_model_base_ray.pt') + parser.add_argument('--wandb', type=str, default='wandb_run') args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/hyperparameter_res_analysis.py b/image_ref/hyperparameter_res_analysis.py index cdabac5f5e441d1037416bc4a1eedfebdc82abb8..7a411a3a87abbe63223f0b0b815c672875bbe91a 100644 --- a/image_ref/hyperparameter_res_analysis.py +++ b/image_ref/hyperparameter_res_analysis.py @@ -2,5 +2,3 @@ import pandas as pd import numpy as np df = pd.read_csv('../df_results_contrastive.csv') - -best_param = df[df['val loss']<0.003] \ No newline at end of file diff --git a/image_ref/main.py b/image_ref/main.py index bab6e980df95e1f786ad1ce376e5420c9ddf3191..32cba8fc5efdbd9d5559ccfda7d57be424d44514 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -101,8 +101,8 @@ def run_duo(args): os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") - wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb) - + wdb.init(project="contrastive_classification", dir='./wandb_run') + print('Wandb initialised') # load data data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_val=args.dataset_val_dir, diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 80fa177ad708f7c00a2a7eb16d18ff5135c50164..801ffc23cc44938bc8e38d78df3095e9ef9c56ab 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -50,6 +50,7 @@ def train_model(config,args): optimizer = optim.Adam(model.parameters(), lr=config["lr"]) elif config['optimizer']=='SGD' : optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) + # init training n_class = len(data_train.dataset.classes) weight = torch.Tensor([1/n_class,1-1/n_class])