Skip to content
Snippets Groups Projects
Commit 2217fd04 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : ray opti

parent 4059d5c6
No related branches found
No related tags found
No related merge requests found
import os
import tempfile
from config import load_args_contrastive
from dataset_ref import load_data_duo
import torch
import torch.nn as nn
from model import Classification_model_duo_contrastive
import torch.optim as optim
#ray
from ray.air import RunConfig
from ray.tune.search.optuna import OptunaSearch
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
def train_model(config,args):
# load data
data_train, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
noise_threshold=config['noise'],
positive_prop=config['positive_prop'], sampler=config['sampler'])
# load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
# move parameters to GPU
model.double()
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
print(type(model))
net = torch.nn.DataParallel(model)
print(type(net))
model.to(device)
if config['optimizer']=='Adam' :
optimizer = optim.Adam(model.parameters(), lr=config["lr"])
elif config['optimizer']=='SGD' :
optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
# init training
loss_function = nn.CrossEntropyLoss()
# Load existing checkpoint through `get_checkpoint()` API.
if train.get_checkpoint():
loaded_checkpoint = train.get_checkpoint()
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
model_state, optimizer_state = torch.load(
os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
# train model
for e in range(args.epoches):
#train loss
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
for imaer, imana, img_ref, label in data_train:
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
if torch.cuda.device_count() > 1:
pred_logits = model.module.forward(imaer, imana, img_ref)
else:
pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits, dim=1)
acc += (pred_class == label).sum().item()
loss = loss_function(pred_logits, label)
losses += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses_train = losses / len(data_train.dataset)
acc_train = acc / len(data_train.dataset)
#validation loss
model.eval()
losses = 0.
acc = 0.
acc_contrastive = 0.
for param in model.parameters():
param.requires_grad = False
for imaer, imana, img_ref, label in data_val_batch:
imaer = imaer.transpose(0, 1)
imana = imana.transpose(0, 1)
img_ref = img_ref.transpose(0, 1)
label = label.transpose(0, 1)
label = label.squeeze()
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
if torch.cuda.device_count() > 1:
pred_logits = model.module.forward(imaer, imana, img_ref)
else:
pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits[:, 0]).tolist()
acc_contrastive += (
torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
acc += (pred_class == label_class)
loss = loss_function(pred_logits, label)
losses += loss.item()
losses_val = losses / (label.shape[0] * len(data_val_batch.dataset))
acc_val = acc / (len(data_val_batch.dataset))
acc_contrastive_val = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset))
# Here we save a checkpoint. It is automatically registered with
# Ray Tune and will potentially be accessed through in ``get_checkpoint()``
# in future iterations.
# Note to save a file like checkpoint, you still need to put it under a directory
# to construct a checkpoint.
with tempfile.TemporaryDirectory(
dir='lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/checkpoints') as temp_checkpoint_dir:
path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
torch.save(
(model.state_dict(), optimizer.state_dict()), path
)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
print(checkpoint.path)
train.report(
{"train loss": losses_train, "train contrastive acc": acc_train,"val loss": losses_val,"val acc": acc_val,"val contrastive acc": acc_contrastive_val,},
checkpoint=checkpoint,)
print("Finished Training")
def test_model(best_result, args):
# load data
_, data_val_batch, _ = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
noise_threshold=best_result.config['noise'],
positive_prop=best_result.config['positive_prop'], sampler=best_result.config['sampler'])
# load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
model.double()
# load weight
checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
model_state, optimizer_state = torch.load(checkpoint_path)
model.load_state_dict(model_state)
# move parameters to GPU
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
print(type(model))
net = torch.nn.DataParallel(model)
print(type(net))
model.to(device)
# init training
loss_function = nn.CrossEntropyLoss()
model.eval()
losses = 0.
acc = 0.
acc_contrastive = 0.
for param in model.parameters():
param.requires_grad = False
for imaer, imana, img_ref, label in data_val_batch:
imaer = imaer.transpose(0, 1)
imana = imana.transpose(0, 1)
img_ref = img_ref.transpose(0, 1)
label = label.transpose(0, 1)
label = label.squeeze()
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits[:, 0]).tolist()
acc_contrastive += (
torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
acc += (pred_class == label_class)
loss = loss_function(pred_logits, label)
losses += loss.item()
losses = losses / (label.shape[0] * len(data_val_batch.dataset))
acc = acc / (len(data_val_batch.dataset))
acc_contrastive = acc_contrastive / (label.shape[0] * len(data_val_batch.dataset))
print("Best trial test set AsyncHyperBandSchedulerloss: loss {} acc {} acc_contrastive {}".format(losses,acc,acc_contrastive))
def main(args, gpus_per_trial=1):
config = {
"lr": tune.loguniform(1e-4, 1e-2),
"noise": tune.loguniform(0, 500),
"positive_prop": tune.uniform(0, 100),
"optimizer": tune.choice(['Adam', 'SGD']),
"sampler": tune.choice(['random', 'balanced']),
}
scheduler = ASHAScheduler(
max_t=100,
grace_period=20,
reduction_factor=3,
brackets=1,
)
algo = OptunaSearch()
tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(train_model, args=args),
resources={"cpu": 80, "gpu": gpus_per_trial}
),
tune_config=tune.TuneConfig(
time_budget_s=3600 * 23.5,
search_alg=algo,
scheduler=scheduler,
num_samples=50,
metric="val loss",
mode='min',
),
run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test",
name="test_experiment_no_scheduler"
),
param_space=config
)
results = tuner.fit()
best_result = results.get_best_result("val loss", "min")
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(
best_result.metrics["loss"]))
print("Best trial final validation accuracy: {}".format(
best_result.metrics["accuracy"]))
test_model(best_result, args)
if __name__ == '__main__':
args = load_args_contrastive()
print(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment