-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_distribution.py
111 lines (81 loc) · 3.89 KB
/
test_distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import unittest
from config_sedimentdrift import MartiniConf
from sedimentdrift import SedimentDrift
from bathymetry import Bathymetry
from probability_distribution_map_v2 import SedimentDistribution
from config_plot import ConfigPlot
import pandas as pd
import numpy as np
import xarray as xr
class TestSedimentDistribution(unittest.TestCase):
def setUp(self):
super(TestSedimentDistribution, self).setUp()
self.distribution = SedimentDistribution()
def tearDown(self) -> None:
super(TestSedimentDistribution, self).tearDown()
self.config = None
def test_setup_returns_config_object(self):
self.assertIsNotNone(self.distribution.config_sedimentdrift)
self.assertIsInstance(self.distribution.config_sedimentdrift, MartiniConf)
def test_init_sediment_distribution_class(self):
bathymetry = Bathymetry(self.distribution.config_sedimentdrift)
self.assertIsInstance(self.distribution.bath, Bathymetry)
self.assertIsNotNone(bathymetry.config)
def test_index_of_last_valid_position(self):
time = np.arange(0, 10, 1)
trajectory = np.ones(2)
test_var = np.ones((len(time), len(trajectory)))
test_var[0:3, 0] = np.nan
test_var[0:5, 1] = np.nan
test_var[-2:, 1] = np.nan
# The test array looks like this:
# [[nan nan nan 1. 1. 1. 1. 1. 1. 1.]
# [nan nan nan nan nan 1. 1. 1. nan nan]]
da = xr.DataArray(data=test_var, dims=["time", "trajectory"],
coords=[time, trajectory])
self.assertIsInstance(da, xr.DataArray)
ds = da.to_dataset(name="test_var")
self.assertIsInstance(ds, xr.Dataset)
res = self.distribution.get_indexes_of_last_valid_position(ds, var_name="test_var")
index_array_expected = [9, 7]
np.testing.assert_array_equal(res, index_array_expected)
def test_index_of_last_valid_position_all_valid(self):
time = np.arange(0, 10, 1)
trajectory = np.ones(2)
test_var = np.ones((len(time), len(trajectory)))
da = xr.DataArray(data=test_var, dims=["time", "trajectory"],
coords=[time, trajectory])
self.assertIsInstance(da, xr.DataArray)
ds = da.to_dataset(name="test_var")
self.assertIsInstance(ds, xr.Dataset)
res = self.distribution.get_indexes_of_last_valid_position(ds, var_name="test_var")
index_array_expected = [9, 9]
np.testing.assert_array_equal(res, index_array_expected)
def test_index_of_last_valid_position_when_non_valid(self):
time = np.arange(0, 10, 1)
trajectory = np.ones(2)
test_var = np.ones((len(time), len(trajectory)))
test_var[:,:]=np.nan
da = xr.DataArray(data=test_var, dims=["time", "trajectory"],
coords=[time, trajectory])
self.assertIsInstance(da, xr.DataArray)
ds = da.to_dataset(name="test_var")
self.assertIsInstance(ds, xr.Dataset)
res = self.distribution.get_indexes_of_last_valid_position(ds, var_name="test_var")
index_array_expected = []
np.testing.assert_array_equal(res, index_array_expected)
def test_index_of_last_valid_position_when_one_valid_and_one_not_valid(self):
time = np.arange(0, 10, 1)
trajectory = np.ones(2)
test_var = np.ones((len(time), len(trajectory)))
test_var[:,0]=np.nan
da = xr.DataArray(data=test_var, dims=["time", "trajectory"],
coords=[time, trajectory])
self.assertIsInstance(da, xr.DataArray)
ds = da.to_dataset(name="test_var")
self.assertIsInstance(ds, xr.Dataset)
res = self.distribution.get_indexes_of_last_valid_position(ds, var_name="test_var")
index_array_expected = [9]
np.testing.assert_array_equal(res, index_array_expected)
if __name__ == "__main__":
unittest.main()