From 25da06561ba7146b2210eeaa6f4723d2d3582508 Mon Sep 17 00:00:00 2001
From: Ikenna Oluigbo <ikenna-victor.oluigbo@etu.univ-lyon1.fr>
Date: Wed, 7 Dec 2022 03:00:42 +0000
Subject: [PATCH] Train node corpus and generate embedding

---
 train.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 78 insertions(+)
 create mode 100644 train.py

diff --git a/train.py b/train.py
new file mode 100644
index 0000000..c0b921f
--- /dev/null
+++ b/train.py
@@ -0,0 +1,78 @@
+import argparse
+from gensim.models import Word2Vec 
+from attribwalk import *
+from builder import *
+import networkx as nx
+
+'''Enter input network and features in builder 
+To preserve neighborhood topology, walk type == structure
+To preserve node contexts, walk type == attribute 
+To preserve both neighborhood topology and contextual attributes, walk type == hybrid'''
+
+
+def parse_args():
+    '''
+    Parses arguments.
+    '''
+    parser = argparse.ArgumentParser(description="Run SNEFAN.")
+    
+    parser.add_argument('--output', nargs='?', default='output/cora.emb',
+                        help='Embedding path')            #default='emb/karate.emb'
+
+    parser.add_argument('--dimensions', type=int, default=64,
+                        help='Number of dimensions. Default is 64.')
+
+    parser.add_argument('--walk-length', type=int, default=40,
+                        help='Length of walk per source. Default is 40.')
+
+    parser.add_argument('--num-walks', type=int, default=5,
+                        help='Number of walks per source. Default is 5.')
+
+    parser.add_argument('--window-size', type=int, default=5,
+                        help='Context size for optimization. Default is 5.')
+
+    parser.add_argument('--epochs', default=1, type=int,
+                      help='Number of epochs in SGD')
+    
+    parser.add_argument("--walk-type", nargs = "?", default = "hybrid",
+                    help = "Random walk order... choose either structure or attribute or hybrid")
+
+    parser.add_argument('--workers', type=int, default=8,
+                        help='Number of parallel workers. Default is 8.')
+
+    parser.add_argument('--min-count', type=int, default=0,
+                        help='Minimum count of Training words. Default is 0.')
+    
+    parser.add_argument('--sg', type=int, default=1,
+                        help='Training Algorithm. CBOW=0,SkipGram=1. Default is 1.')
+    
+    return parser.parse_args()
+
+
+def learn_embeddings(walks):
+    '''
+    Learn embeddings by optimizing the Skipgram objective using SGD.
+    '''
+    walks = [list(map(str, walk)) for walk in walks]
+    print("Training Node Corpus...")
+    model = Word2Vec(walks, vector_size=args.dimensions, window=args.window_size, 
+                     min_count=args.min_count, sg=args.sg, workers=args.workers, epochs=args.epochs,  
+                    sample=1e-5, alpha=0.25, min_alpha=0.01, negative=5)
+    print("Saving Embeddings...")
+    model.wv.save_word2vec_format(args.output)
+    
+    return model
+
+
+def main(args):
+    G = build_graph()
+    walks = ATTRIB_NEIGH(args.num_walks, args.walk_length, args.walk_type)
+    learn_embeddings(walks)
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)
+    
+
+
-- 
GitLab