Source code for geocoding.algorithm_selection

import numpy as np
import os
import shutil
import argparse
import pickle
import time
import warnings

from geocoding import clf_utilities as clf_ut, writers as wrtrs
from geocoding.config import Config


warnings.filterwarnings('ignore', 'Solver terminated early.*')


[docs]def main(): """ Implements the second step of the experiment pipeline. Trains a series of \ classifiers based on different configurations in a nested cross \ validation scheme. Returns: None """ # Construct argument parser and parse arguments ap = argparse.ArgumentParser() ap.add_argument('-experiment_path', required=True) args = vars(ap.parse_args()) features_path = os.path.join(Config.base_dir, 'experiments', args['experiment_path'], 'features_extraction_results') t1 = time.time() results = [] for i in range(1, Config.n_folds + 1): print('Fold:', i) fold_path = features_path + f'/fold_{i}' X_train = np.load(fold_path + '/X_train.npy') X_test = np.load(fold_path + '/X_test.npy') y_train = np.load(fold_path + '/y_train.npy') y_test = np.load(fold_path + '/y_test.npy') for clf_name in Config.included_classifiers: print('Classifier:', clf_name) params = {'hparams': None} if clf_name == 'Baseline': encoder = pickle.load(open(os.path.join(features_path, 'encoder.pkl'), 'rb')) baseline_service_enc = encoder.transform([Config.baseline_service])[0] y_pred = np.tile(baseline_service_enc, (len(X_test), len(Config.services))) else: clf = clf_ut.train_classifier(clf_name, X_train, y_train) params['hparams'] = clf.best_params_ pred = clf.predict_proba(X_test) y_pred = np.argsort(-pred, axis=1)[:, :] info = {'fold': i, 'classifier': clf_name} scores = clf_ut.evaluate(y_test, y_pred) # results.append(dict(info, **scores, **params)) results.append(dict(info, **scores)) results_path = os.path.join(Config.base_dir, 'experiments', args['experiment_path'], 'algorithm_selection_results') if os.path.exists(results_path): shutil.rmtree(results_path) os.makedirs(results_path) wrtrs.write_results(results_path, results, 'algorithm_selection') print(f'Algorithm selection done in {time.time() - t1:.3f} sec.')
if __name__ == '__main__': main()