From 702fd61c5401e0fc278cb27373aafc4eaeaa4819 Mon Sep 17 00:00:00 2001
From: Abderaouf Gacem <gcmabderaouf@gmail.com>
Date: Mon, 13 Feb 2023 11:32:33 +0000
Subject: [PATCH] init

---
 main.py | 231 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 231 insertions(+)
 create mode 100644 main.py

diff --git a/main.py b/main.py
new file mode 100644
index 0000000..df9254e
--- /dev/null
+++ b/main.py
@@ -0,0 +1,231 @@
+import torch
+
+import torch.nn.functional as F
+
+from torch_geometric.utils import degree
+
+import utils
+from Args import Args
+
+import pandas as pd
+
+
+path = "./data/"
+args = Args(dataset='Flickr', 
+            model='graph_saint_gcn', 
+            edge_weight = True,
+            use_normalization = True,
+            hidden_channels=256, 
+            lr=0.001, 
+            use_cuda=True, 
+            epochs=50,
+            loader="forest_fire",
+            p=0.5,
+            connectivity=False, 
+            batch_size=14000, 
+            num_workers=4,
+            num_steps=20,
+            sample_coverage=100,
+            test_loader=False)
+
+
+dataset = utils.get_dataset(args.dataset, path)
+data = dataset[0]
+
+if args.edge_weight :
+    row, col = data.edge_index
+    data.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.
+
+
+loader = utils.get_loader(loader=args.loader, data=data, save_dir=dataset.processed_dir, batch_size=args.batch_size, num_workers=args.num_workers, **args.kwargs)
+if args.test_loader :
+    test_loader = utils.get_loader(loader="test_loader", data=data, save_dir=None, batch_size=args.test_batch_size, num_workers=args.num_workers)
+
+
+
+device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu')
+if args.model == 'gat':
+    model = args.model_class(dataset.num_features, dataset.num_classes, hidden_channels=args.hidden_channels, heads=args.heads).to(device)
+else :
+    model = args.model_class(dataset.num_features, dataset.num_classes, hidden_channels=args.hidden_channels).to(device)
+
+if args.model == 'gcn' :
+    optimizer = torch.optim.Adam([
+            dict(params=model.conv1.parameters(), weight_decay=5e-4),
+            dict(params=model.conv2.parameters(), weight_decay=0)
+        ], lr=args.lr)
+elif args.model == 'gat' :
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
+else :
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+def train():
+    if args.model == "graph_saint" or args.model == "saint_gcn":
+        model.train()
+        model.set_aggr('add' if args.use_normalization else 'mean')
+
+        total_train_loss = total_train_acc = total_train_nodes = 0
+        total_val_loss = total_val_acc = total_val_nodes = 0
+        for batch in loader:
+            batch = batch.to(device)
+            optimizer.zero_grad()
+
+            if args.use_normalization:
+                edge_weight = batch.edge_norm * batch.edge_weight
+                out = model(batch.x, batch.edge_index, edge_weight)
+                loss = F.nll_loss(out, batch.y, reduction='none')
+                loss = (loss * batch.node_norm)[batch.train_mask].sum()
+                val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask])
+                val_loss = (val_loss * batch.node_norm)[batch.val_mask].sum()
+            else:
+                out = model(batch.x, batch.edge_index)
+                loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
+                val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask])
+
+            loss.backward()
+            optimizer.step()
+
+            train_nodes = batch.train_mask.sum().item()
+            total_train_loss += loss.item() * train_nodes
+            total_train_nodes += train_nodes
+
+            
+            val_nodes = batch.val_mask.sum().item()
+            total_val_loss += val_loss.item() * val_nodes
+            total_val_nodes += val_nodes
+
+
+            y_pred = out.argmax(dim=-1).to(device)
+            correct = y_pred.eq(batch.y.to(device))
+
+            accs = []
+            for _, mask in batch('train_mask', 'val_mask'):
+                accs.append(correct[mask].sum().item() / mask.sum().item())
+            total_val_acc += accs[1] * val_nodes
+            total_train_acc += accs[0] * train_nodes
+
+        return {"train_acc" : total_train_acc / total_train_nodes, "train_loss" : total_train_loss / total_train_nodes, "val_acc" : total_val_acc / total_val_nodes, "val_loss" : total_val_loss / total_val_nodes}
+    else :
+        model.train()
+
+        total_train_loss = total_train_acc = total_train_nodes = 0
+        total_val_loss = total_val_acc = total_val_nodes = 0
+        for batch in loader:
+            batch = batch.to(device)
+            optimizer.zero_grad()
+            out = model(batch.x, batch.edge_index)
+                
+            loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
+            loss.backward()
+            optimizer.step()
+
+            train_nodes = batch.train_mask.sum().item()
+            total_train_loss += loss.item() * train_nodes
+            total_train_nodes += train_nodes
+
+            val_loss = F.nll_loss(out[batch.val_mask], batch.y[batch.val_mask])
+            val_nodes = batch.val_mask.sum().item()
+            total_val_loss += val_loss.item() * val_nodes
+            total_val_nodes += val_nodes
+
+
+            y_pred = out.argmax(dim=-1).to(device)
+            correct = y_pred.eq(batch.y.to(device))
+
+            accs = []
+            for _, mask in batch('train_mask', 'val_mask'):
+                accs.append(correct[mask].sum().item() / mask.sum().item())
+            total_val_acc += accs[1] * val_nodes
+            total_train_acc += accs[0] * train_nodes
+
+        return {"train_acc" : total_train_acc / total_train_nodes, "train_loss" : total_train_loss / total_train_nodes, "val_acc" : total_val_acc / total_val_nodes, "val_loss" : total_val_loss / total_val_nodes}
+    
+
+@torch.no_grad()
+def test():
+    model.eval()
+
+    if args.model == "graph_saint" :
+        model.set_aggr('mean')
+
+    if args.test_loader :
+        out = model.inference(data.x, test_loader, device)
+    else :
+        out = model(data.x.to(device), data.edge_index.to(device))
+    
+    y_pred = out.argmax(dim=-1).to(device)
+    correct = y_pred.eq(data.y.to(device))
+
+    accs = []
+    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
+        accs.append(correct[mask].sum().item() / mask.sum().item())
+    return accs
+    
+@torch.no_grad()
+def traintest():
+    model.eval()
+    total_loss = total_acc = total_nodes = 0
+    if args.model == "cluster_gcn" :
+        
+        for batch in loader:
+            batch = batch.to(device)
+            
+            out = model(batch.x, batch.edge_index)
+                
+            loss = F.nll_loss(out[batch.test_mask], batch.y[batch.test_mask])
+            
+
+            nodes = batch.test_mask.sum().item()
+            total_loss += loss.item() * nodes
+            total_nodes += nodes
+
+            y_pred = out.argmax(dim=-1).to(device)
+            correct = y_pred.eq(batch.y.to(device))
+
+            acc = correct[batch.test_mask].sum().item() / batch.test_mask.sum().item()
+            total_acc += acc * nodes
+            
+        return total_acc / total_nodes
+    elif args.model == "graph_saint" :
+        model.set_aggr('add' if args.use_normalization else 'mean')
+
+        for batch in loader:
+            batch = batch.to(device)
+
+            if args.use_normalization:
+                edge_weight = batch.edge_norm * batch.edge_weight
+                out = model(batch.x, batch.edge_index, edge_weight)
+                loss = F.nll_loss(out, batch.y, reduction='none')
+                loss = (loss * batch.node_norm)[batch.test_mask].sum()
+            else:
+                out = model(batch.x, batch.edge_index)
+                loss = F.nll_loss(out[batch.test_mask], batch.y[batch.test_mask])
+
+            nodes = batch.test_mask.sum().item()
+            total_loss += loss.item() * nodes
+            total_nodes += nodes
+
+            y_pred = out.argmax(dim=-1).to(device)
+            correct = y_pred.eq(batch.y.to(device))
+
+            acc = correct[batch.test_mask].sum().item() / batch.test_mask.sum().item()
+            total_acc += acc * nodes
+            
+
+        return total_acc / total_nodes
+
+stats = []
+best_val_acc = 0
+for epoch in range(1, args.epochs+1):
+    stat = train()
+    print(f'Epoch: {epoch:02d},  Train: {stat["train_acc"]:.4f}, Val: {stat["val_acc"]:.4f}')
+    stats.append(stat)
+    if stat["val_acc"] > best_val_acc :
+        best_val_acc = stat["val_acc"]
+        best_model = '{}.pkl'.format(epoch)
+        torch.save(model.state_dict(), '{}.pkl'.format(epoch))
+
+pd.DataFrame(stats).to_csv(f'{args.dataset}_{args.model}_{args.batch_size}_{args.loader}.csv')
+
+model.load_state_dict(torch.load(best_model))
+print("test accuracy : " + str(traintest()))
-- 
GitLab