Skip to content

Commit

Permalink
use content size for output
Browse files Browse the repository at this point in the history
  • Loading branch information
vangj committed Jan 22, 2020
1 parent f57f393 commit e9d4cbf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions dl-transfer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,31 @@
def get_device():
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_raw_image_size(fpath):
with Image.open(fpath) as img:
width, height = img.size
return width, height

def get_image_size():
imsize = 512 if torch.cuda.is_available() else 128
imsize = (512, 512) if torch.cuda.is_available() else (128, 128)
return imsize

def get_loader():
image_size = get_image_size()
def get_loader(image_size=None):
image_size = get_image_size() if image_size is None else image_size
loader = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.Resize((image_size[1], image_size[0])),
transforms.ToTensor()])
return loader

def get_unloader():
unloader = transforms.ToPILImage()
return unloader

def image_loader(image_name):
def image_loader(image_name, image_size=None):
device = get_device()
image = Image.open(image_name)
# fake batch dimension required to fit network's input dimensions
loader = get_loader()
loader = get_loader(image_size=image_size)
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)

Expand Down Expand Up @@ -247,8 +252,11 @@ def parse_args(args):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

style_img = image_loader(style_path)
content_img = image_loader(content_path)
image_size = get_raw_image_size(content_path)
print(f'target width={image_size[0]} and height={image_size[1]}')

style_img = image_loader(style_path, image_size)
content_img = image_loader(content_path, image_size)
input_img = content_img.clone()

assert style_img.size() == content_img.size(), \
Expand All @@ -267,4 +275,6 @@ def parse_args(args):
num_steps=num_steps, style_weight=style_weight,
content_weight=content_weight)
output_img = to_pil_image(output)
output_img.save(output_path)
output_img.save(output_path)

print(f'style:{style_path} + content:{content_path} ==> output:{output_path}')
Binary file modified dl-transfer/image/final.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified dl-transfer/image/output.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e9d4cbf

Please sign in to comment.