From 416a76ba69d0d1704503174d36b66bff86fba317 Mon Sep 17 00:00:00 2001 From: Thomas Witzel Date: Mon, 19 Jun 2023 11:53:16 -0700 Subject: [PATCH] remove safe_transmute from writing --- Cargo.toml | 3 +- src/volume/element.rs | 44 +++++------------ src/writer.rs | 13 +++-- tests/writer.rs | 112 ++++++++++++++++++++++++++++++++++++++---- 4 files changed, 122 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a5eaafc..a28a6b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,9 @@ num-traits = "0.2" quick-error = "2.0" safe-transmute = "0.11" either = "1.6" -num-complex = "0.4.3" +num-complex = {version = "0.4.3", features=["bytemuck"]} rgb = "0.8.36" +bytemuck = "1.13.1" [dependencies.nalgebra] optional = true diff --git a/src/volume/element.rs b/src/volume/element.rs index 8f02968..6f0cf4e 100644 --- a/src/volume/element.rs +++ b/src/volume/element.rs @@ -177,15 +177,15 @@ impl NiftiDataRescaler for Complex64 { } // Nifti 1.1 specifies that RGB data must NOT be rescaled -impl NiftiDataRescaler for NiftiRGB { - fn nifti_rescale(value: NiftiRGB, _slope: f32, _intercept: f32) -> NiftiRGB { +impl NiftiDataRescaler for RGB8 { + fn nifti_rescale(value: RGB8, _slope: f32, _intercept: f32) -> RGB8 { return value; } } // Nifti 1.1 specifies that RGB(A) data must NOT be rescaled -impl NiftiDataRescaler for NiftiRGBA { - fn nifti_rescale(value: NiftiRGBA, _slope: f32, _intercept: f32) -> NiftiRGBA { +impl NiftiDataRescaler for RGBA8 { + fn nifti_rescale(value: RGBA8, _slope: f32, _intercept: f32) -> RGBA8 { return value; } } @@ -977,9 +977,9 @@ impl From for NiftiRGB { } } -unsafe impl TriviallyTransmutable for NiftiRGB {} -impl DataElement for NiftiRGB { + +impl DataElement for RGB8 { const DATA_TYPE: NiftiType = NiftiType::Rgb24; type Transform = NoTransform; @@ -989,7 +989,7 @@ impl DataElement for NiftiRGB { { Ok(convert_bytes_to::<[u8; 3], _>(vec, e) .into_iter() - .map(|x| NiftiRGB::new(x[0], x[1], x[2])) + .map(|x| RGB8::new(x[0], x[1], x[2])) .collect()) } @@ -1017,7 +1017,7 @@ impl DataElement for NiftiRGB { let g = ByteOrdered::native(&mut src).read_u8()?; let b = ByteOrdered::native(&mut src).read_u8()?; - Ok(NiftiRGB::new(r, g, b)) + Ok(RGB8::new(r, g, b)) } } @@ -1060,30 +1060,8 @@ impl DataElement for [u8; 3] { } } -#[repr(C)] -#[derive(Debug, Copy, Clone)] -struct NiftiRGBA { - r: u8, - g: u8, - b: u8, - a: u8, -} - -impl NiftiRGBA { - fn new(r: u8, g: u8, b: u8, a: u8) -> Self { - NiftiRGBA { r, g, b, a } - } -} - -impl Into for NiftiRGBA { - fn into(self) -> RGBA8 { - RGBA8::new(self.r, self.g, self.b, self.a) - } -} - -unsafe impl TriviallyTransmutable for NiftiRGBA {} -impl DataElement for NiftiRGBA { +impl DataElement for RGBA8 { const DATA_TYPE: NiftiType = NiftiType::Rgba32; type Transform = NoTransform; @@ -1093,7 +1071,7 @@ impl DataElement for NiftiRGBA { { Ok(convert_bytes_to::<[u8; 4], _>(vec, e) .into_iter() - .map(|x| NiftiRGBA::new(x[0], x[1], x[2], x[3])) + .map(|x| RGBA8::new(x[0], x[1], x[2], x[3])) .collect()) } @@ -1122,7 +1100,7 @@ impl DataElement for NiftiRGBA { let b = ByteOrdered::native(&mut src).read_u8()?; let a = ByteOrdered::native(&mut src).read_u8()?; - Ok(NiftiRGBA::new(r, g, b, a)) + Ok(RGBA8::new(r, g, b, a)) } } diff --git a/src/writer.rs b/src/writer.rs index c2ab82f..e44c324 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,7 +9,7 @@ use byteordered::{ByteOrdered, Endian}; use flate2::write::GzEncoder; use flate2::Compression; use ndarray::{ArrayBase, Axis, Data, Dimension, RemoveAxis}; -use safe_transmute::{transmute_to_bytes, TriviallyTransmutable}; +use bytemuck::{Pod, cast_slice}; use crate::{ header::{MAGIC_CODE_NI1, MAGIC_CODE_NIP1}, @@ -142,7 +142,7 @@ impl<'a> WriterOptions<'a> { pub fn write_nifti_tt(&self, data: &ArrayBase, datatype: NiftiType) -> Result<()> where S: Data, - A: TriviallyTransmutable, + A: Pod, D: Dimension + RemoveAxis, { @@ -210,8 +210,7 @@ impl<'a> WriterOptions<'a> { pub fn write_nifti(&self, data: &ArrayBase) -> Result<()> where S: Data, - A: DataElement, - A: TriviallyTransmutable, + A: DataElement + Pod, D: Dimension + RemoveAxis, { self.write_nifti_tt(data, A::DATA_TYPE) @@ -440,7 +439,7 @@ where fn write_data(mut writer: ByteOrdered, data: ArrayBase) -> Result<()> where S: Data, - A: TriviallyTransmutable, + A: Pod, D: Dimension + RemoveAxis, W: Write, E: Endian + Copy, @@ -470,7 +469,7 @@ fn write_slice( ) -> Result<()> where S: Data, - A: Clone + TriviallyTransmutable, + A: Clone + Pod, D: Dimension, W: Write, E: Endian, @@ -478,7 +477,7 @@ where let len = data.len(); let arr_data = data.into_shape(len).unwrap(); let slice = arr_data.as_slice().unwrap(); - let bytes = transmute_to_bytes(slice); + let bytes = cast_slice(slice); let (writer, endianness) = writer.into_parts(); let bytes = adapt_bytes::(bytes, endianness); writer.write_all(&bytes)?; diff --git a/tests/writer.rs b/tests/writer.rs index ced93ce..672c730 100644 --- a/tests/writer.rs +++ b/tests/writer.rs @@ -9,6 +9,7 @@ extern crate tempfile; mod util; +use num_complex::*; #[cfg(feature = "ndarray_volumes")] mod tests { use std::{ @@ -21,6 +22,7 @@ mod tests { use ndarray::{ s, Array, Array1, Array2, Array3, Array4, Array5, Axis, Dimension, Ix2, IxDyn, ShapeBuilder, }; + use rgb::{RGB8, RGBA8}; use tempfile::tempdir; use nifti::{ @@ -28,7 +30,7 @@ mod tests { object::NiftiObject, volume::shape::Dim, writer::WriterOptions, - DataElement, IntoNdArray, NiftiHeader, NiftiType, ReaderOptions, volume::element::NiftiRGB, + DataElement, IntoNdArray, NiftiHeader, NiftiType, ReaderOptions, }; use super::util::rgb_header_gt; @@ -434,15 +436,40 @@ mod tests { } #[test] - fn write_4d_rgb2() { - let mut data = Array::from_elem((3, 3, 3, 2), NiftiRGB::new(0u8, 0u8, 0u8)); + fn write_4d_rgba_direct() { + let mut data = Array::from_elem((3, 3, 3, 2), [0u8, 0u8, 0u8, 0u8]); + data[(0, 0, 0, 0)] = [55, 55, 0, 0]; + data[(0, 0, 1, 0)] = [55, 0, 55, 0]; + data[(0, 1, 0, 0)] = [0, 55, 55, 0]; + data[(0, 0, 0, 1)] = [55, 55, 0, 0]; + data[(0, 1, 0, 1)] = [55, 0, 55, 0]; + data[(1, 0, 0, 1)] = [0, 55, 55, 0]; - data[(0, 0, 0, 0)] = NiftiRGB::new(55, 55, 0); - data[(0, 0, 1, 0)] = NiftiRGB::new(55, 0, 55); - data[(0, 1, 0, 0)] = NiftiRGB::new(0, 55, 55); - data[(0, 0, 0, 1)] = NiftiRGB::new(55, 55, 0); - data[(0, 1, 0, 1)] = NiftiRGB::new(55, 0, 55); - data[(1, 0, 0, 1)] = NiftiRGB::new(0, 55, 55); + let path = get_temporary_path("rgb.nii"); + let header = rgb_header_gt(); + WriterOptions::new(&path) + .reference_header(&header) + .write_nifti_tt(&data, NiftiType::Rgba32) + .unwrap(); + + // Until we are able to read RGB images, we simply compare the bytes of the newly created + // image to the bytes of the prepared 4D RGB image in ressources/rgb/. + assert_eq!( + fs::read(path).unwrap(), + fs::read("resources/rgba/4D.nii").unwrap() + ); + } + + #[test] + fn write_4d_rgb_rgbtype() { + let mut data = Array::from_elem((3, 3, 3, 2), RGB8::new(0u8, 0u8, 0u8)); + + data[(0, 0, 0, 0)] = RGB8::new(55, 55, 0); + data[(0, 0, 1, 0)] = RGB8::new(55, 0, 55); + data[(0, 1, 0, 0)] = RGB8::new(0, 55, 55); + data[(0, 0, 0, 1)] = RGB8::new(55, 55, 0); + data[(0, 1, 0, 1)] = RGB8::new(55, 0, 55); + data[(1, 0, 0, 1)] = RGB8::new(0, 55, 55); let path = get_temporary_path("rgb.nii"); let header = rgb_header_gt(); @@ -459,6 +486,73 @@ mod tests { ); } + #[test] + fn write_4d_rgb_rgbatype() { + let mut data = Array::from_elem((3, 3, 3, 2), RGBA8::new(0u8, 0u8, 0u8, 0u8)); + + data[(0, 0, 0, 0)] = RGBA8::new(55, 55, 0, 0); + data[(0, 0, 1, 0)] = RGBA8::new(55, 0, 55, 0); + data[(0, 1, 0, 0)] = RGBA8::new(0, 55, 55, 0); + data[(0, 0, 0, 1)] = RGBA8::new(55, 55, 0, 0); + data[(0, 1, 0, 1)] = RGBA8::new(55, 0, 55, 0); + data[(1, 0, 0, 1)] = RGBA8::new(0, 55, 55, 0); + + let path = get_temporary_path("rgb.nii"); + let header = rgb_header_gt(); + WriterOptions::new(&path) + .reference_header(&header) + .write_nifti(&data) + .unwrap(); + + // Until we are able to read RGB images, we simply compare the bytes of the newly created + // image to the bytes of the prepared 4D RGB image in ressources/rgb/. + assert_eq!( + fs::read(path).unwrap(), + fs::read("resources/rgba/4D.nii").unwrap() + ); + } + + #[test] + fn write_2d_complex32() { + let mut data = Array::from_elem((3, 3), num_complex::Complex32::new(0.0, 0.0)); + + data[(0, 0)] = num_complex::Complex32::new(1.0, 1.0); + data[(0, 1)] = num_complex::Complex32::new(2.0, 2.0); + data[(1, 0)] = num_complex::Complex32::new(3.0, 3.0); + + let path = get_temporary_path("complex32.nii"); + let header = generate_nifti_header([2, 3, 3, 1, 1, 1, 1, 1], 1.0, 0.0, NiftiType::Complex64); + WriterOptions::new(&path) + .reference_header(&header) + .write_nifti(&data).unwrap(); + + assert_eq!( + fs::read(path).unwrap(), + fs::read("resources/complex/complex32.nii").unwrap() + ); + } + + #[test] + fn write_2d_complex64() { + let mut data = Array::from_elem((3, 3), num_complex::Complex64::new(0.0, 0.0)); + + data[(0, 0)] = num_complex::Complex64::new(1.0, 1.0); + data[(0, 1)] = num_complex::Complex64::new(2.0, 2.0); + data[(1, 0)] = num_complex::Complex64::new(3.0, 3.0); + + let path = get_temporary_path("complex32.nii"); + let header = generate_nifti_header([2, 3, 3, 1, 1, 1, 1, 1], 1.0, 0.0, NiftiType::Complex128); + WriterOptions::new(&path) + .reference_header(&header) + .write_nifti(&data).unwrap(); + + assert_eq!( + fs::read(path).unwrap(), + fs::read("resources/complex/complex32.nii").unwrap() + ); + } + + #[test] fn write_extended_header() { let data: Array2 = Array2::zeros((8, 8));