Skip to content

Commit

Permalink
Add a hack for generating random uniform/normal for f16/bf16.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 31, 2023
1 parent c12ad45 commit 1099504
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
}
Expand All @@ -213,8 +219,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
}
Expand Down

0 comments on commit 1099504

Please sign in to comment.