Skip to content

Commit

Permalink
rustfmt ran
Browse files Browse the repository at this point in the history
  • Loading branch information
AnubhabB committed Oct 1, 2024
1 parent bc60875 commit 76e6b53
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 21 deletions.
5 changes: 3 additions & 2 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1939,8 +1939,9 @@ impl BackendDevice for MetalDevice {
name,
shape.elem_count(),
&buffer,
1.
).map_err(MetalError::from)?;
1.,
)
.map_err(MetalError::from)?;

Ok(MetalStorage::new(
buffer,
Expand Down
26 changes: 24 additions & 2 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,33 @@ fn ones(device: &Device) -> Result<()> {
);
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[[half::f16::from_f32(1.0), half::f16::from_f32(1.0), half::f16::from_f32(1.0)], [half::f16::from_f32(1.0), half::f16::from_f32(1.0), half::f16::from_f32(1.0)]],
[
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
],
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
]
],
);
assert_eq!(
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
[[half::bf16::from_f32(1.0), half::bf16::from_f32(1.0), half::bf16::from_f32(1.0)], [half::bf16::from_f32(1.0), half::bf16::from_f32(1.0), half::bf16::from_f32(1.0)]],
[
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
],
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
]
],
);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2367,7 +2367,7 @@ pub fn call_const_fill(
name: &'static str,
length: usize,
output: &Buffer,
v: f32
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
Expand All @@ -2376,7 +2376,7 @@ pub fn call_const_fill(
encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (output, v, length));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);

encoder.use_resource(output, metal::MTLResourceUsage::Write);
Expand Down
25 changes: 10 additions & 15 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2315,17 +2315,12 @@ fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T>
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let buffer = dev.new_buffer((len * std::mem::size_of::<T>()) as u64, MTLResourceOptions::StorageModePrivate);
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);

call_const_fill(
&dev,
command_buffer,
&kernels,
name,
len,
&buffer,
value
).unwrap();
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();

command_buffer.commit();
command_buffer.wait_until_completed();
Expand All @@ -2341,12 +2336,12 @@ fn const_fill() {
"fill_i64",
"fill_f16",
"fill_bf16",
"fill_f32"
"fill_f32",
];

for name in fills {
let len = rand::thread_rng().gen_range(2 .. 16) * rand::thread_rng().gen_range(4 .. 16);
let value = rand::thread_rng().gen_range(1. .. 19.);
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);

match name {
"fill_u8" => {
Expand All @@ -2373,7 +2368,7 @@ fn const_fill() {
let v = constant_fill::<f32>(name, len, value);
assert_eq!(v, vec![value; len])
}
_ => unimplemented!()
_ => unimplemented!(),
};
}
}
}

0 comments on commit 76e6b53

Please sign in to comment.