forked from sebastiankmiec/NinaTools
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial, minimal ninaeval Python package added.
- Loading branch information
sebastiankmiec
committed
Nov 27, 2018
0 parents
commit 1c44f61
Showing
13 changed files
with
549 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.pyc | ||
*.swp | ||
all_data/ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from ninaeval.config import config_parser, config_setup | ||
from ninaeval.utils.nina_data import NinaDataParser, BaselineDataset | ||
|
||
DATA_PATH = "all_data/" | ||
|
||
def main(): | ||
|
||
# Reads JSON file via --json, or command line arguments: | ||
config_param = config_parser.parse_config() | ||
|
||
# Basic setup: | ||
classifier = config_setup.get_model(config_param.model)() | ||
feat_extractor = config_setup.get_feat_extract(config_param.features)() | ||
|
||
# Generate a dataset, if necessary: | ||
print("Checking for existing features extracted...") | ||
dataset = BaselineDataset(DATA_PATH, feat_extractor) | ||
|
||
if not dataset.load_dataset(): | ||
print("Loading Ninapro data from processed directory...") | ||
data_parser = NinaDataParser(DATA_PATH) | ||
loaded_nina = data_parser.load_processed_data() | ||
|
||
print("Extracting dataset features for training, and testing...") | ||
dataset.create_dataset(loaded_nina) | ||
|
||
# Train on the training dataset: | ||
print("Training classifier on training dataset...") | ||
classifier.train_model(dataset.train_features, dataset.train_labels) | ||
|
||
# Classify the testing dataset: | ||
print("Testing classifier on testing dataset...") | ||
print(classifier.perform_inference(dataset.test_features, dataset.test_labels)) | ||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from optparse import OptionParser | ||
from argparse import ArgumentParser | ||
import json | ||
|
||
# | ||
# Model/Feature abbreviations and choices. | ||
# | ||
model_choices = {"rf": "RandomForest", "svm": "SupportVectorMachine", "conv": "ConvolutionalNN"} | ||
feature_choices = {"rms": "RMS", "ts": "TimeStatistics", "dwt": "DiscreteWaveletTransform", "hist": "HistogramBins", | ||
"all": "AllFeatures", "scat": "Scattering"} | ||
|
||
######################################################################################################################## | ||
######################################################################################################################## | ||
|
||
|
||
def init_parser(): | ||
''' | ||
Create a parser to read command line or JSON file configuration parameters. | ||
:return: ArgumentParser | ||
''' | ||
|
||
parser = OptionParser() | ||
|
||
# Model choices | ||
# | ||
parser.add_option('--model', action='store', type='choice', default='rf', | ||
choices=list(model_choices.keys())) | ||
|
||
parser.add_option('--features', action='store', type='choice', default='rms', | ||
choices=list(feature_choices.keys())) | ||
|
||
# Actions to perform | ||
# | ||
parser.add_option('--action', action='store', type='choice', default='train', | ||
choices=['train', 'validate', 'test']) | ||
|
||
parser.add_option('--data', action='store', type='choice', default='v1', | ||
choices=['baseline', 'v1']) | ||
|
||
# | ||
# Model training settings | ||
# | ||
parser.add_option('--rf_trees', action='store', type='int', default='128') | ||
|
||
return parser | ||
|
||
|
||
def json_to_string(json_path): | ||
""" | ||
Helper function (parse_config): Parses a JSON file. | ||
:param json_path: Path to a JSON file. | ||
:return: list of JSON file arguments, in command-line format | ||
""" | ||
|
||
if not json_path.endswith(".json"): | ||
raise ValueError('Expecting a .json file passed to --json argument.') | ||
|
||
try: | ||
f = open(json_path) | ||
raw_json = f.read().replace("\n", "") | ||
json_dict = json.loads(raw_json) | ||
except Exception as e: | ||
print("Invalid JSON file passed to --json.") | ||
print(e) | ||
exit() | ||
|
||
# Convert JSON dict to a list of command line arguments: | ||
command_list = [] | ||
json_keys = list(json_dict.keys()) | ||
|
||
for key in json_keys: | ||
|
||
cur_command = "" | ||
if not "--" in key: | ||
cur_command += "--" | ||
|
||
cur_command += key + "=" + json_dict[key] | ||
command_list.append(cur_command) | ||
|
||
return command_list | ||
|
||
|
||
|
||
def parse_config(): | ||
''' | ||
Checks for --json for a JSON file, otherwise, reads command line arguments. | ||
:return: A dictionary of configuration options | ||
''' | ||
|
||
# Check for "--json" argument first; | ||
json_parser = ArgumentParser() | ||
json_parser.add_argument("--json", action="store", dest="JSON_PATH") | ||
json_path, _ = json_parser.parse_known_args() | ||
json_path = json_path.JSON_PATH | ||
|
||
if json_path: | ||
argparser = init_parser() | ||
command_list = json_to_string(json_path) | ||
options, _ = argparser.parse_args(args=command_list) | ||
|
||
else: | ||
argparser = init_parser() | ||
options, _ = argparser.parse_args() | ||
|
||
return options |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from ninaeval.models import baseline_model | ||
from ninaeval.config.config_parser import model_choices, feature_choices | ||
|
||
def get_model(model_abbrev): | ||
""" | ||
:param model_abbrev: An abbreviated model name (config_parser.py). | ||
:return: ClassifierModel | ||
""" | ||
|
||
model = None | ||
try: | ||
model = getattr(baseline_model, model_choices[model_abbrev]) | ||
except AttributeError: | ||
pass | ||
|
||
return model | ||
|
||
def get_feat_extract(features_abbrev): | ||
""" | ||
:param features_abbrev: An abbreviated model name (config_parser.py). | ||
:return: FeatureExtractor | ||
""" | ||
|
||
feat_ext = None | ||
try: | ||
feat_ext = getattr(baseline_model, feature_choices[features_abbrev]) | ||
except AttributeError: | ||
pass | ||
|
||
return feat_ext |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from ninaeval.models.model import ClassifierModel, FeatureExtractor | ||
import numpy as np | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
# | ||
# Baseline Classifiers | ||
# | ||
class RandomForest(ClassifierModel): | ||
|
||
num_trees = 128 | ||
|
||
def __init__(self): | ||
self.classifier = RandomForestClassifier(n_estimators=self.num_trees) | ||
|
||
def train_model(self, train_features, train_labels): | ||
self.classifier.fit(train_features, train_labels) | ||
|
||
def perform_inference(self, test_features, test_labels): | ||
predictions = self.classifier.predict(test_features) | ||
return self.classifier_accuracy(predictions, test_labels) | ||
|
||
# | ||
# Baseline Feature Extractors | ||
# | ||
class RMS(FeatureExtractor): | ||
|
||
def extract_feature_point(self, raw_samples): | ||
return np.sqrt(np.mean(np.square(raw_samples), axis=0)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class ClassifierModel(ABC): | ||
|
||
@abstractmethod | ||
def train_model(self, train_features, train_labels): | ||
pass | ||
|
||
@abstractmethod | ||
def perform_inference(self, test_features, test_labels): | ||
""" | ||
Given test features and labels, compute predictions and classifier accuracy, | ||
:param test_features:Features from the test split. | ||
:param test_labels: Labels from the test split. | ||
:return: Classifier accuracy from 0 ~ 1.0. | ||
""" | ||
pass | ||
|
||
def classifier_accuracy(self, predictions, test_labels): | ||
errors = predictions == test_labels | ||
acc_rate = len([x for x in errors if (x == True)]) / len(errors) | ||
return acc_rate | ||
|
||
|
||
class FeatureExtractor(ABC): | ||
|
||
@abstractmethod | ||
def extract_feature_point(self, raw_samples): | ||
""" | ||
:param raw_samples: A window of emg samples. | ||
:return: A single feature point. | ||
""" | ||
pass |
Empty file.
Oops, something went wrong.