-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
172 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,5 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,43 @@ | ||
# py-image-comparer | ||
# py-image-comparer | ||
|
||
## Installation | ||
To install, run | ||
|
||
```bash | ||
pip install image-comparer | ||
``` | ||
|
||
## Usage | ||
With PIL | ||
|
||
```python | ||
import image_comparer | ||
from PIL import Image | ||
|
||
image = Image.open("test/kobe.jpg") | ||
image2 = Image.open("test/kobe2.jpg") | ||
image_comparer.is_similar(image, image2, threshold=0.5) | ||
``` | ||
or with OpenCV | ||
|
||
```python | ||
import image_comparer | ||
import cv2 | ||
|
||
image = cv2.imread("test/kobe.jpg") | ||
image2 = cv2.imread("test/kobe2.jpg") | ||
image_comparer.is_similar(image, image2, threshold=0.5) | ||
``` | ||
|
||
## Development | ||
|
||
### Installation | ||
```bash | ||
pip install -r requirements-test.txt | ||
``` | ||
|
||
### Tests | ||
To run tests, run | ||
```bash | ||
pytest | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .compare import is_similar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from pathlib import Path | ||
|
||
from tqdm import tqdm | ||
import requests | ||
|
||
def download_model(model_path: Path, version="v1.0.0", block_Size=1024): | ||
model_url = f"https://github.com/joeyism/siamese-pytorch/releases/download/{version}/siamese-model.pt" | ||
response = requests.get(model_url, stream=True) | ||
total_size_in_bytes= int(response.headers.get('content-length', 0)) | ||
progress_bar = tqdm(desc="Downloading model", total=total_size_in_bytes, unit='iB', unit_scale=True) | ||
with open(model_path.as_posix(), 'wb') as file: | ||
for data in response.iter_content(block_size): | ||
progress_bar.update(len(data)) | ||
file.write(data) | ||
progress_bar.close() | ||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | ||
print("ERROR, something went wrong") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Siamese(nn.Module): | ||
|
||
def __init__(self): | ||
super(Siamese, self).__init__() | ||
self.conv = nn.Sequential( | ||
nn.Conv2d(1, 64, 10), # 64@96*96 | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(2), # 64@48*48 | ||
nn.Conv2d(64, 128, 7), | ||
nn.ReLU(), # 128@42*42 | ||
nn.MaxPool2d(2), # 128@21*21 | ||
nn.Conv2d(128, 128, 4), | ||
nn.ReLU(), # 128@18*18 | ||
nn.MaxPool2d(2), # 128@9*9 | ||
nn.Conv2d(128, 256, 4), | ||
nn.ReLU(), # 256@6*6 | ||
) | ||
self.liner = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid()) | ||
self.out = nn.Linear(4096, 1) | ||
|
||
def forward_one(self, x): | ||
x = self.conv(x) | ||
x = x.view(x.size()[0], -1) | ||
x = self.liner(x) | ||
return x | ||
|
||
def forward(self, x1, x2): | ||
out1 = self.forward_one(x1) | ||
out2 = self.forward_one(x2) | ||
dis = torch.abs(out1 - out2) | ||
out = self.out(dis) | ||
return torch.sigmoid(out) | ||
|
||
|
||
# for test | ||
if __name__ == '__main__': | ||
net = Siamese() | ||
print(net) | ||
print(list(net.parameters())) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
numpy==1.16.1 | ||
Pillow==5.4.1 | ||
torch==1.0.1.post2 | ||
torchvision==0.2.1 | ||
opencv-python==4.5.2.54 | ||
tqdm==4.61.2 | ||
requests==2.25.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from setuptools import find_packages, setup | ||
|
||
PACKAGE_NAME = "image-comparer" | ||
|
||
setup( | ||
name=PACKAGE_NAME, | ||
packages=find_packages(exclude=["tests"]), | ||
install_requires=[package for package in open("requirements.txt").read().split("\n")], | ||
entry_points={ | ||
}, | ||
package_data={"": ["*.txt", "*.cfg"]}, | ||
include_package_data=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from unittest import TestCase | ||
|
||
import cv2 | ||
from PIL import Image | ||
|
||
import image_comparer | ||
|
||
class TestCompare(TestCase): | ||
|
||
def setUp(self): | ||
self.image = Image.open("tests/images/kobe.jpg") | ||
self.image2 = cv2.imread("tests/images/kobe2.jpg") | ||
|
||
def test_compare_success(self): | ||
assert not image_comparer.is_similar(self.image, self.image2) | ||
assert image_comparer.is_similar(self.image, self.image) |