Skip to content

Commit

Permalink
fixed issues with keepdims
Browse files Browse the repository at this point in the history
  • Loading branch information
Ebanflo42 committed Mar 8, 2024
1 parent d02dd18 commit ddbbb22
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 193 deletions.
80 changes: 38 additions & 42 deletions src/core/graph/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl Context {
}

Operation::Reshape(node) => {
let next_pullback = self.diff(node, dependent_node)?;
let next_pullback = self.diff(output, dependent_node)?;
let node_sh = self.nodes[node].shape.clone();
let pullback = self.reshape(next_pullback, node_sh)?;
dependent_pullbacks.push(pullback);
Expand Down Expand Up @@ -135,11 +135,14 @@ impl Context {

Operation::Mul(a, b) => {
let next_pullback = self.diff(output, dependent_node)?;
if a == with_respect_to {
if a == b && a == with_respect_to {
let two = self.scalar(2, wrt_dtype)?;
let mul = self.mul(two, a)?;
dependent_pullbacks.push(self.mul(mul, next_pullback)?);
} else if a == with_respect_to {
let mul = self.mul(next_pullback, a)?;
dependent_pullbacks.push(mul);
}
if b == with_respect_to {
} else if b == with_respect_to {
let mul = self.mul(a, next_pullback)?;
dependent_pullbacks.push(mul);
}
Expand All @@ -166,7 +169,7 @@ impl Context {
}

Operation::TileInDim { node, n_tiles, dim } => {
let next_pullback = self.diff(node, dependent_node)?;
let next_pullback = self.diff(output, dependent_node)?;

let mut new_sizes = SmallVec::new();
for i in (0..self.nodes[node].shape.ndims()).rev() {
Expand All @@ -178,7 +181,7 @@ impl Context {

let reshaped_pullback =
self.reshape(next_pullback, Shape { sizes: new_sizes })?;
dependent_pullbacks.push(self.reduce_sum(reshaped_pullback, dim, false));
dependent_pullbacks.push(self.reduce_sum(reshaped_pullback, dim, false)?);
}

Operation::SliceInDim {
Expand All @@ -200,7 +203,6 @@ impl Context {
Operation::ReduceMax {
node,
dim,
keepdims,
} => {
if self.gradient_is_dependent(node, dependent_node) {
panic!(
Expand All @@ -214,52 +216,46 @@ impl Context {
Operation::ReduceSum {
node,
dim,
keepdims,
} => {
let next_pullback = self.diff(output, dependent_node)?;
let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64;
let tiled_pullback = if keepdims {
self.tile_in_dim(next_pullback, n_tiles, dim)?
} else {
let mut new_sizes = SmallVec::new();
for i in (0..self.nodes[node].shape.ndims()).rev() {
new_sizes.push(self.nodes[node].shape.sizes[i]);
if i as i64 == dim {
new_sizes.push(1u32);
}

let mut new_sizes = SmallVec::new();
for i in (0..self.nodes[next_pullback].shape.ndims()).rev() {
new_sizes.push(self.nodes[next_pullback].shape.sizes[i]);
if i as i64 == dim {
new_sizes.push(1u32);
}
let reshaped_pullback =
self.reshape(next_pullback, Shape { sizes: new_sizes })?;
self.tile_in_dim(reshaped_pullback, n_tiles, dim)?
};
}
if self.nodes[next_pullback].shape.ndims() == 0 {
new_sizes.push(1u32);
}
let reshaped_pullback =
self.reshape(next_pullback, Shape { sizes: new_sizes })?;
let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?;

dependent_pullbacks.push(tiled_pullback);
}

Operation::ReduceMean {
node,
dim,
keepdims,
} => {
Operation::ReduceMean { node, dim } => {
let next_pullback = self.diff(output, dependent_node)?;
let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64;
let tiled_pullback = if keepdims {
self.tile_in_dim(next_pullback, n_tiles, dim)?
} else {
let mut new_sizes = SmallVec::new();
for i in (0..self.nodes[next_pullback].shape.ndims()).rev() {
new_sizes.push(self.nodes[next_pullback].shape.sizes[i]);
if i as i64 == dim {
new_sizes.push(1u32);
}
}
if self.nodes[next_pullback].shape.ndims() == 0 {

let mut new_sizes = SmallVec::new();
for i in (0..self.nodes[next_pullback].shape.ndims()).rev() {
new_sizes.push(self.nodes[next_pullback].shape.sizes[i]);
if i as i64 == dim {
new_sizes.push(1u32);
}
let reshaped_pullback =
self.reshape(next_pullback, Shape { sizes: new_sizes })?;
self.tile_in_dim(reshaped_pullback, n_tiles, dim)?
};
let scale = self.scalar(1.0/(n_tiles as f32), self.nodes[node].dtype)?;
}
if self.nodes[next_pullback].shape.ndims() == 0 {
new_sizes.push(1u32);
}
let reshaped_pullback =
self.reshape(next_pullback, Shape { sizes: new_sizes })?;
let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?;

let scale = self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?;
let rescaled_pullback = self.mul(scale, tiled_pullback)?;
dependent_pullbacks.push(rescaled_pullback);
}
Expand Down
9 changes: 3 additions & 6 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,10 @@ impl Context {
Operation::ReduceMax {
node,
dim,
keepdims,
} => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], keepdims)?;
xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], false)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
Expand All @@ -394,11 +393,10 @@ impl Context {
Operation::ReduceSum {
node,
dim,
keepdims,
} => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].reduce_sum(&[dim], keepdims)?;
xla_op_slotmap[unda_xla_map[&node]].reduce_sum(&[dim], false)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
Expand All @@ -408,11 +406,10 @@ impl Context {
Operation::ReduceMean {
node,
dim,
keepdims,
} => {
if xla_op_slotmap.contains_key(unda_xla_map[&node]) {
let xla_op =
xla_op_slotmap[unda_xla_map[&node]].reduce_mean(&[dim], keepdims)?;
xla_op_slotmap[unda_xla_map[&node]].reduce_mean(&[dim], false)?;
let xla_id = xla_op_slotmap.insert(xla_op);
unda_xla_map.insert(*dependent_op, xla_id);
unda_op_queue.push_back(*dependent_op);
Expand Down
Loading

0 comments on commit ddbbb22

Please sign in to comment.