Skip to content

Commit

Permalink
Add is_zero impl for all types minus complex(won't build yet)
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Mar 9, 2024
1 parent 99880c8 commit 3502123
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
19 changes: 18 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ strum = "0.26"
strum_macros = "0.26"
xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "dev" }
thiserror = "1"
half = "2.4.0"

[features]
default = ["util"]
Expand Down
50 changes: 46 additions & 4 deletions src/core/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use slotmap::new_key_type;
use xla::Literal;
use std::{fmt::{Display, Formatter, Result}, error::Error};

use half::bf16;
use half::f16;

/// A node in the compute graph
#[derive(Clone, Debug)]
pub struct Node {
Expand Down Expand Up @@ -40,6 +43,23 @@ impl Node {
return match &self.operation {
Operation::Constant(a) => {
match a.value.element_type()? {
xla::ElementType::Pred => {
let data_ref = a.value.to_vec::<u8>()?;
for i in data_ref.iter() {
if !i.is_zero() {
return Ok(false);
}
}

},
xla::ElementType::F16 => {
let data_ref = a.value.to_vec::<f16>()?;
for i in data_ref.iter() {
if *i != f16::ZERO {
return Ok(false);
}
}
},
xla::ElementType::F32 => {
let data_ref = a.value.to_vec::<f32>()?;
for i in data_ref.iter() {
Expand All @@ -57,6 +77,15 @@ impl Node {
}
},

xla::ElementType::U8 => {
let data_ref = a.value.to_vec::<u8>()?;
for i in data_ref.iter() {
if !i.is_zero() {
return Ok(false);
}
}
},

xla::ElementType::U16 => {
let data_ref = a.value.to_vec::<u16>()?;
for i in data_ref.iter() {
Expand All @@ -81,6 +110,16 @@ impl Node {
}
}
},

xla::ElementType::S8 => {
let data_ref = a.value.to_vec::<i8>()?;
for i in data_ref.iter() {
if !i.is_zero() {
return Ok(false);
}
}
},

xla::ElementType::S16 => {
let data_ref = a.value.to_vec::<i16>()?;
for i in data_ref.iter() {
Expand Down Expand Up @@ -114,10 +153,13 @@ impl Node {
return Ok(false);
},
xla::ElementType::Bf16 => {
//TODO
return Ok(false);
},
_ => { return Ok(false); }
let data_ref: Vec<bf16> = a.value.to_vec()?;
for i in data_ref.iter() {
if *i != bf16::ZERO {
return Ok(false);
}
}
}
}

Ok(true)
Expand Down

0 comments on commit 3502123

Please sign in to comment.