From 76e6b535db8f2f694226850c8477728b052e077d Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 1 Oct 2024 21:40:44 +0530 Subject: [PATCH] rustfmt ran --- candle-core/src/metal_backend/mod.rs | 5 +++-- candle-core/tests/tensor_tests.rs | 26 ++++++++++++++++++++++++-- candle-metal-kernels/src/lib.rs | 4 ++-- candle-metal-kernels/src/tests.rs | 25 ++++++++++--------------- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index af9d791c8c..6f560c02ee 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1939,8 +1939,9 @@ impl BackendDevice for MetalDevice { name, shape.elem_count(), &buffer, - 1. - ).map_err(MetalError::from)?; + 1., + ) + .map_err(MetalError::from)?; Ok(MetalStorage::new( buffer, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 4bc35d3795..e0cea15c61 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -31,11 +31,33 @@ fn ones(device: &Device) -> Result<()> { ); assert_eq!( Tensor::ones((2, 3), DType::F16, device)?.to_vec2::()?, - [[half::f16::from_f32(1.0), half::f16::from_f32(1.0), half::f16::from_f32(1.0)], [half::f16::from_f32(1.0), half::f16::from_f32(1.0), half::f16::from_f32(1.0)]], + [ + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ], + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ] + ], ); assert_eq!( Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::()?, - [[half::bf16::from_f32(1.0), half::bf16::from_f32(1.0), half::bf16::from_f32(1.0)], [half::bf16::from_f32(1.0), half::bf16::from_f32(1.0), half::bf16::from_f32(1.0)]], + [ + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ], + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ] + ], ); Ok(()) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 832f35fb83..a270bb2888 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2367,7 +2367,7 @@ pub fn call_const_fill( name: &'static str, length: usize, output: &Buffer, - v: f32 + v: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); @@ -2376,7 +2376,7 @@ pub fn call_const_fill( encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (output, v, length)); - + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(output, metal::MTLResourceUsage::Write); diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 1fbd7f5470..f37ab5bb9c 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -2315,17 +2315,12 @@ fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec let command_queue = dev.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let buffer = dev.new_buffer((len * std::mem::size_of::()) as u64, MTLResourceOptions::StorageModePrivate); + let buffer = dev.new_buffer( + (len * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModePrivate, + ); - call_const_fill( - &dev, - command_buffer, - &kernels, - name, - len, - &buffer, - value - ).unwrap(); + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -2341,12 +2336,12 @@ fn const_fill() { "fill_i64", "fill_f16", "fill_bf16", - "fill_f32" + "fill_f32", ]; for name in fills { - let len = rand::thread_rng().gen_range(2 .. 16) * rand::thread_rng().gen_range(4 .. 16); - let value = rand::thread_rng().gen_range(1. .. 19.); + let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); + let value = rand::thread_rng().gen_range(1. ..19.); match name { "fill_u8" => { @@ -2373,7 +2368,7 @@ fn const_fill() { let v = constant_fill::(name, len, value); assert_eq!(v, vec![value; len]) } - _ => unimplemented!() + _ => unimplemented!(), }; } -} \ No newline at end of file +}