Skip to content
Snippets Groups Projects
Commit 062a8fdb authored by Fize Jacques's avatar Fize Jacques
Browse files

Made a nice script for baseline model training

parent c0d73c75
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
# coding = utf-8
# -*- coding: utf-8 -*-
import os
import argparse
# BASIC
import pandas as pd
import numpy as np
# ML
# MACHINE LEARNING
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
# ML HELPERS
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from glob import glob
from joblib import dump,load
#PROGRESS BAR
from tqdm import tqdm
from ngram import NGram
parser = argparse.ArgumentParser()
parser.add_argument("dataset_name")
parser.add_argument("inclusion_fn")
parser.add_argument("adjacency_fn")
parser.add_argument("cooc_fn")
args = parser.parse_args()
DATASET_NAME = args.dataset_name
I_FN = args.inclusion_fn
A_FN = args.adjacency_fn
C_FN = args.cooc_fn
OUTPUT_DIR = "outputs"
verbose = False
for fn in [I_FN,A_FN,C_FN]:
if not os.path.exists(fn):
raise FileNotFoundError("{0} does not exists !".format(fn))
classifier_dict = {
"naive-bayes":MultinomialNB(),
"svm":SVC(),
"sgd":SGDClassifier(),
"knn":KNeighborsClassifier(),
"decision-tree": DecisionTreeClassifier(),
"random-forest":RandomForestClassifier()
}
parameters = {
"naive-bayes":[{"alpha":[0,1]}],
"svm":[{"kernel":["rbf","poly"], 'gamma': [1e-1,1e-2,1e-3, 1,10,100]}],
"sgd":[{"penalty":["l1","l2"],"loss":["hinge","modified_huber","log"]}],
"knn":[{"n_neighbors":list(range(4,8)),"p":[1,2]}],
"decision-tree": [{"criterion":["gini","entropy"]}],
"random-forest":[{"criterion":["gini","entropy"],"n_estimators":[10,50,100]}]
}
combinaison = [
[I_FN,C_FN,A_FN],
[I_FN,C_FN],
[C_FN],
[C_FN,A_FN]
]
combinaison_label = [
"PIC",
"IC",
"C",
"PC"
]
for ix, comb in enumerate(combinaison):
df = pd.concat([pd.read_csv(fn,sep="\t").head(500) for fn in comb])
index = NGram(n=4)
data_vectorizer = Pipeline([
('vect', CountVectorizer(tokenizer=index.split)),
('tfidf', TfidfTransformer()),
])
X_train,y_train = (df[df.split == "train"].toponym + " " + df[df.split == "train"].toponym_context).values, df[df.split == "train"].hp_split
X_test,y_test = (df[df.split == "test"].toponym + " " + df[df.split == "test"].toponym_context).values, df[df.split == "test"].hp_split
data_vectorizer.fit((df.toponym + " " + df.toponym_context).values)
dump(data_vectorizer,"{2}/{0}_{1}_vectorizer.pkl".format(DATASET_NAME,combinaison_label[ix],OUTPUT_DIR))
X_train = data_vectorizer.transform(X_train)
X_test = data_vectorizer.transform(X_test)
for CLASSIFIER in tqdm(classifier_dict):
if verbose : print("TRAIN AND EVAL {0}".format(CLASSIFIER))
clf = GridSearchCV(
classifier_dict[CLASSIFIER], parameters[CLASSIFIER], scoring='f1_weighted',n_jobs=-1
)
clf.fit(X_train, y_train)
if verbose : print("Best Parameters : ",clf.best_params_)
y_pred = clf.best_estimator_.predict(X_test)
if verbose : print(classification_report(y_test,y_pred))
dump(clf.best_estimator_,"{0}/{1}_{2}_{3}.pkl".format(OUTPUT_DIR,DATASET_NAME,combinaison_label[ix],CLASSIFIER))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment