Skip to content

Commit

Permalink
Error on empty set, backward compatibility (13 and below) with 'axes'
Browse files Browse the repository at this point in the history
  • Loading branch information
AnubhabB committed Oct 15, 2024
1 parent 4f8cd5c commit b40491a
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 20 deletions.
40 changes: 32 additions & 8 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,17 +1190,26 @@ fn simple_eval_(
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
// Version 18 impl
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.to_vec1::<i64>()?
.iter()
.map(|a| {
let axis = if *a < 0 {
Expand Down Expand Up @@ -1230,7 +1239,10 @@ fn simple_eval_(
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// Numpy yields error: ValueError: zero-size array to reduction operation maximum which has no identity
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
Expand Down Expand Up @@ -1286,18 +1298,27 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-20
// Version 18 impl
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.to_vec1::<i64>()?
.iter()
.map(|a| {
let axis = if *a < 0 {
Expand Down Expand Up @@ -1326,8 +1347,11 @@ fn simple_eval_(

// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// Numpy yields error: ValueError: zero-size array to reduction operation maximum which has no identity
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
Expand Down
Loading

0 comments on commit b40491a

Please sign in to comment.