Skip to content

Commit

Permalink
Create tests for uniform dist
Browse files Browse the repository at this point in the history
  • Loading branch information
BradenEverson committed Apr 8, 2024
1 parent c218991 commit db42976
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/core/graph/tests_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,35 @@ mod tests {
}
}

#[test]
fn test_uniform_dist() {
let mut ctx = Context::new();
let min = ctx.scalar(0, xla::ElementType::F32).expect("min = 0");
let max = ctx.scalar(1, xla::ElementType::F32).expect("max = 10");
let mat = ctx.rng_uniform(min, max, &[10,1]).expect("sample the uniform distribution");

let client = xla::PjRtClient::cpu().expect("client");
let name = "test";
let executable = ctx.compile(&name, [mat], &client).expect("executable");

let device_result = executable.execute::<Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

match untupled_result.shape().unwrap() {
Shape::Array(shape) => {
assert_eq!(shape.dims(), &[10,1]);
},
_ => {
panic!("Shape is not correct");
}
}
}


#[test]
fn test_large_cte() {
Expand Down

0 comments on commit db42976

Please sign in to comment.