Skip to content

Commit

Permalink
Shortened training time for various unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Jun 11, 2019
1 parent 35f3fe9 commit f017562
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
4 changes: 4 additions & 0 deletions tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_train_project(self):
sv = Server()
sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1])
sv.create_project('test_project', 2, 2)
sv._vars['epochs'] = 100
sv.train()
for pool in range(2):
self.assertTrue(exists(join(
Expand All @@ -72,6 +73,7 @@ def test_use_project(self):
sv = Server()
sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1])
sv.create_project('test_project', 2, 2)
sv._vars['epochs'] = 100
sv.train()
results = sv.use()
self.assertEqual(len(results), len(sv._df))
Expand All @@ -84,6 +86,7 @@ def test_save_project(self):
sv = Server()
sv.load_data('cn_model_v1.0.csv', random=True, split=[0.7, 0.2, 0.1])
sv.create_project('test_project', 2, 2)
sv._vars['epochs'] = 100
sv.train()
sv.save_project()
self.assertTrue(exists('test_project.prj'))
Expand All @@ -96,6 +99,7 @@ def test_multiprocessing_train(self):
sv = Server(num_processes=8)
sv.load_data('cn_model_v1.0.csv')
sv.create_project('test_project', 2, 4)
sv._vars['epochs'] = 100
sv.train()
for pool in range(2):
self.assertTrue(exists(join(
Expand Down
1 change: 1 addition & 0 deletions tests/tools/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_predict(self):
sv = Server()
sv.load_data('cn_model_v2.0.csv')
sv.create_project('test_project', 1, 1)
sv._vars['epochs'] = 100
sv.train()
sv.save_project()

Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,11 @@ def test_train_model(self):
df.create_sets(random=True)
pd = df.package_sets()
config = server_utils.default_config()
config['epochs'] = 100
r_squared = server_utils.train_model(
pd, config, 'test', 'r2', filename='test_train.h5'
)
self.assertTrue(exists('test_train.h5'))
self.assertGreaterEqual(r_squared, 0)
self.assertLessEqual(r_squared, 1)
remove('test_train.h5')

def test_use_model(self):
Expand All @@ -161,6 +160,7 @@ def test_use_model(self):
df.create_sets(random=True)
pd = df.package_sets()
config = server_utils.default_config()
config['epochs'] = 100
_ = server_utils.train_model(
pd, config, 'test', 'rmse', filename='test_use.h5'
)
Expand Down

0 comments on commit f017562

Please sign in to comment.