Skip to content

Commit

Permalink
make sure nested unet can accommodate feature dimensions other than p…
Browse files Browse the repository at this point in the history
…ower of 2
  • Loading branch information
lucidrains committed Sep 10, 2022
1 parent 392a0b1 commit 47571f2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-unet',
packages = find_packages(exclude=[]),
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'X-Unet',
long_description_content_type = 'text/markdown',
Expand Down
7 changes: 6 additions & 1 deletion x_unet/x_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def default(val, d):
def is_power_two(n):
return math.log2(n).is_integer()

def divisible_by(num, denom):
return (num % denom) == 0

def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
Expand Down Expand Up @@ -583,8 +586,10 @@ def forward(self, x, residual = None):

*_, h, w = x.shape

layers = len(self.ups)

assert h == w, 'only works with square images'
assert is_power_two(h), 'height and width must be power of two'
assert divisible_by(h, 2 ** len(self.ups)), f'dimension {h} must be divisible by {2 ** layers} ({layers} layers in nested unet)'
assert (h % (2 ** self.depth)) == 0, 'the unet has too much depth for the image being passed in'

# hiddens
Expand Down

0 comments on commit 47571f2

Please sign in to comment.