Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out common logic in [Multi]Binning::child #207

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 66 additions & 97 deletions src/correction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,63 @@ namespace {
const std::vector<Variable::Type>& values;
};

std::size_t find_bin_idx(double value,
const std::variant<_UniformBins, _NonUniformBins> &bins_,
const _FlowBehavior &flow,
std::size_t variableIdx,
const char *name)
{
if ( auto *bins = std::get_if<_UniformBins>(&bins_) ) { // uniform binning
if (value < bins->low || value >= bins->high) {
switch (flow) {
case _FlowBehavior::value:
return bins->n; // the default value is stored at the end of the content array, after the last bin
case _FlowBehavior::clamp:
return value < bins->low ? 0 : bins->n - 1; // assuming we always have at least 1 bin
case _FlowBehavior::error:
const std::string belowOrAbove = value < bins->low ? "below" : "above";
auto msg = "Index " + belowOrAbove + " bounds in " + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value);
throw std::runtime_error(std::move(msg));
}
}

std::size_t binIdx = bins->n * ((value - bins->low) / (bins->high - bins->low));
return binIdx;
}

// otherwise we have non-uniform binning
using namespace std::string_literals;
const auto bins = std::get<_NonUniformBins>(bins_);

auto it = std::upper_bound(std::begin(bins), std::end(bins), value);
if ( it == std::begin(bins) ) { // underflow
if ( flow == _FlowBehavior::value ) {
return bins.size() - 1; // the default value is stored at the end of the content array, after the last bin
}
else if ( flow == _FlowBehavior::error ) {
throw std::runtime_error("Index below bounds in "s + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value));
}
else { // clamp
it++;
}
}
else if ( it == std::end(bins) ) { // overflow
if ( flow == _FlowBehavior::value ) {
return bins.size() - 1;
}
else if ( flow == _FlowBehavior::error ) {
throw std::runtime_error("Index above bounds in "s + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value));
}
else { // clamp
it--;
}
}

// -1 because upper_bound returns the edge _after_ the bin we are interested in
const std::size_t binIdx = std::distance(std::begin(bins), it) - 1;
return binIdx;
}

size_t input_index(const std::string_view name, const std::vector<Variable> &inputs) {
size_t idx = 0;
for (const auto& var : inputs) {
Expand All @@ -154,7 +211,7 @@ namespace {
}
throw std::runtime_error("Error: could not find variable " + std::string(name) + " in inputs");
}
}
} // end of anonymous namespace

Variable::Variable(const JSONObject& json) :
name_(json.getRequired<const char *>("name")),
Expand Down Expand Up @@ -406,61 +463,15 @@ Binning::Binning(const JSONObject& json, const Correction& context)
}

// set bin contents
contents_.push_back(std::move(default_value));
for (size_t i=0; i < content.Size(); ++i)
contents_.push_back(resolve_content(content[i], context));
contents_.push_back(std::move(default_value));
}

const Content& Binning::child(const std::vector<Variable::Type>& values) const {
double value = std::get<double>(values[variableIdx_]);

if ( auto *bins = std::get_if<_UniformBins>(&bins_) ) { // uniform binning
std::size_t binIdx = bins->n * ((value - bins->low) / (bins->high - bins->low));
if (value < bins->low || value >= bins->high) {
switch (flow_) {
case _FlowBehavior::value:
return contents_[0u]; // the default value
case _FlowBehavior::clamp:
binIdx = value < bins->low ? 0 : bins->n - 1; // assuming we always have at least 1 bin
break;
case _FlowBehavior::error:
const std::string belowOrAbove = value < bins->low ? "below" : "above";
const auto msg = "Index " + belowOrAbove + " bounds in Binning for input argument " + std::to_string(variableIdx_) + " value: " + std::to_string(value);
throw std::runtime_error(std::move(msg));
}
}

return contents_[binIdx + 1u]; // skipping the default value at index 0
}

// otherwise we have non-uniform binning
const auto bins = std::get<_NonUniformBins>(bins_);

auto it = std::upper_bound(std::begin(bins), std::end(bins), value);
if ( it == std::begin(bins) ) {
if ( flow_ == _FlowBehavior::value ) {
// default value already at std::begin
}
else if ( flow_ == _FlowBehavior::error ) {
throw std::runtime_error("Index below bounds in Binning for input argument " + std::to_string(variableIdx_) + " value: " + std::to_string(value));
}
else { // clamp
it++;
}
}
else if ( it == std::end(bins) ) {
if ( flow_ == _FlowBehavior::value ) {
it = std::begin(bins);
}
else if ( flow_ == _FlowBehavior::error ) {
throw std::runtime_error("Index above bounds in Binning for input argument " + std::to_string(variableIdx_) + " value: " + std::to_string(value));
}
else { // clamp
it--;
}
}

return contents_[std::distance(std::begin(bins), it)];
std::size_t binIdx = find_bin_idx(value, bins_, flow_, variableIdx_, "Binning");
return contents_[binIdx];
}

MultiBinning::MultiBinning(const JSONObject& json, const Correction& context)
Expand Down Expand Up @@ -529,60 +540,18 @@ MultiBinning::MultiBinning(const JSONObject& json, const Correction& context)
}
}

// TODO factor out logic in common with Binning::child.
// One notable difference is that MultiBinning stores the default value at the end of content_ instead of the beginning.
const Content& MultiBinning::child(const std::vector<Variable::Type>& values) const {
size_t idx {0};
size_t localidx {0};
size_t dim {0};

for (const auto& [variableIdx, stride, edgesVariant] : axes_) {
double value = std::get<double>(values[variableIdx]);

if ( auto *bins = std::get_if<_UniformBins>(&edgesVariant) ) { // uniform bins
std::size_t binIdx = bins->n * ((value - bins->low) / (bins->high - bins->low));
if (value < bins->low || value >= bins->high) {
switch (flow_) {
case _FlowBehavior::value:
return content_.back(); // the default value
case _FlowBehavior::clamp:
binIdx = value < bins->low ? 0 : bins->n - 1; // assuming we always have at least 1 bin
break;
case _FlowBehavior::error:
const std::string belowOrAbove = value < bins->low ? "below" : "above";
const auto msg = "Index " + belowOrAbove + " bounds in MultiBinning for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value);
throw std::runtime_error(std::move(msg));
}
}
localidx = binIdx;
} else { // non-uniform bins
const auto edges = std::get<_NonUniformBins>(edgesVariant);
auto it = std::upper_bound(std::begin(edges), std::end(edges), value);
if ( it == std::begin(edges) ) {
if ( flow_ == _FlowBehavior::value ) {
return *content_.rbegin();
}
else if ( flow_ == _FlowBehavior::error ) {
throw std::runtime_error("Index below bounds in MultiBinning for input argument " + std::to_string(variableIdx) + " val: " + std::to_string(value));
}
else { // clamp
it++;
}
}
else if ( it == std::end(edges) ) {
if ( flow_ == _FlowBehavior::value ) {
return content_.back();
}
else if ( flow_ == _FlowBehavior::error ) {
throw std::runtime_error("Index above bounds in MultiBinning input argument" + std::to_string(variableIdx) + " val: " + std::to_string(value));
}
else { // clamp
it--;
}
}
localidx = std::distance(std::begin(edges), it) - 1;
}

localidx = find_bin_idx(value, edgesVariant, flow_, variableIdx, "MultiBinning");
if ( localidx == nbins(dim) ) // find_bin_idx is indicating we need to return the default value
return content_.back();
idx += localidx * stride;
++dim;
}

return content_.at(idx);
Expand Down
Loading