Skip to content

Commit

Permalink
Fix GetItem (#1114)
Browse files Browse the repository at this point in the history
* fix runtime

* Apply code-format changes

---------

Co-authored-by: FusionBolt <[email protected]>
  • Loading branch information
FusionBolt and FusionBolt authored Oct 30, 2023
1 parent e65d9b1 commit 390b827
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/Native/src/kernels/stackvm/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,25 @@ result<value_t> nncase::kernels::stackvm::get_item(
#undef RETURN_RESULT
return err(std::errc::not_supported);
}

if (input_tensor->shape().size() == 2 && begins_value.size() == 1) {
auto get_item_index = begins_value[0];
auto out_shape = dims_t{input_tensor->shape()[1]};
try_output(out_mem, output, input_tensor->dtype(), out_shape);
auto size = input_tensor->shape()[1];
#define RETURN_RESULT(_in_type) \
if (cmp_type<_in_type>(input_tensor->dtype())) { \
for (int i = 0; i < size; ++i) { \
OUT_CAST(_in_type, out_mem) \
[i] = IN_CAST(_in_type, in_mem)[get_item_index * size + i]; \
} \
return ok(output); \
}
RETURN_RESULT_SELECT(RETURN_RESULT);
#undef RETURN_RESULT
return err(std::errc::not_supported);
}

auto n = begins_value.size();
auto in_shape = input_tensor->shape();
auto ends_value = axes_t(n, 0);
Expand All @@ -421,6 +440,8 @@ result<value_t> nncase::kernels::stackvm::get_item(
out_mem, in_shape, input_tensor->strides(),
output_tensor->strides(), begin_values, end_values,
strides_values, context);
output = tensor_reshape(output_tensor,
dims_t(out_shape.begin() + n, out_shape.end()));
KERNEL_FINISH;
}
}
Expand Down Expand Up @@ -769,6 +790,18 @@ result<value_t> nncase::kernels::stackvm::bucket_pad(
try_dims_v(shape);
auto in_tensor = input.as<tensor>().expect("input is not a tensor");
auto in_shape = in_tensor->shape();
if (compute_size(in_shape) > compute_size(shape_value)) {
std::cout << "in shape" << std::endl;
for (int i = 0; i < in_shape.size(); ++i) {
std::cout << in_shape[i] << std::endl;
}
std::cout << "shape_value shape" << std::endl;
for (int i = 0; i < shape_value.size(); ++i) {
std::cout << shape_value[i] << std::endl;
}
return err(std::errc::invalid_argument);
}

auto paddings = std::vector<int>(8);
auto rank = shape_value.size();
for (int i = 0; i < rank; ++i) {
Expand Down Expand Up @@ -1103,6 +1136,7 @@ nncase::kernels::stackvm::squeeze(value_t input, value_t dim, value_t output,
try_var(in_tensor, input.as<tensor>());
auto in_shape = in_tensor->shape();
not_impl_no_contiguous(in_tensor);
// todo: dim is scalar
try_positive_axes(axes, dim, in_tensor->shape().size());
auto new_shape = squeeze_infer_shape(in_shape, axes);
output = tensor_reshape(in_tensor, new_shape);
Expand Down

0 comments on commit 390b827

Please sign in to comment.