Skip to content

Commit

Permalink
feat: more augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim authored Jan 18, 2025
1 parent f0ae3f1 commit 4b6f285
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 23 deletions.
61 changes: 48 additions & 13 deletions src/aind_exaspim_soma_detection/machine_learning/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from scipy.ndimage import rotate
from scipy.ndimage import rotate, zoom
from scipy.spatial import distance
from torch.utils.data import Dataset

Expand Down Expand Up @@ -82,7 +82,9 @@ def __init__(self, patch_shape, transform=False):
[
RandomFlip3D(),
RandomRotation3D(),
RandomScale3D(),
RandomContrast3D(),
RandomBrightness3D(),
RandomNoise3D(),
lambda x: torch.tensor(x, dtype=torch.float32).unsqueeze(
0
Expand Down Expand Up @@ -132,7 +134,7 @@ def __getitem__(self, key):
# Get voxel
brain_id, voxel = key
if self.transform:
voxel = [voxel_i + random.randint(-8, 8) for voxel_i in voxel]
voxel = [voxel_i + random.randint(-6, 6) for voxel_i in voxel]

# Get image patch
try:
Expand Down Expand Up @@ -253,6 +255,29 @@ def visualize_augmented_proposal(self, key):


# --- Data Augmentation ---
class RandomBrightness3D:
def __init__(self, delta=0.1):
self.delta = delta

def __call__(self, img):
factor = 1 + np.random.uniform(-self.delta, self.delta)
return img * factor


class RandomContrast3D:
"""
Adjusts the contrast of a 3D image by scaling voxel intensities.
"""

def __init__(self, factor_range=(0.8, 1.2)):
self.factor_range = factor_range

def __call__(self, img):
factor = random.uniform(*self.factor_range)
return np.clip(img * factor, img.min(), img.max())


class RandomFlip3D:
"""
Randomly flip a 3D image along one or more axes.
Expand Down Expand Up @@ -294,26 +319,36 @@ def __init__(self, angles=(-45, 45), axes=((0, 1), (0, 2), (1, 2))):
self.angles = angles
self.axes = axes

def __call__(self, img):
for _ in range(3):
def __call__(self, img, mode="grid-mirror"):
for axis in self.axes:
angle = random.uniform(*self.angles)
axis = random.choice(self.axes)
img = rotate(img, angle, axes=axis, reshape=False, order=1)
img = rotate(
img, angle, axes=axis, mode=mode, reshape=False, order=1
)
return img


class RandomContrast3D:
class RandomScale3D:
"""
Adjusts the contrast of a 3D image by scaling voxel intensities.
Applies random scaling to an image along each axis.
"""

def __init__(self, factor_range=(0.7, 1.3)):
self.factor_range = factor_range
def __init__(self, scale_range=(0.9, 1.1)):
self.scale_range = scale_range

def __call__(self, img):
factor = random.uniform(*self.factor_range)
return np.clip(img * factor, img.min(), img.max())
# Sample new image shape
alpha = np.random.uniform(self.scale_range[0], self.scale_range[1])
new_shape = (int(img.shape[0] * alpha),
int(img.shape[1] * alpha),
int(img.shape[2] * alpha))

# Compute the zoom factors
shape = img.shape
zoom_factors = [
new_dim / old_dim for old_dim, new_dim in zip(shape, new_shape)
]
return zoom(img, zoom_factors, order=3)


# --- Custom Dataloader ---
Expand Down
48 changes: 38 additions & 10 deletions src/aind_exaspim_soma_detection/machine_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn
import torch.nn.init as init


class Fast3dCNN(nn.Module):
Expand All @@ -36,41 +37,67 @@ def __init__(self, patch_shape):
super(Fast3dCNN, self).__init__()
self.patch_shape = patch_shape

# Convolutional Layer 1
# Convolutional layer 1
self.layer1 = nn.Sequential(
FastConvLayer(1, 32),
nn.BatchNorm3d(32),
FastConvLayer(1, 16),
nn.BatchNorm3d(16),
nn.ReLU(),
nn.Dropout3d(0.3),
nn.Dropout3d(0.25),
nn.MaxPool3d(kernel_size=2, stride=2),
)

# Convolutional Layer 2
# Convolutional layer 2
self.layer2 = nn.Sequential(
FastConvLayer(16, 32),
nn.BatchNorm3d(32),
nn.ReLU(),
nn.Dropout3d(0.25),
nn.MaxPool3d(kernel_size=2, stride=2),
)

# Convolutional layer 3
self.layer3 = nn.Sequential(
FastConvLayer(32, 64),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Dropout3d(0.3),
nn.Dropout3d(0.25),
nn.MaxPool3d(kernel_size=2, stride=2),
)

# Convolutional Layer 3
self.layer3 = nn.Sequential(
# Convolutional layer 4
self.layer4 = nn.Sequential(
FastConvLayer(64, 128),
nn.BatchNorm3d(128),
nn.ReLU(),
nn.Dropout3d(0.3),
nn.Dropout3d(0.25),
nn.MaxPool3d(kernel_size=2, stride=2),
)

# Final fully connected layers
self.output = nn.Sequential(
nn.Linear(128 * (self.patch_shape[0] // 8) ** 3, 128),
nn.Linear(128 * (self.patch_shape[0] // 16) ** 3, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 1),
)

# Initialize weights
self.apply(self.init_weights)

@staticmethod
def init_weights(m):
if isinstance(m, nn.Conv3d):
init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)

def forward(self, x):
"""
Forward pass of the 2.5D convolutional neural network.
Expand All @@ -90,6 +117,7 @@ def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

# Output layer
x = torch.flatten(x, start_dim=1)
Expand Down

0 comments on commit 4b6f285

Please sign in to comment.