Skip to content

Commit

Permalink
apply cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
imihalcea committed Nov 26, 2024
1 parent 80a0ff2 commit 83723a4
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 31 deletions.
2 changes: 1 addition & 1 deletion candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use crate::ops::{registry, ComputeNode};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::{HashMap, HashSet};
use crate::ops::{registry, ComputeNode};

pub type Value = Tensor;

Expand Down
12 changes: 6 additions & 6 deletions candle-onnx/src/ops/compute_node.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use std::collections::HashMap;
use candle::Tensor;
use crate::onnx::NodeProto;
use candle::Tensor;
use std::collections::HashMap;

//This struct is used to represent a node in the computation graph
//The idea is not to use the NodeProto directly in the computation graph
//On a longer term, this can lead to a more optimized representation of the computation graph.
//For now, it is just a wrapper around the NodeProto and the context
pub struct ComputeNode<'a>{
pub struct ComputeNode<'a> {
node_proto: &'a NodeProto,
context: &'a HashMap<String, Tensor>
context: &'a HashMap<String, Tensor>,
}

impl<'a> ComputeNode<'a> {
pub fn new(node_proto: &'a NodeProto, context: &'a HashMap<String, Tensor>) -> Self {
ComputeNode {
node_proto,
context
context,
}
}

Expand All @@ -27,4 +27,4 @@ impl<'a> ComputeNode<'a> {
pub fn get_output(&self, index: usize) -> Option<&String> {
self.node_proto.output.get(index)
}
}
}
2 changes: 1 addition & 1 deletion candle-onnx/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub(crate) mod sign;
pub(crate) mod sign;
15 changes: 9 additions & 6 deletions candle-onnx/src/ops/math/sign.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use crate::ops::compute_node::ComputeNode;
use crate::ops::{OnnxOp, OnnxOpError, OpOutput};
use crate::ops::OnnxOpError::ComputationFailed;
use crate::ops::{OnnxOp, OnnxOpError, OpOutput};

pub(crate) struct Sign;
impl OnnxOp for Sign {
fn eval(&self, node: &ComputeNode) -> Result<OpOutput, OnnxOpError> {
let input = node.get_input(0)
let input = node
.get_input(0)
.ok_or_else(|| ComputationFailed("input 0 not found".to_string()))?;

let output = input.sign()
.map_err(|err| ComputationFailed(format!("{:?}",err)))?;
let output = input
.sign()
.map_err(|err| ComputationFailed(format!("{:?}", err)))?;

let output_name = node.get_output(0)
let output_name = node
.get_output(0)
.ok_or_else(|| ComputationFailed("output 0 not found".to_string()))?;

Ok((output_name.clone(), output))
}
}
}
3 changes: 1 addition & 2 deletions candle-onnx/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ pub use onnxop::{OnnxOp, OnnxOpError, OnnxOpRegistry, OpOutput};
pub mod compute_node;
pub use compute_node::ComputeNode;


mod math;
use math::sign;

pub fn registry() -> Result<OnnxOpRegistry, OnnxOpError> {
let mut registry = OnnxOpRegistry::new();
registry.insert("Sign", Box::new(sign::Sign))?;
Ok(registry)
}
}
23 changes: 9 additions & 14 deletions candle-onnx/src/ops/onnxop.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::ops::ComputeNode;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use crate::ops::ComputeNode;

pub type OpOutput = (String, candle::Tensor);

Expand Down Expand Up @@ -45,11 +45,7 @@ impl OnnxOpRegistry {
ops: HashMap::new(),
}
}
pub fn insert(
&mut self,
name: &str,
op: Box<dyn OnnxOp>,
) -> Result<(), OnnxOpError> {
pub fn insert(&mut self, name: &str, op: Box<dyn OnnxOp>) -> Result<(), OnnxOpError> {
match self.ops.entry(name.to_string()) {
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(op);
Expand All @@ -67,11 +63,10 @@ impl OnnxOpRegistry {
}
}


#[cfg(test)]
mod onnxop_registry_tests {
use candle::Device;
use super::*;
use candle::Device;
#[test]
fn nominal_case() {
//Given
Expand All @@ -84,7 +79,6 @@ mod onnxop_registry_tests {

//Then
assert!(op.is_ok());

}

#[test]
Expand All @@ -97,7 +91,7 @@ mod onnxop_registry_tests {

//Then
match op {
Err(OnnxOpError::UnsupportedOp(_)) => {},
Err(OnnxOpError::UnsupportedOp(_)) => {}
_ => panic!("Expected unsupported op error"),
}
}
Expand All @@ -115,17 +109,18 @@ mod onnxop_registry_tests {

//Then
match result {
Err(OnnxOpError::DuplicateOp(_)) => {},
Err(OnnxOpError::DuplicateOp(_)) => {}
_ => panic!("Expected duplicate op error"),
}
}


struct DummyOp;
impl OnnxOp for DummyOp {
fn eval(&self, _node: &ComputeNode) -> Result<OpOutput, OnnxOpError> {
Ok(("dummy".to_string(), candle::Tensor::new(vec![1u8,1], &Device::Cpu).unwrap()))
Ok((
"dummy".to_string(),
candle::Tensor::new(vec![1u8, 1], &Device::Cpu).unwrap(),
))
}
}

}
2 changes: 1 addition & 1 deletion candle-onnx/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ pub fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
doc_string: "".to_string(),
graph,
}
}
}

0 comments on commit 83723a4

Please sign in to comment.