Skip to content

Commit

Permalink
Merge branch 'master' into xla/mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson authored Apr 4, 2024
2 parents 2f50db8 + a569ba5 commit 2f71239
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = ["BradenEverson <[email protected]>", "Ebanflo42", "Atlas Dostal
description = "General purpose machine learning crate for neural network development and analysis"
edition = "2021"
license = "MIT"
repository = "https://github.com/BradenEverson/unda"
repository = "https://github.com/unda-ai/unda"
categories = ["science"]
keywords = [
"machine-learning",
Expand Down
File renamed without changes.
21 changes: 21 additions & 0 deletions LICENSE-MIT
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Braden Everson

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ Gradient descent currently can happen both syncronously as stochastic gradient d
#### If open source development is your thing, we at Unda would love additional work on anything that can be implemented, please contact **[email protected]** if you'd like to help out!

# License
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0. This file may not be copied, modified, or distributed except according to those terms.
Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.
2 changes: 1 addition & 1 deletion src/core/graph/constant.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use std::{path::Path, hash::Hash};
use std::{path::Path};
use xla::FromRawBytes;

#[derive(Debug, Clone)]
Expand Down
4 changes: 2 additions & 2 deletions src/core/graph/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use super::*;
impl Context {
fn collect_deps(&self, node: NodeIdentifier) -> Vec<NodeIdentifier> {
if self.dependent_nodes.contains_key(&node) {
return self.dependent_nodes[&node].to_vec();
self.dependent_nodes[&node].to_vec()
} else {
return vec![];
vec![]
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/graph/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl Hash for Operation {

impl PartialEq for Operation {
fn eq(&self, other: &Self) -> bool {
return match (&self, &other) {
match (&self, &other) {
//Order not matering. Ex: 1 + 2 equals 2 + 1, but 1 / 2 doesnt equal 2 /1 so we can
//check these separately
(&Self::Mul(a, b), &Self::Mul(c, d))
Expand Down
106 changes: 45 additions & 61 deletions src/core/graph/subterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,73 +27,57 @@ impl Context {
if visited.contains(&node_id) || modifications >= modification_limit {
continue;
}

if node_map.contains_key(&self.nodes[node_id]) {
if node_map.contains_key(&self.nodes[node_id]) && node_map[&self.nodes[node_id]] != node_id {
self.replace_index(node_id, node_map[&self.nodes[node_id]])?;
modifications += 1;
changed = true;
} else {
node_map.insert(self.nodes[node_id].clone(), node_id);
}

visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim {
node: a,
start: _,
stop: _,
stride: _,
dim: _,
visited.insert(node_id);
//Add operation nodes to the queue
match self.nodes[node_id].operation {
Operation::Add(a, b)
| Operation::Sub(a, b)
| Operation::Mul(a, b)
| Operation::Div(a, b)
| Operation::NotEqual(a, b)
| Operation::Equal(a, b)
| Operation::LessThan(a, b)
| Operation::GreaterThan(a, b)
| Operation::GreaterThanEq(a, b)
| Operation::LessThanEq(a, b)
| Operation::MatMul(a, b)
| Operation::Pow(a, b) => {
to_visit.push(a);
to_visit.push(b);
}
Operation::Neg(a)
| Operation::StopGradient(a)
| Operation::Log(a)
| Operation::Exp(a)
| Operation::TypeCast(a, _)
| Operation::Transpose(a, _)
| Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ }
| Operation::TileInDim { node: a, n_tiles: _, dim: _ }
| Operation::Reshape(a)
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ReduceMean { node, dim: _ }
| Operation::ReduceSum { node, dim: _ } => {
to_visit.push(node);
}
Operation::Select { pred, on_true, on_false } => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
}
Operation::Constant(_) | Operation::Parameter(_) => {}
}
| Operation::TileInDim {
node: a,
n_tiles: _,
dim: _,
}
| Operation::Reshape(a)
| Operation::ZerosLike(a) => {
to_visit.push(a);
}
Operation::ReduceMax { node, dim: _ }
| Operation::ReduceMean { node, dim: _ }
| Operation::ReduceSum { node, dim: _ }
| Operation::ReduceArgmax { node, dim: _ } => {
to_visit.push(node);
}
Operation::OneHot(node) => to_visit.push(node),
Operation::Select {
pred,
on_true,
on_false,
} => {
to_visit.push(pred);
to_visit.push(on_true);
to_visit.push(on_false);
}
Operation::Constant(_) | Operation::Parameter(_) => {}
node_map.insert(self.nodes[node_id].clone(), node_id);
}

}

//Recursive recall if we changed something and modifications are still available
Expand Down
30 changes: 30 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ mod tests {
create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32);
create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32);

#[test]
fn test_large_cte() {
let mut ctx = Context::new();
let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a");
let two = ctx.scalar(2, xla::ElementType::F32).expect("2");

let a_2 = ctx.pow(a, two).expect("a^2");
let a_21 = ctx.pow(a, two).expect("a^2");
let a_22 = ctx.pow(a, two).expect("a^2");
let a_23 = ctx.pow(a, two).expect("a^2");
let a_24 = ctx.pow(a, two).expect("a^2");
let a_25 = ctx.pow(a, two).expect("a^2");
let a_26 = ctx.pow(a, two).expect("a^2");
let a_27 = ctx.pow(a, two).expect("a^2");


let sum1 = ctx.add(a_2, a_21).expect("a^2 + a^2");
let sum2 = ctx.add(a_22, a_23).expect("a^2 + a^2");
let sum3 = ctx.add(a_24, a_25).expect("a^2 + a^2");
let sum4 = ctx.add(a_26, a_27).expect("a^2 + a^2");

let nest_sum1 = ctx.add(sum1, sum2).expect("(a^2 + a^2) + (a^2 + a^2)");
let nest_sum2 = ctx.add(sum3, sum4).expect("(a^2 + a^2) + (a^2 + a^2)");

let res = ctx.add(nest_sum1, nest_sum2).expect("((a^2 + a^2) + (a^2 + a^2)) + ((a^2 + a^2) + (a^2 + a^2))");
let subterm_extract = ctx.extract_subterms(&[res], usize::MAX).expect("CTE");

assert!(subterm_extract);
}

#[test]
fn test_cte_happened() {
let mut ctx = Context::new();
Expand Down

0 comments on commit 2f71239

Please sign in to comment.