From ddbbb2237ecd64448162ea7ed08843ee55ce5822 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 8 Mar 2024 17:13:38 +0100 Subject: [PATCH] fixed issues with keepdims --- src/core/graph/autodiff.rs | 80 +++++---- src/core/graph/compile.rs | 9 +- src/core/graph/consteval.rs | 319 +++++++++++++++++++++++------------- src/core/graph/context.rs | 9 +- src/core/graph/math.rs | 78 +++++++-- src/core/graph/operation.rs | 6 +- src/core/graph/tests.rs | 4 +- 7 files changed, 312 insertions(+), 193 deletions(-) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 223d54b..ab2a465 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -106,7 +106,7 @@ impl Context { } Operation::Reshape(node) => { - let next_pullback = self.diff(node, dependent_node)?; + let next_pullback = self.diff(output, dependent_node)?; let node_sh = self.nodes[node].shape.clone(); let pullback = self.reshape(next_pullback, node_sh)?; dependent_pullbacks.push(pullback); @@ -135,11 +135,14 @@ impl Context { Operation::Mul(a, b) => { let next_pullback = self.diff(output, dependent_node)?; - if a == with_respect_to { + if a == b && a == with_respect_to { + let two = self.scalar(2, wrt_dtype)?; + let mul = self.mul(two, a)?; + dependent_pullbacks.push(self.mul(mul, next_pullback)?); + } else if a == with_respect_to { let mul = self.mul(next_pullback, a)?; dependent_pullbacks.push(mul); - } - if b == with_respect_to { + } else if b == with_respect_to { let mul = self.mul(a, next_pullback)?; dependent_pullbacks.push(mul); } @@ -166,7 +169,7 @@ impl Context { } Operation::TileInDim { node, n_tiles, dim } => { - let next_pullback = self.diff(node, dependent_node)?; + let next_pullback = self.diff(output, dependent_node)?; let mut new_sizes = SmallVec::new(); for i in (0..self.nodes[node].shape.ndims()).rev() { @@ -178,7 +181,7 @@ impl Context { let reshaped_pullback = self.reshape(next_pullback, Shape { sizes: new_sizes })?; - dependent_pullbacks.push(self.reduce_sum(reshaped_pullback, dim, false)); + dependent_pullbacks.push(self.reduce_sum(reshaped_pullback, dim, false)?); } Operation::SliceInDim { @@ -200,7 +203,6 @@ impl Context { Operation::ReduceMax { node, dim, - keepdims, } => { if self.gradient_is_dependent(node, dependent_node) { panic!( @@ -214,52 +216,46 @@ impl Context { Operation::ReduceSum { node, dim, - keepdims, } => { let next_pullback = self.diff(output, dependent_node)?; let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; - let tiled_pullback = if keepdims { - self.tile_in_dim(next_pullback, n_tiles, dim)? - } else { - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[node].shape.ndims()).rev() { - new_sizes.push(self.nodes[node].shape.sizes[i]); - if i as i64 == dim { - new_sizes.push(1u32); - } + + let mut new_sizes = SmallVec::new(); + for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { + new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); + if i as i64 == dim { + new_sizes.push(1u32); } - let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; - self.tile_in_dim(reshaped_pullback, n_tiles, dim)? - }; + } + if self.nodes[next_pullback].shape.ndims() == 0 { + new_sizes.push(1u32); + } + let reshaped_pullback = + self.reshape(next_pullback, Shape { sizes: new_sizes })?; + let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; + dependent_pullbacks.push(tiled_pullback); } - Operation::ReduceMean { - node, - dim, - keepdims, - } => { + Operation::ReduceMean { node, dim } => { let next_pullback = self.diff(output, dependent_node)?; let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; - let tiled_pullback = if keepdims { - self.tile_in_dim(next_pullback, n_tiles, dim)? - } else { - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { - new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); - if i as i64 == dim { - new_sizes.push(1u32); - } - } - if self.nodes[next_pullback].shape.ndims() == 0 { + + let mut new_sizes = SmallVec::new(); + for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { + new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); + if i as i64 == dim { new_sizes.push(1u32); } - let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; - self.tile_in_dim(reshaped_pullback, n_tiles, dim)? - }; - let scale = self.scalar(1.0/(n_tiles as f32), self.nodes[node].dtype)?; + } + if self.nodes[next_pullback].shape.ndims() == 0 { + new_sizes.push(1u32); + } + let reshaped_pullback = + self.reshape(next_pullback, Shape { sizes: new_sizes })?; + let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; + + let scale = self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?; let rescaled_pullback = self.mul(scale, tiled_pullback)?; dependent_pullbacks.push(rescaled_pullback); } diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index 446dde2..2a8832e 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -380,11 +380,10 @@ impl Context { Operation::ReduceMax { node, dim, - keepdims, } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = - xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], keepdims)?; + xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], false)?; let xla_id = xla_op_slotmap.insert(xla_op); unda_xla_map.insert(*dependent_op, xla_id); unda_op_queue.push_back(*dependent_op); @@ -394,11 +393,10 @@ impl Context { Operation::ReduceSum { node, dim, - keepdims, } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = - xla_op_slotmap[unda_xla_map[&node]].reduce_sum(&[dim], keepdims)?; + xla_op_slotmap[unda_xla_map[&node]].reduce_sum(&[dim], false)?; let xla_id = xla_op_slotmap.insert(xla_op); unda_xla_map.insert(*dependent_op, xla_id); unda_op_queue.push_back(*dependent_op); @@ -408,11 +406,10 @@ impl Context { Operation::ReduceMean { node, dim, - keepdims, } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = - xla_op_slotmap[unda_xla_map[&node]].reduce_mean(&[dim], keepdims)?; + xla_op_slotmap[unda_xla_map[&node]].reduce_mean(&[dim], false)?; let xla_id = xla_op_slotmap.insert(xla_op); unda_xla_map.insert(*dependent_op, xla_id); unda_op_queue.push_back(*dependent_op); diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index 4364de4..db7fc42 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -5,12 +5,9 @@ use xla::ElementType; use super::*; impl Context { - - fn collect_deps( - &self, - node: NodeIdentifier, - ) -> Vec { - self.dependent_nodes[&node].iter() + fn collect_deps(&self, node: NodeIdentifier) -> Vec { + self.dependent_nodes[&node] + .iter() .map(|node| node.clone()) .collect::>() } @@ -18,8 +15,8 @@ impl Context { fn replace_index( &mut self, to_remove: NodeIdentifier, - rep_with: NodeIdentifier - ) -> Result { + rep_with: NodeIdentifier, + ) -> Result { let mut changed = false; let deps = self.collect_deps(to_remove); @@ -37,8 +34,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Add(a, rep_with); changed = true; } - - }, + } Operation::Sub(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Sub(rep_with, rep_with); @@ -50,8 +46,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Sub(a, rep_with); changed = true; } - - }, + } Operation::Mul(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Mul(rep_with, rep_with); @@ -63,8 +58,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Mul(a, rep_with); changed = true; } - - }, + } Operation::Div(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Div(rep_with, rep_with); @@ -76,8 +70,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Div(a, rep_with); changed = true; } - - }, + } Operation::GreaterThan(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, rep_with); @@ -89,11 +82,12 @@ impl Context { self.nodes[dep_node].operation = Operation::GreaterThan(a, rep_with); changed = true; } - }, + } Operation::GreaterThanEq(a, b) => { if a == b { - self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, rep_with); + self.nodes[dep_node].operation = + Operation::GreaterThanEq(rep_with, rep_with); changed = true; } else if a == to_remove { self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, b); @@ -102,7 +96,7 @@ impl Context { self.nodes[dep_node].operation = Operation::GreaterThanEq(a, rep_with); changed = true; } - }, + } Operation::Equal(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Equal(rep_with, rep_with); @@ -114,7 +108,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Equal(a, rep_with); changed = true; } - }, + } Operation::NotEqual(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::NotEqual(rep_with, rep_with); @@ -126,7 +120,7 @@ impl Context { self.nodes[dep_node].operation = Operation::NotEqual(a, rep_with); changed = true; } - }, + } Operation::LessThan(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::LessThan(rep_with, rep_with); @@ -138,7 +132,7 @@ impl Context { self.nodes[dep_node].operation = Operation::LessThan(a, rep_with); changed = true; } - }, + } Operation::LessThanEq(a, b) => { if a == b { @@ -151,83 +145,131 @@ impl Context { self.nodes[dep_node].operation = Operation::LessThanEq(a, rep_with); changed = true; } - }, - Operation::Constant(_) - | Operation::Parameter(_) => { + } + Operation::Constant(_) | Operation::Parameter(_) => { unreachable!("Constants or Parameters cannot depend on nodes"); - }, + } Operation::StopGradient(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::StopGradient(rep_with); changed = true; } - }, + } Operation::Neg(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::Neg(rep_with); changed = true; } - }, + } Operation::ZerosLike(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::ZerosLike(rep_with); changed = true; } - }, + } Operation::TypeCast(_, t) => { changed = true; self.nodes[dep_node].operation = Operation::TypeCast(rep_with, t) - }, + } Operation::Reshape(_) => { changed = true; self.nodes[dep_node].operation = Operation::Reshape(rep_with) - }, - Operation::Select { pred, on_true, on_false } => { + } + Operation::Select { + pred, + on_true, + on_false, + } => { if pred == to_remove { if pred == on_true { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true: rep_with, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true: rep_with, + on_false, + } } else if pred == on_false { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true, on_false: rep_with } + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true, + on_false: rep_with, + } } else { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true, + on_false, + } } changed = true; } else if on_true == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::Select { pred, on_true: rep_with, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred, + on_true: rep_with, + on_false, + } } else if on_false == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::Select { pred, on_true, on_false: rep_with } + self.nodes[dep_node].operation = Operation::Select { + pred, + on_true, + on_false: rep_with, + } } - }, - Operation::ReduceMax { node, dim, keepdims } => { + } + Operation::ReduceMax { node, dim } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::ReduceMax { node: rep_with, dim, keepdims } + self.nodes[dep_node].operation = Operation::ReduceMax { + node: rep_with, + dim, + } } - }, - Operation::ReduceSum { node, dim, keepdims } => { + } + Operation::ReduceSum { node, dim } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::ReduceSum { node: rep_with, dim, keepdims } + self.nodes[dep_node].operation = Operation::ReduceSum { + node: rep_with, + dim, + } } - }, - Operation::ReduceMean { node, dim, keepdims } => { + } + Operation::ReduceMean { node, dim } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::ReduceMean { node: rep_with, dim, keepdims } + self.nodes[dep_node].operation = Operation::ReduceMean { + node: rep_with, + dim, + } } - }, - Operation::SliceInDim { node, start, stop, stride, dim } => { + } + Operation::SliceInDim { + node, + start, + stop, + stride, + dim, + } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::SliceInDim { node: rep_with, start, stop, stride, dim } + self.nodes[dep_node].operation = Operation::SliceInDim { + node: rep_with, + start, + stop, + stride, + dim, + } } - }, + } Operation::TileInDim { node, n_tiles, dim } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::TileInDim { node: rep_with, n_tiles, dim } + self.nodes[dep_node].operation = Operation::TileInDim { + node: rep_with, + n_tiles, + dim, + } } } } @@ -243,7 +285,7 @@ impl Context { &mut self, input: NodeIdentifier, modification_limit: usize, - ) -> Result { + ) -> Result { if modification_limit == 0 { return Ok(true); } @@ -259,30 +301,28 @@ impl Context { continue; } match self.nodes[node_id].operation { - Operation::Add(a, b) - | Operation::Sub(a, b) => { - if self.nodes[a].is_zero()? { - self.replace_index(node_id, b)?; - modifications += 1; - changed = true; - - } else if self.nodes[b].is_zero()? { - self.replace_index(node_id, a)?; - modifications += 1; - changed = true; - } - //Enqueue the dependent nodes to check both of them for constant - //mul/adding + Operation::Add(a, b) | Operation::Sub(a, b) => { + if self.nodes[a].is_zero()? { + self.replace_index(node_id, b)?; + modifications += 1; + changed = true; + } else if self.nodes[b].is_zero()? { + self.replace_index(node_id, a)?; + modifications += 1; + changed = true; + } + //Enqueue the dependent nodes to check both of them for constant + //mul/adding - //TODO: Once we create a new Node based on the constant propegation, - //use insert_with_key to 'replace existant node' - if self.nodes.get(a).unwrap().is_const().is_none() { - to_visit.push(a); - } - if self.nodes.get(b).unwrap().is_const().is_none() { - to_visit.push(b); - } - }, + //TODO: Once we create a new Node based on the constant propegation, + //use insert_with_key to 'replace existant node' + if self.nodes.get(a).unwrap().is_const().is_none() { + to_visit.push(a); + } + if self.nodes.get(b).unwrap().is_const().is_none() { + to_visit.push(b); + } + } Operation::Mul(a, b) => { if self.nodes[a].is_zero()? { self.replace_index(node_id, a)?; @@ -291,7 +331,8 @@ impl Context { } if let Some(literal) = self.nodes[a].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -305,10 +346,12 @@ impl Context { changed = true; } } + // TODO: Clean this up! Too many cases!! if let Operation::TileInDim { node, n_tiles, dim } = self.nodes[a].operation { if let Some(literal) = self.nodes[node].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -324,7 +367,8 @@ impl Context { } else if let Operation::Reshape(x) = self.nodes[node].operation { if let Some(literal) = self.nodes[x].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -337,17 +381,36 @@ impl Context { modifications += 1; changed = true; } + } else if let Operation::Reshape(y) = self.nodes[x].operation { + if let Some(literal) = self.nodes[y].is_const() { + //Check for mul by 1 + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let mut all_one = true; + floating_literal.iter().for_each(|elem| { + if *elem != 1f32 { + all_one = false; + } + }); + if all_one { + //a is all ones, replace node_id with a + self.replace_index(node_id, b)?; + modifications += 1; + changed = true; + } + } } } } - if self.nodes[b].is_zero()?{ + if self.nodes[b].is_zero()? { self.replace_index(node_id, b)?; modifications += 1; changed = true; } if let Some(literal) = self.nodes[b].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -364,7 +427,8 @@ impl Context { if let Operation::TileInDim { node, n_tiles, dim } = self.nodes[b].operation { if let Some(literal) = self.nodes[node].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -380,7 +444,8 @@ impl Context { } else if let Operation::Reshape(x) = self.nodes[node].operation { if let Some(literal) = self.nodes[x].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -393,6 +458,24 @@ impl Context { modifications += 1; changed = true; } + } else if let Operation::Reshape(y) = self.nodes[x].operation { + if let Some(literal) = self.nodes[y].is_const() { + //Check for mul by 1 + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let mut all_one = true; + floating_literal.iter().for_each(|elem| { + if *elem != 1f32 { + all_one = false; + } + }); + if all_one { + //a is all ones, replace node_id with a + self.replace_index(node_id, a)?; + modifications += 1; + changed = true; + } + } } } } @@ -402,41 +485,40 @@ impl Context { if let None = self.nodes[b].is_const() { to_visit.push(b); } - - }, + } Operation::Neg(a) => { if let None = self.nodes[a].is_const() { to_visit.push(a); } } Operation::GreaterThan(a, b) - | Operation::GreaterThanEq(a, b) - | Operation::LessThan(a, b) - | Operation::LessThanEq(a, b) - | Operation::Equal(a, b) - | Operation::NotEqual(a, b) - | Operation::Div(a, b) - => { - - if let None = self.nodes[a].is_const() { - to_visit.push(a); - } - - if let None = self.nodes[b].is_const() { - to_visit.push(b); - } + | Operation::GreaterThanEq(a, b) + | Operation::LessThan(a, b) + | Operation::LessThanEq(a, b) + | Operation::Equal(a, b) + | Operation::NotEqual(a, b) + | Operation::Div(a, b) => { + if let None = self.nodes[a].is_const() { + to_visit.push(a); + } - }, + if let None = self.nodes[b].is_const() { + to_visit.push(b); + } + } Operation::StopGradient(a) - | Operation::TypeCast(a, _) - | Operation::Reshape(a) - | Operation::ZerosLike(a) - => { + | Operation::TypeCast(a, _) + | Operation::Reshape(a) + | Operation::ZerosLike(a) => { if let None = self.nodes[a].is_const() { to_visit.push(a); } - }, - Operation::Select { pred, on_true, on_false } => { + } + Operation::Select { + pred, + on_true, + on_false, + } => { if let None = self.nodes[pred].is_const() { to_visit.push(pred) } @@ -446,26 +528,31 @@ impl Context { if let None = self.nodes[on_false].is_const() { to_visit.push(on_false) } - }, - Operation::SliceInDim { node, start, stop, stride, dim } => { + } + Operation::SliceInDim { + node, + start, + stop, + stride, + dim, + } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } - }, + } Operation::TileInDim { node, n_tiles, dim } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } - }, - Operation::ReduceMax { node, dim, keepdims } - | Operation::ReduceSum { node, dim, keepdims } - | Operation::ReduceMean { node, dim, keepdims } => { + } + Operation::ReduceMax { node, dim } + | Operation::ReduceSum { node, dim } + | Operation::ReduceMean { node, dim } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } - }, - Operation::Constant(_) - | Operation::Parameter(_) => {} + } + Operation::Constant(_) | Operation::Parameter(_) => {} } visitied.insert(node_id); } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 9dd7b2c..cdee12f 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -139,18 +139,15 @@ impl Context { Operation::ReduceMax { node, dim, - keepdims, - } => format!("ReduceMax {} {} {}", self.to_string(node), dim, keepdims), + } => format!("ReduceMax {} {}", self.to_string(node), dim), Operation::ReduceSum { node, dim, - keepdims, - } => format!("ReduceSum {} {} {}", self.to_string(node), dim, keepdims), + } => format!("ReduceSum {} {}", self.to_string(node), dim), Operation::ReduceMean { node, dim, - keepdims, - } => format!("ReduceMean {} {} {}", self.to_string(node), dim, keepdims), + } => format!("ReduceMean {} {}", self.to_string(node), dim), } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 78ccd9a..136e38f 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -587,12 +587,15 @@ impl Context { node_id } - pub fn reduce_max(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> NodeIdentifier { + pub fn reduce_max( + &mut self, + a: NodeIdentifier, + dim: i64, + keepdims: bool, + ) -> Result { let mut s = Shape::new(); for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim && keepdims { - s.sizes.push(1) - } else { + if d as i64 != dim { s.sizes.push(self.nodes[a].shape.sizes[d]) } } @@ -602,7 +605,6 @@ impl Context { operation: Operation::ReduceMax { node: a, dim: dim, - keepdims: keepdims, }, dtype: self.nodes[a].dtype, }); @@ -610,15 +612,30 @@ impl Context { .entry(a) .or_insert(Vec::new()) .push(node_id); - node_id + if keepdims { + let mut s_keepdim = Shape::new(); + for d in (0..self.nodes[a].shape.ndims()).rev() { + if d as i64 == dim { + s_keepdim.sizes.push(1u32) + } else { + s_keepdim.sizes.push(self.nodes[a].shape.sizes[d]) + } + } + self.reshape(node_id, s_keepdim) + } else { + Ok(node_id) + } } - pub fn reduce_sum(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> NodeIdentifier { + pub fn reduce_sum( + &mut self, + a: NodeIdentifier, + dim: i64, + keepdims: bool, + ) -> Result { let mut s = Shape::new(); for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim && keepdims { - s.sizes.push(1) - } else { + if d as i64 != dim { s.sizes.push(self.nodes[a].shape.sizes[d]) } } @@ -628,7 +645,6 @@ impl Context { operation: Operation::ReduceSum { node: a, dim: dim, - keepdims: keepdims, }, dtype: self.nodes[a].dtype, }); @@ -636,15 +652,30 @@ impl Context { .entry(a) .or_insert(Vec::new()) .push(node_id); - node_id + if keepdims { + let mut s_keepdim = Shape::new(); + for d in (0..self.nodes[a].shape.ndims()).rev() { + if d as i64 == dim { + s_keepdim.sizes.push(1u32) + } else { + s_keepdim.sizes.push(self.nodes[a].shape.sizes[d]) + } + } + self.reshape(node_id, s_keepdim) + } else { + Ok(node_id) + } } - pub fn reduce_mean(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> NodeIdentifier { + pub fn reduce_mean( + &mut self, + a: NodeIdentifier, + dim: i64, + keepdims: bool, + ) -> Result { let mut s = Shape::new(); for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim && keepdims { - s.sizes.push(1) - } else { + if d as i64 != dim { s.sizes.push(self.nodes[a].shape.sizes[d]) } } @@ -654,7 +685,6 @@ impl Context { operation: Operation::ReduceMean { node: a, dim: dim, - keepdims: keepdims, }, dtype: self.nodes[a].dtype, }); @@ -662,6 +692,18 @@ impl Context { .entry(a) .or_insert(Vec::new()) .push(node_id); - node_id + if keepdims { + let mut s_keepdim = Shape::new(); + for d in (0..self.nodes[a].shape.ndims()).rev() { + if d as i64 == dim { + s_keepdim.sizes.push(1u32) + } else { + s_keepdim.sizes.push(self.nodes[a].shape.sizes[d]) + } + } + self.reshape(node_id, s_keepdim) + } else { + Ok(node_id) + } } } diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index d748edf..ddcb696 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -30,10 +30,10 @@ pub enum Operation { ZerosLike(NodeIdentifier), - ReduceMax{ node: NodeIdentifier, dim: i64, keepdims: bool }, - ReduceSum{ node: NodeIdentifier, dim: i64, keepdims: bool }, + ReduceMax{ node: NodeIdentifier, dim: i64, }, + ReduceSum{ node: NodeIdentifier, dim: i64, }, // TODO: This might not behave well for integral types! Figure out behavior. - ReduceMean{ node: NodeIdentifier, dim: i64, keepdims: bool }, + ReduceMean{ node: NodeIdentifier, dim: i64, }, } impl Display for Operation { diff --git a/src/core/graph/tests.rs b/src/core/graph/tests.rs index 7770d2a..3a7a5bb 100644 --- a/src/core/graph/tests.rs +++ b/src/core/graph/tests.rs @@ -593,12 +593,12 @@ mod tests { let x = ctx.parameter("x", [2], xla::ElementType::F32).expect("x"); let x2 = ctx.mul(x, x).expect("x2"); - let y = ctx.reduce_mean(x2, 0, false); - println!("{}", ctx.nodes[y].shape); + let y = ctx.reduce_mean(x2, 0, true).expect("y"); let dydx = ctx.diff(y, x.into()).expect("dydx"); ctx.fold_consts(dydx, usize::max_value()).expect("fold_consts"); println!("{}", ctx.to_string(dydx)); + assert_eq!(ctx.to_string(dydx), "Mul (Mul (Constant Scalar 2) (Parameter Vector2 x)) (Constant Scalar 0.5)"); let lr = ctx.scalar(1, xla::ElementType::F32).expect("lr"); let update = ctx.mul(lr, dydx).expect("update"); let new_x = ctx.sub(x, update).expect("new_x");