Skip to content

Commit

Permalink
remove safe_transmute from writing
Browse files Browse the repository at this point in the history
  • Loading branch information
twitzelbos committed Jun 19, 2023
1 parent c1cac1f commit 416a76b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 50 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ num-traits = "0.2"
quick-error = "2.0"
safe-transmute = "0.11"
either = "1.6"
num-complex = "0.4.3"
num-complex = {version = "0.4.3", features=["bytemuck"]}
rgb = "0.8.36"
bytemuck = "1.13.1"

[dependencies.nalgebra]
optional = true
Expand Down
44 changes: 11 additions & 33 deletions src/volume/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ impl NiftiDataRescaler<Complex64> for Complex64 {
}

// Nifti 1.1 specifies that RGB data must NOT be rescaled
impl NiftiDataRescaler<NiftiRGB> for NiftiRGB {
fn nifti_rescale(value: NiftiRGB, _slope: f32, _intercept: f32) -> NiftiRGB {
impl NiftiDataRescaler<RGB8> for RGB8 {
fn nifti_rescale(value: RGB8, _slope: f32, _intercept: f32) -> RGB8 {
return value;
}
}

// Nifti 1.1 specifies that RGB(A) data must NOT be rescaled
impl NiftiDataRescaler<NiftiRGBA> for NiftiRGBA {
fn nifti_rescale(value: NiftiRGBA, _slope: f32, _intercept: f32) -> NiftiRGBA {
impl NiftiDataRescaler<RGBA8> for RGBA8 {
fn nifti_rescale(value: RGBA8, _slope: f32, _intercept: f32) -> RGBA8 {
return value;
}
}
Expand Down Expand Up @@ -977,9 +977,9 @@ impl From<RGB8> for NiftiRGB {
}
}

unsafe impl TriviallyTransmutable for NiftiRGB {}

impl DataElement for NiftiRGB {

impl DataElement for RGB8 {
const DATA_TYPE: NiftiType = NiftiType::Rgb24;
type Transform = NoTransform;

Expand All @@ -989,7 +989,7 @@ impl DataElement for NiftiRGB {
{
Ok(convert_bytes_to::<[u8; 3], _>(vec, e)
.into_iter()
.map(|x| NiftiRGB::new(x[0], x[1], x[2]))
.map(|x| RGB8::new(x[0], x[1], x[2]))
.collect())
}

Expand Down Expand Up @@ -1017,7 +1017,7 @@ impl DataElement for NiftiRGB {
let g = ByteOrdered::native(&mut src).read_u8()?;
let b = ByteOrdered::native(&mut src).read_u8()?;

Ok(NiftiRGB::new(r, g, b))
Ok(RGB8::new(r, g, b))
}
}

Expand Down Expand Up @@ -1060,30 +1060,8 @@ impl DataElement for [u8; 3] {
}
}

#[repr(C)]
#[derive(Debug, Copy, Clone)]
struct NiftiRGBA {
r: u8,
g: u8,
b: u8,
a: u8,
}

impl NiftiRGBA {
fn new(r: u8, g: u8, b: u8, a: u8) -> Self {
NiftiRGBA { r, g, b, a }
}
}

impl Into<RGBA8> for NiftiRGBA {
fn into(self) -> RGBA8 {
RGBA8::new(self.r, self.g, self.b, self.a)
}
}

unsafe impl TriviallyTransmutable for NiftiRGBA {}

impl DataElement for NiftiRGBA {
impl DataElement for RGBA8 {
const DATA_TYPE: NiftiType = NiftiType::Rgba32;
type Transform = NoTransform;

Expand All @@ -1093,7 +1071,7 @@ impl DataElement for NiftiRGBA {
{
Ok(convert_bytes_to::<[u8; 4], _>(vec, e)
.into_iter()
.map(|x| NiftiRGBA::new(x[0], x[1], x[2], x[3]))
.map(|x| RGBA8::new(x[0], x[1], x[2], x[3]))
.collect())
}

Expand Down Expand Up @@ -1122,7 +1100,7 @@ impl DataElement for NiftiRGBA {
let b = ByteOrdered::native(&mut src).read_u8()?;
let a = ByteOrdered::native(&mut src).read_u8()?;

Ok(NiftiRGBA::new(r, g, b, a))
Ok(RGBA8::new(r, g, b, a))
}
}

Expand Down
13 changes: 6 additions & 7 deletions src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use byteordered::{ByteOrdered, Endian};
use flate2::write::GzEncoder;
use flate2::Compression;
use ndarray::{ArrayBase, Axis, Data, Dimension, RemoveAxis};
use safe_transmute::{transmute_to_bytes, TriviallyTransmutable};
use bytemuck::{Pod, cast_slice};

use crate::{
header::{MAGIC_CODE_NI1, MAGIC_CODE_NIP1},
Expand Down Expand Up @@ -142,7 +142,7 @@ impl<'a> WriterOptions<'a> {
pub fn write_nifti_tt<A, S, D>(&self, data: &ArrayBase<S, D>, datatype: NiftiType) -> Result<()>
where
S: Data<Elem = A>,
A: TriviallyTransmutable,
A: Pod,
D: Dimension + RemoveAxis,
{

Expand Down Expand Up @@ -210,8 +210,7 @@ impl<'a> WriterOptions<'a> {
pub fn write_nifti<A, S, D>(&self, data: &ArrayBase<S, D>) -> Result<()>
where
S: Data<Elem = A>,
A: DataElement,
A: TriviallyTransmutable,
A: DataElement + Pod,
D: Dimension + RemoveAxis,
{
self.write_nifti_tt(data, A::DATA_TYPE)
Expand Down Expand Up @@ -440,7 +439,7 @@ where
fn write_data<A, B, S, D, W, E>(mut writer: ByteOrdered<W, E>, data: ArrayBase<S, D>) -> Result<()>
where
S: Data<Elem = A>,
A: TriviallyTransmutable,
A: Pod,
D: Dimension + RemoveAxis,
W: Write,
E: Endian + Copy,
Expand Down Expand Up @@ -470,15 +469,15 @@ fn write_slice<A, B, S, D, W, E>(
) -> Result<()>
where
S: Data<Elem = A>,
A: Clone + TriviallyTransmutable,
A: Clone + Pod,
D: Dimension,
W: Write,
E: Endian,
{
let len = data.len();
let arr_data = data.into_shape(len).unwrap();
let slice = arr_data.as_slice().unwrap();
let bytes = transmute_to_bytes(slice);
let bytes = cast_slice(slice);
let (writer, endianness) = writer.into_parts();
let bytes = adapt_bytes::<B, _>(bytes, endianness);
writer.write_all(&bytes)?;
Expand Down
112 changes: 103 additions & 9 deletions tests/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ extern crate tempfile;

mod util;

use num_complex::*;
#[cfg(feature = "ndarray_volumes")]
mod tests {
use std::{
Expand All @@ -21,14 +22,15 @@ mod tests {
use ndarray::{
s, Array, Array1, Array2, Array3, Array4, Array5, Axis, Dimension, Ix2, IxDyn, ShapeBuilder,
};
use rgb::{RGB8, RGBA8};
use tempfile::tempdir;

use nifti::{
header::{MAGIC_CODE_NI1, MAGIC_CODE_NIP1},
object::NiftiObject,
volume::shape::Dim,
writer::WriterOptions,
DataElement, IntoNdArray, NiftiHeader, NiftiType, ReaderOptions, volume::element::NiftiRGB,
DataElement, IntoNdArray, NiftiHeader, NiftiType, ReaderOptions,
};

use super::util::rgb_header_gt;
Expand Down Expand Up @@ -434,15 +436,40 @@ mod tests {
}

#[test]
fn write_4d_rgb2() {
let mut data = Array::from_elem((3, 3, 3, 2), NiftiRGB::new(0u8, 0u8, 0u8));
fn write_4d_rgba_direct() {
let mut data = Array::from_elem((3, 3, 3, 2), [0u8, 0u8, 0u8, 0u8]);
data[(0, 0, 0, 0)] = [55, 55, 0, 0];
data[(0, 0, 1, 0)] = [55, 0, 55, 0];
data[(0, 1, 0, 0)] = [0, 55, 55, 0];
data[(0, 0, 0, 1)] = [55, 55, 0, 0];
data[(0, 1, 0, 1)] = [55, 0, 55, 0];
data[(1, 0, 0, 1)] = [0, 55, 55, 0];

data[(0, 0, 0, 0)] = NiftiRGB::new(55, 55, 0);
data[(0, 0, 1, 0)] = NiftiRGB::new(55, 0, 55);
data[(0, 1, 0, 0)] = NiftiRGB::new(0, 55, 55);
data[(0, 0, 0, 1)] = NiftiRGB::new(55, 55, 0);
data[(0, 1, 0, 1)] = NiftiRGB::new(55, 0, 55);
data[(1, 0, 0, 1)] = NiftiRGB::new(0, 55, 55);
let path = get_temporary_path("rgb.nii");
let header = rgb_header_gt();
WriterOptions::new(&path)
.reference_header(&header)
.write_nifti_tt(&data, NiftiType::Rgba32)
.unwrap();

// Until we are able to read RGB images, we simply compare the bytes of the newly created
// image to the bytes of the prepared 4D RGB image in ressources/rgb/.
assert_eq!(
fs::read(path).unwrap(),
fs::read("resources/rgba/4D.nii").unwrap()
);
}

#[test]
fn write_4d_rgb_rgbtype() {
let mut data = Array::from_elem((3, 3, 3, 2), RGB8::new(0u8, 0u8, 0u8));

data[(0, 0, 0, 0)] = RGB8::new(55, 55, 0);
data[(0, 0, 1, 0)] = RGB8::new(55, 0, 55);
data[(0, 1, 0, 0)] = RGB8::new(0, 55, 55);
data[(0, 0, 0, 1)] = RGB8::new(55, 55, 0);
data[(0, 1, 0, 1)] = RGB8::new(55, 0, 55);
data[(1, 0, 0, 1)] = RGB8::new(0, 55, 55);

let path = get_temporary_path("rgb.nii");
let header = rgb_header_gt();
Expand All @@ -459,6 +486,73 @@ mod tests {
);
}

#[test]
fn write_4d_rgb_rgbatype() {
let mut data = Array::from_elem((3, 3, 3, 2), RGBA8::new(0u8, 0u8, 0u8, 0u8));

data[(0, 0, 0, 0)] = RGBA8::new(55, 55, 0, 0);
data[(0, 0, 1, 0)] = RGBA8::new(55, 0, 55, 0);
data[(0, 1, 0, 0)] = RGBA8::new(0, 55, 55, 0);
data[(0, 0, 0, 1)] = RGBA8::new(55, 55, 0, 0);
data[(0, 1, 0, 1)] = RGBA8::new(55, 0, 55, 0);
data[(1, 0, 0, 1)] = RGBA8::new(0, 55, 55, 0);

let path = get_temporary_path("rgb.nii");
let header = rgb_header_gt();
WriterOptions::new(&path)
.reference_header(&header)
.write_nifti(&data)
.unwrap();

// Until we are able to read RGB images, we simply compare the bytes of the newly created
// image to the bytes of the prepared 4D RGB image in ressources/rgb/.
assert_eq!(
fs::read(path).unwrap(),
fs::read("resources/rgba/4D.nii").unwrap()
);
}

#[test]
fn write_2d_complex32() {
let mut data = Array::from_elem((3, 3), num_complex::Complex32::new(0.0, 0.0));

data[(0, 0)] = num_complex::Complex32::new(1.0, 1.0);
data[(0, 1)] = num_complex::Complex32::new(2.0, 2.0);
data[(1, 0)] = num_complex::Complex32::new(3.0, 3.0);

let path = get_temporary_path("complex32.nii");
let header = generate_nifti_header([2, 3, 3, 1, 1, 1, 1, 1], 1.0, 0.0, NiftiType::Complex64);
WriterOptions::new(&path)
.reference_header(&header)
.write_nifti(&data).unwrap();

assert_eq!(
fs::read(path).unwrap(),
fs::read("resources/complex/complex32.nii").unwrap()
);
}

#[test]
fn write_2d_complex64() {
let mut data = Array::from_elem((3, 3), num_complex::Complex64::new(0.0, 0.0));

data[(0, 0)] = num_complex::Complex64::new(1.0, 1.0);
data[(0, 1)] = num_complex::Complex64::new(2.0, 2.0);
data[(1, 0)] = num_complex::Complex64::new(3.0, 3.0);

let path = get_temporary_path("complex32.nii");
let header = generate_nifti_header([2, 3, 3, 1, 1, 1, 1, 1], 1.0, 0.0, NiftiType::Complex128);
WriterOptions::new(&path)
.reference_header(&header)
.write_nifti(&data).unwrap();

assert_eq!(
fs::read(path).unwrap(),
fs::read("resources/complex/complex32.nii").unwrap()
);
}


#[test]
fn write_extended_header() {
let data: Array2<f64> = Array2::zeros((8, 8));
Expand Down

0 comments on commit 416a76b

Please sign in to comment.