-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_setup.py
134 lines (105 loc) · 4.32 KB
/
data_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Data Setup Module for MAE-VAE Model.
This module handles the creation of symbolic links and preprocessing of 3D NIfTI data,
specifically focusing on files containing 'org' in their names.
"""
import os
import glob
import shutil
import nibabel as nib
import numpy as np
from tqdm import tqdm
def create_symlinks(source_dir, target_dir, pattern='*org*.nii.gz'):
"""Create symbolic links for NIfTI files matching the pattern.
Args:
source_dir (str): Source directory containing the original data
target_dir (str): Target directory for symbolic links
pattern (str): Pattern to match files (default: '*org*.nii.gz')
Returns:
list: List of created symbolic link paths
"""
# Create target directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)
# Find all matching files in source directory
source_files = glob.glob(os.path.join(source_dir, pattern))
symlinks = []
print(f"Creating symbolic links for {len(source_files)} files...")
for source_file in tqdm(source_files):
# Get the base filename
basename = os.path.basename(source_file)
# Create the target path
target_path = os.path.join(target_dir, basename)
# Remove existing symlink if it exists
if os.path.exists(target_path):
os.remove(target_path)
# Create symbolic link
os.symlink(source_file, target_path)
symlinks.append(target_path)
return symlinks
def extract_slices(nifti_path, output_dir, axis=2):
"""Extract 2D slices from a 3D NIfTI file.
Args:
nifti_path (str): Path to the NIfTI file
output_dir (str): Directory to save the extracted slices
axis (int): Axis along which to extract slices (0: sagittal, 1: coronal, 2: axial)
Returns:
int: Number of slices extracted
"""
# Load NIfTI file
img = nib.load(nifti_path)
data = img.get_fdata()
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Get base filename without extension
base_name = os.path.splitext(os.path.splitext(os.path.basename(nifti_path))[0])[0]
# Extract slices
n_slices = data.shape[axis]
for i in range(n_slices):
if axis == 0:
slice_data = data[i, :, :]
elif axis == 1:
slice_data = data[:, i, :]
else: # axis == 2
slice_data = data[:, :, i]
# Save slice as .npy file
output_path = os.path.join(output_dir, f"{base_name}_slice_{i:04d}.npy")
np.save(output_path, slice_data)
return n_slices
def process_dataset(source_dir, target_base_dir, axis=2):
"""Process the entire dataset by creating symlinks and extracting slices.
Args:
source_dir (str): Source directory containing the original data
target_base_dir (str): Base directory for processed data
axis (int): Axis along which to extract slices (0: sagittal, 1: coronal, 2: axial)
Returns:
tuple: Number of processed files and total number of slices
"""
# Create directories
symlink_dir = os.path.join(target_base_dir, 'symlinks')
slices_dir = os.path.join(target_base_dir, 'slices')
# Create symlinks
symlinks = create_symlinks(source_dir, symlink_dir)
# Process each file
total_slices = 0
print("\nExtracting slices from NIfTI files...")
for symlink in tqdm(symlinks):
# Create subject-specific directory for slices
subject_name = os.path.splitext(os.path.splitext(os.path.basename(symlink))[0])[0]
subject_slice_dir = os.path.join(slices_dir, subject_name)
# Extract slices
n_slices = extract_slices(symlink, subject_slice_dir, axis)
total_slices += n_slices
return len(symlinks), total_slices
def main():
"""Main function to set up the dataset."""
# Define directories
source_dir = "/path/to/your/data" # Replace with your data directory
target_base_dir = "processed_data"
# Process the dataset
n_files, n_slices = process_dataset(source_dir, target_base_dir)
print(f"\nProcessing complete!")
print(f"Processed {n_files} files")
print(f"Extracted {n_slices} total slices")
print(f"\nProcessed data is stored in: {os.path.abspath(target_base_dir)}")
if __name__ == '__main__':
main()