Skip to content

Commit

Permalink
Round transform rule output for integers (#54)
Browse files Browse the repository at this point in the history
* Round transform rule output for integers

* Do something to trigger tests
  • Loading branch information
nsmith- authored Mar 18, 2021
1 parent 1012812 commit 9c8d10f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/correction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <optional>
#include <algorithm>
#include <stdexcept>
#include <cmath>
#include "correction.h"

using namespace correction;
Expand Down Expand Up @@ -160,7 +161,7 @@ double Transform::evaluate(const std::vector<Variable::Type>& values) const {
v = vnew;
}
else if ( std::holds_alternative<int>(v) ) {
v = (int) vnew;
v = (int) std::round(vnew);
}
else {
throw std::logic_error("I should not have ever seen a string");
Expand Down
32 changes: 5 additions & 27 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ def test_transform():
)
corr = cset["test"]
assert corr.evaluate(0.5) == 0.1
assert corr.evaluate(1.5) == 0.1

cset = wrap(
schema.Correction(
Expand All @@ -893,6 +894,8 @@ def test_transform():
{"key": 0, "value": 0},
{"key": 1, "value": 4},
{"key": 2, "value": 0},
{"key": 9, "value": 3.000001},
{"key": 10, "value": 2.999999},
],
),
content=schema.Category(
Expand All @@ -913,30 +916,5 @@ def test_transform():
assert corr.evaluate(2) == 0.0
with pytest.raises(IndexError):
corr.evaluate(3)


def evaluate(expr, variables, parameters):
cset = {
"schema_version": 2,
"corrections": [
{
"name": "test",
"version": 1,
"inputs": [
{"name": vname, "type": "real"}
for vname, _ in zip("xyzt", variables)
],
"output": {"name": "f", "type": "real"},
"data": {
"nodetype": "formula",
"expression": expr,
"parser": "TFormula",
"variables": [vname for vname, _ in zip("xyzt", variables)],
"parameters": parameters or None,
},
}
],
}
schema.CorrectionSet.parse_obj(cset)
corr = core.CorrectionSet.from_string(json.dumps(cset))["test"]
return corr.evaluate(*variables)
assert corr.evaluate(9) == 0.1
assert corr.evaluate(10) == 0.1

0 comments on commit 9c8d10f

Please sign in to comment.