From 702fd61c5401e0fc278cb27373aafc4eaeaa4819 Mon Sep 17 00:00:00 2001 From: Abderaouf Gacem <gcmabderaouf@gmail.com> Date: Mon, 13 Feb 2023 11:32:33 +0000 Subject: [PATCH] init --- main.py | 231 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..df9254e --- /dev/null +++ b/main.py @@ -0,0 +1,231 @@ +import torch + +import torch.nn.functional as F + +from torch_geometric.utils import degree + +import utils +from Args import Args + +import pandas as pd + + +path = "./data/" +args = Args(dataset='Flickr', + model='graph_saint_gcn', + edge_weight = True, + use_normalization = True, + hidden_channels=256, + lr=0.001, + use_cuda=True, + epochs=50, + loader="forest_fire", + p=0.5, + connectivity=False, + batch_size=14000, + num_workers=4, + num_steps=20, + sample_coverage=100, + test_loader=False) + + +dataset = utils.get_dataset(args.dataset, path) +data = dataset[0] + +if args.edge_weight : + row, col = data.edge_index + data.edge_weight = 1. / degree(col, data.num_nodes)[col] # Norm by in-degree. + + +loader = utils.get_loader(loader=args.loader, data=data, save_dir=dataset.processed_dir, batch_size=args.batch_size, num_workers=args.num_workers, **args.kwargs) +if args.test_loader : + test_loader = utils.get_loader(loader="test_loader", data=data, save_dir=None, batch_size=args.test_batch_size, num_workers=args.num_workers) + + + +device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu') +if args.model == 'gat': + model = args.model_class(dataset.num_features, dataset.num_classes, hidden_channels=args.hidden_channels, heads=args.heads).to(device) +else : + model = args.model_class(dataset.num_features, dataset.num_classes, hidden_channels=args.hidden_channels).to(device) + +if args.model == 'gcn' : + optimizer = torch.optim.Adam([ + dict(params=model.conv1.parameters(), weight_decay=5e-4), + dict(params=model.conv2.parameters(), weight_decay=0) + ], lr=args.lr) +elif args.model == 'gat' : + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4) +else : + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + +def train(): + if args.model == "graph_saint" or args.model == "saint_gcn": + model.train() + model.set_aggr('add' if args.use_normalization else 'mean') + + total_train_loss = total_train_acc = total_train_nodes = 0 + total_val_loss = total_val_acc = total_val_nodes = 0 + for batch in loader: + batch = batch.to(device) + optimizer.zero_grad() + + if args.use_normalization: + edge_weight = batch.edge_norm * batch.edge_weight + out = model(batch.x, batch.edge_index, edge_weight) + loss = F.nll_loss(out, batch.y, reduction='none') + loss = (loss * batch.node_norm)[batch.train_mask].sum() + val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask]) + val_loss = (val_loss * batch.node_norm)[batch.val_mask].sum() + else: + out = model(batch.x, batch.edge_index) + loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask]) + val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask]) + + loss.backward() + optimizer.step() + + train_nodes = batch.train_mask.sum().item() + total_train_loss += loss.item() * train_nodes + total_train_nodes += train_nodes + + + val_nodes = batch.val_mask.sum().item() + total_val_loss += val_loss.item() * val_nodes + total_val_nodes += val_nodes + + + y_pred = out.argmax(dim=-1).to(device) + correct = y_pred.eq(batch.y.to(device)) + + accs = [] + for _, mask in batch('train_mask', 'val_mask'): + accs.append(correct[mask].sum().item() / mask.sum().item()) + total_val_acc += accs[1] * val_nodes + total_train_acc += accs[0] * train_nodes + + return {"train_acc" : total_train_acc / total_train_nodes, "train_loss" : total_train_loss / total_train_nodes, "val_acc" : total_val_acc / total_val_nodes, "val_loss" : total_val_loss / total_val_nodes} + else : + model.train() + + total_train_loss = total_train_acc = total_train_nodes = 0 + total_val_loss = total_val_acc = total_val_nodes = 0 + for batch in loader: + batch = batch.to(device) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index) + + loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask]) + loss.backward() + optimizer.step() + + train_nodes = batch.train_mask.sum().item() + total_train_loss += loss.item() * train_nodes + total_train_nodes += train_nodes + + val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask]) + val_nodes = batch.val_mask.sum().item() + total_val_loss += val_loss.item() * val_nodes + total_val_nodes += val_nodes + + + y_pred = out.argmax(dim=-1).to(device) + correct = y_pred.eq(batch.y.to(device)) + + accs = [] + for _, mask in batch('train_mask', 'val_mask'): + accs.append(correct[mask].sum().item() / mask.sum().item()) + total_val_acc += accs[1] * val_nodes + total_train_acc += accs[0] * train_nodes + + return {"train_acc" : total_train_acc / total_train_nodes, "train_loss" : total_train_loss / total_train_nodes, "val_acc" : total_val_acc / total_val_nodes, "val_loss" : total_val_loss / total_val_nodes} + + +@torch.no_grad() +def test(): + model.eval() + + if args.model == "graph_saint" : + model.set_aggr('mean') + + if args.test_loader : + out = model.inference(data.x, test_loader, device) + else : + out = model(data.x.to(device), data.edge_index.to(device)) + + y_pred = out.argmax(dim=-1).to(device) + correct = y_pred.eq(data.y.to(device)) + + accs = [] + for _, mask in data('train_mask', 'val_mask', 'test_mask'): + accs.append(correct[mask].sum().item() / mask.sum().item()) + return accs + +@torch.no_grad() +def traintest(): + model.eval() + total_loss = total_acc = total_nodes = 0 + if args.model == "cluster_gcn" : + + for batch in loader: + batch = batch.to(device) + + out = model(batch.x, batch.edge_index) + + loss = F.nll_loss(out[batch.test_mask], batch.y[batch.test_mask]) + + + nodes = batch.test_mask.sum().item() + total_loss += loss.item() * nodes + total_nodes += nodes + + y_pred = out.argmax(dim=-1).to(device) + correct = y_pred.eq(batch.y.to(device)) + + acc = correct[batch.test_mask].sum().item() / batch.test_mask.sum().item() + total_acc += acc * nodes + + return total_acc / total_nodes + elif args.model == "graph_saint" : + model.set_aggr('add' if args.use_normalization else 'mean') + + for batch in loader: + batch = batch.to(device) + + if args.use_normalization: + edge_weight = batch.edge_norm * batch.edge_weight + out = model(batch.x, batch.edge_index, edge_weight) + loss = F.nll_loss(out, batch.y, reduction='none') + loss = (loss * batch.node_norm)[batch.test_mask].sum() + else: + out = model(batch.x, batch.edge_index) + loss = F.nll_loss(out[batch.test_mask], batch.y[batch.test_mask]) + + nodes = batch.test_mask.sum().item() + total_loss += loss.item() * nodes + total_nodes += nodes + + y_pred = out.argmax(dim=-1).to(device) + correct = y_pred.eq(batch.y.to(device)) + + acc = correct[batch.test_mask].sum().item() / batch.test_mask.sum().item() + total_acc += acc * nodes + + + return total_acc / total_nodes + +stats = [] +best_val_acc = 0 +for epoch in range(1, args.epochs+1): + stat = train() + print(f'Epoch: {epoch:02d}, Train: {stat["train_acc"]:.4f}, Val: {stat["val_acc"]:.4f}') + stats.append(stat) + if stat["val_acc"] > best_val_acc : + best_val_acc = stat["val_acc"] + best_model = '{}.pkl'.format(epoch) + torch.save(model.state_dict(), '{}.pkl'.format(epoch)) + +pd.DataFrame(stats).to_csv(f'{args.dataset}_{args.model}_{args.batch_size}_{args.loader}.csv') + +model.load_state_dict(torch.load(best_model)) +print("test accuracy : " + str(traintest())) -- GitLab