forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbucketize_op.cc
70 lines (56 loc) · 2.08 KB
/
bucketize_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "caffe2/operators/bucketize_op.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
template <typename T, class Context>
bool BucketizeOp<T, Context>::RunOnDevice() {
auto& input = Input(X);
CAFFE_ENFORCE_GE(input.dim(), 1);
auto N = input.numel();
auto* output = Output(INDICES, input.sizes(), at::dtype<T>());
const auto* input_data = input.template data<float>();
auto* output_data = output->template mutable_data<T>();
math::Set<T, Context>(output->numel(), 0.0, output_data, &context_);
for (int64_t pos = 0; pos < N; pos++) {
// here we assume the boundary values for each feature are sorted
int64_t bucket_idx =
std::lower_bound(
boundaries_.begin(), boundaries_.end(), input_data[pos]) -
boundaries_.begin();
output_data[pos] = bucket_idx;
}
return true;
};
REGISTER_CPU_OPERATOR(Bucketize, BucketizeOp<int32_t, CPUContext>);
OPERATOR_SCHEMA(Bucketize)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
This operator works as bucketize in tensorflow and digitize
in numpy. It bucketizes the input 'X' based on argument 'boundaries'.
For each value x in input 'data', the operator returns index i given
boundaries[i-1] < x <= boundaries[i].
If values in 'data' are beyond the bounds of boundaries, 0 or
len(boundaries) is returned as appropriate.
The boundaries need to be monotonically increasing.
For example
If data = [2, 4, 1] and boundaries = [0.1, 2.5], then
output = [1, 2, 1]
If data = [[2, 3], [4, 1], [2, 5]] and boundaries = [0.1, 2.5], then
output = [[1, 2], [2, 1], [1, 2]]
)DOC")
.Input(0, "data", "input tensor")
.Output(
0,
"output",
"indices of bins given by boundaries to which each value"
"in data belongs")
.TensorInferenceFunction([](const OperatorDef& /* def */,
const vector<TensorShape>& in) {
vector<TensorShape> out(in);
out[0].set_data_type(TensorProto::INT32);
return out;
})
.Arg("boundaries", "bucketization boundaries");
NO_GRADIENT(BucketizeOp);
} // namespace caffe2