Skip to content

Commit

Permalink
Onnx Support for Sign operation #2641 (#2642)
Browse files Browse the repository at this point in the history
* Support for Sign operation #2641

* Apply rustfmt.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
imihalcea and LaurentMazare authored Nov 26, 2024
1 parent b4deb5c commit 21c6863
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
6 changes: 6 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,12 @@ fn simple_eval_(

values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__Sign.html
"Sign" => {
let input = get(&node.input[0])?;
let output = input.sign()?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
Expand Down
41 changes: 41 additions & 0 deletions candle-onnx/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> {
}
Ok(())
}

#[test]
fn test_sign_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sign".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));

let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(
INPUT_X.to_string(),
Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?,
);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;

let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
assert_eq!(
z.to_dtype(candle::DType::I64)?.to_vec1::<i64>()?.to_vec(),
vec![-1, -1, 0, 1, 1]
);
Ok(())
}

0 comments on commit 21c6863

Please sign in to comment.