Skip to content

Commit

Permalink
Added working code; tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyism committed Jul 10, 2021
1 parent 438276f commit f00629b
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/

models/
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,43 @@
# py-image-comparer
# py-image-comparer

## Installation
To install, run

```bash
pip install image-comparer
```

## Usage
With PIL

```python
import image_comparer
from PIL import Image

image = Image.open("test/kobe.jpg")
image2 = Image.open("test/kobe2.jpg")
image_comparer.is_similar(image, image2, threshold=0.5)
```
or with OpenCV

```python
import image_comparer
import cv2

image = cv2.imread("test/kobe.jpg")
image2 = cv2.imread("test/kobe2.jpg")
image_comparer.is_similar(image, image2, threshold=0.5)
```

## Development

### Installation
```bash
pip install -r requirements-test.txt
```

### Tests
To run tests, run
```bash
pytest
```
1 change: 1 addition & 0 deletions image_comparer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .compare import is_similar
35 changes: 27 additions & 8 deletions image_comparer/compare.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import Union
from pathlib import Path

import numpy as np
import torch
from model import Siamese
from PIL import Image
from torchvision import transforms

from .model import Siamese
from .download import download_model

ImageType = Union[Image.Image, np.ndarray]
DEFAULT_IMAGE_SIZE = (105, 105)
MODEL_FOLDER = Path(__file__).parents[1] / "models"
MODEL_FOLDER.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_FOLDER / "siamese-model.pt"
if not MODEL_PATH.exists():
download_model(MODEL_PATH)

model = Siamese()
model_dict = torch.load("siamese-model.pt", map_location="cpu")
model_dict = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(model_dict)

transformer = transforms.Compose(
Expand All @@ -14,10 +28,15 @@
]
)

if __name__ == "__main__":
image = Image.open("test/kobe.jpg").resize((105, 105))
image2 = Image.open("test/kobe2.jpg").resize((105, 105))
def _pil_image_(image: ImageType) -> Image:
return Image.fromarray(np.array(image)).resize(DEFAULT_IMAGE_SIZE)

def is_similar(image1: ImageType, image2: ImageType, threshold=0.5):
pil_image1 = _pil_image_(image1)
pil_image2 = _pil_image_(image2)

image_tensor1 = transformer(pil_image1).unsqueeze(0)
image_tensor2 = transformer(pil_image2).unsqueeze(0)

image_tensor = transformer(image).unsqueeze(0)
image2_tensor = transformer(image2).unsqueeze(0)
model(image_tensor, image2_tensor)
results = model(image_tensor1, image_tensor2)
return results.detach().cpu().numpy()[0][0] > threshold
17 changes: 17 additions & 0 deletions image_comparer/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pathlib import Path

from tqdm import tqdm
import requests

def download_model(model_path: Path, version="v1.0.0", block_Size=1024):
model_url = f"https://github.com/joeyism/siamese-pytorch/releases/download/{version}/siamese-model.pt"
response = requests.get(model_url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
progress_bar = tqdm(desc="Downloading model", total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(model_path.as_posix(), 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
45 changes: 45 additions & 0 deletions image_comparer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class Siamese(nn.Module):

def __init__(self):
super(Siamese, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 10), # 64@96*96
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 64@48*48
nn.Conv2d(64, 128, 7),
nn.ReLU(), # 128@42*42
nn.MaxPool2d(2), # 128@21*21
nn.Conv2d(128, 128, 4),
nn.ReLU(), # 128@18*18
nn.MaxPool2d(2), # 128@9*9
nn.Conv2d(128, 256, 4),
nn.ReLU(), # 256@6*6
)
self.liner = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid())
self.out = nn.Linear(4096, 1)

def forward_one(self, x):
x = self.conv(x)
x = x.view(x.size()[0], -1)
x = self.liner(x)
return x

def forward(self, x1, x2):
out1 = self.forward_one(x1)
out2 = self.forward_one(x2)
dis = torch.abs(out1 - out2)
out = self.out(dis)
return torch.sigmoid(out)


# for test
if __name__ == '__main__':
net = Siamese()
print(net)
print(list(net.parameters()))

1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
numpy==1.16.1
Pillow==5.4.1
torch==1.0.1.post2
torchvision==0.2.1
opencv-python==4.5.2.54
tqdm==4.61.2
requests==2.25.1
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from setuptools import find_packages, setup

PACKAGE_NAME = "image-comparer"

setup(
name=PACKAGE_NAME,
packages=find_packages(exclude=["tests"]),
install_requires=[package for package in open("requirements.txt").read().split("\n")],
entry_points={
},
package_data={"": ["*.txt", "*.cfg"]},
include_package_data=True,
)
16 changes: 16 additions & 0 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from unittest import TestCase

import cv2
from PIL import Image

import image_comparer

class TestCompare(TestCase):

def setUp(self):
self.image = Image.open("tests/images/kobe.jpg")
self.image2 = cv2.imread("tests/images/kobe2.jpg")

def test_compare_success(self):
assert not image_comparer.is_similar(self.image, self.image2)
assert image_comparer.is_similar(self.image, self.image)

0 comments on commit f00629b

Please sign in to comment.