From db42976ac052bd0abccf7f3ac296cc1f9fb99240 Mon Sep 17 00:00:00 2001 From: Braden Everson Date: Mon, 8 Apr 2024 07:55:29 -0500 Subject: [PATCH] Create tests for uniform dist --- src/core/graph/tests_cpu.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/core/graph/tests_cpu.rs b/src/core/graph/tests_cpu.rs index d10093a..b0d5cea 100644 --- a/src/core/graph/tests_cpu.rs +++ b/src/core/graph/tests_cpu.rs @@ -97,6 +97,35 @@ mod tests { } } + #[test] + fn test_uniform_dist() { + let mut ctx = Context::new(); + let min = ctx.scalar(0, xla::ElementType::F32).expect("min = 0"); + let max = ctx.scalar(1, xla::ElementType::F32).expect("max = 10"); + let mat = ctx.rng_uniform(min, max, &[10,1]).expect("sample the uniform distribution"); + + let client = xla::PjRtClient::cpu().expect("client"); + let name = "test"; + let executable = ctx.compile(&name, [mat], &client).expect("executable"); + + let device_result = executable.execute::(&[]).expect("execute"); + let host_result = device_result[0][0] + .to_literal_sync() + .expect("to_literal_sync"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); + println!("{:?}", rust_result); + + match untupled_result.shape().unwrap() { + Shape::Array(shape) => { + assert_eq!(shape.dims(), &[10,1]); + }, + _ => { + panic!("Shape is not correct"); + } + } + } + #[test] fn test_large_cte() {