Skip to content

Commit

Permalink
add premultiply option
Browse files Browse the repository at this point in the history
  • Loading branch information
les-sosna committed Mar 29, 2024
1 parent 263866a commit 8705043
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions collagify/collagify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,19 @@ def find_grid_dimensions(n):
columns = n // rows
return rows, columns

def stitch_images_in_grid(input_pattern, output_filename):
def srgb_to_linear(srgb):
if srgb <= 0.04045:
return srgb / 12.92
else:
return ((srgb + 0.055) / 1.055) ** 2.4

def linear_to_srgb(linear):
if linear <= 0.0031308:
return linear * 12.92
else:
return 1.055 * (linear ** (1/2.4)) - 0.055

def stitch_images_in_grid(input_pattern, output_filename, premultiply_alpha=False):
"""Stitch images from input directory matching the pattern into a grid atlas."""
files = sorted(glob.glob(input_pattern))

Expand All @@ -38,6 +50,24 @@ def stitch_images_in_grid(input_pattern, output_filename):
col = index % cols
atlas.paste(image, (col * img_width, row * img_height))

# Premultiply alpha
if premultiply_alpha:
print("Premultiplying alpha...")
pixels = atlas.load()
for y in range(atlas_height):
for x in range(atlas_width):
r, g, b, a = pixels[x, y]
r = srgb_to_linear(r / 255)
g = srgb_to_linear(g / 255)
b = srgb_to_linear(b / 255)
r *= a / 255
g *= a / 255
b *= a / 255
r = linear_to_srgb(r) * 255
g = linear_to_srgb(g) * 255
b = linear_to_srgb(b) * 255
pixels[x, y] = (round(r), round(g), round(b), a)

# Save the final image
atlas.save(output_filename, format='TGA')

Expand All @@ -47,10 +77,11 @@ def main():
parser = argparse.ArgumentParser(description='Stitch images into a grid atlas.')
parser.add_argument('input_prefix', type=str, help='Prefix for input images')
parser.add_argument('output_filename', type=str, help='Filename for the output image')

parser.add_argument('--premultiply', action='store_true', help='Premultiply alpha channel')

args = parser.parse_args()

stitch_images_in_grid(args.input_prefix, args.output_filename)
stitch_images_in_grid(args.input_prefix, args.output_filename, args.premultiply)

if __name__ == "__main__":
main()
Expand Down

0 comments on commit 8705043

Please sign in to comment.