-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🚀 - Release working version of auto-tag
- Loading branch information
0 parents
commit bf31e7f
Showing
9 changed files
with
16,564 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Model | ||
model-resnet_custom_v3.h5 | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control | ||
.pdm.toml | ||
.pdm-python | ||
.pdm-build/ | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Auto-Tag for Anime Images | ||
Original code by [Epsp0](https://github.com/Epsp0) in [auto-tag-anime](https://github.com/Epsp0/auto-tag-anime)\ | ||
Model provided by [KichangKim](https://github.com/KichangKim) in [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) | ||
|
||
## Requrements | ||
- Python 3.10 | ||
|
||
## Instructions | ||
1. **Download the model**: Download from [Google Drive](https://drive.google.com/file/d/1qffwjF-BHV6MkPVliLO1jZwMQatri06v) or [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) and place it in the `./model` folder. | ||
2. **Install dependencies**: Run the following command to install required packages: | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
```bash | ||
python3 auto-tag.py "/path/to/directory/" | ||
``` | ||
|
||
## Important Disclaimer | ||
|
||
**Do not modify `tags.txt`**: This file contains all possible tags used by the model and must remain unchanged to ensure the model functions correctly. | ||
|
||
**Modify only `metadata.txt`**: If you want to control which tags should be included in the JSON output, edit the `metadata.txt` file. Add or remove tags from this file to customize the tags that will be saved to the JSON. | ||
|
||
## Simple Version | ||
|
||
For users who prefer a simpler version where character tags are mixed in with other tags and cannot be customized, please switch to the `simple-tag` branch.\ | ||
This version includes all tags by default and does not allow modifying the tags list. | ||
|
||
To switch to the `simple-tag` branch, run: | ||
```bash | ||
git checkout simple-tag | ||
``` | ||
In this version, you will get a single set of tags for each image without the option to exclude characters or specify custom tags. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
import sys | ||
import time | ||
from tags import load_to_memory, save_to_json | ||
from model import DeepDanbooruModel | ||
|
||
class AddAnimeTags: | ||
def __init__(self): | ||
self.model = DeepDanbooruModel() | ||
self.directory = None | ||
|
||
def navigate_directory(self, path: str): | ||
start_time = time.time() | ||
|
||
if os.path.isdir(path): | ||
self.directory = path | ||
for root, _, files in os.walk(path): | ||
for filename in files: | ||
file_path = os.path.join(root, filename) | ||
self.classify_and_add_tags(file_path) | ||
self.save_tags() | ||
else: | ||
self.classify_and_add_tags(path) | ||
self.directory = os.path.dirname(path) | ||
self.save_tags() | ||
|
||
total_time = time.time() - start_time | ||
print(f"Total operation time: {total_time:.2f} seconds") | ||
|
||
def classify_and_add_tags(self, path: str): | ||
image_start_time = time.time() | ||
|
||
status, tags, characters = self.model.classify_image(path) | ||
if status == 'success': | ||
load_to_memory(path, tags, characters) | ||
image_time = time.time() - image_start_time | ||
num_tags = len(tags) | ||
num_characters = len(characters) | ||
print(f"[Success] [{image_time:.2f} seconds] [{num_tags} tags, {num_characters} characters added] [{path}]") | ||
else: | ||
image_time = time.time() - image_start_time | ||
print(f"[Failed] [{image_time:.2f} seconds] [No tags] [{path}]") | ||
|
||
def save_tags(self): | ||
if self.directory: | ||
json_file_path = os.path.join(self.directory, 'tags.json') | ||
save_to_json(json_file_path) | ||
print(f"Tags saved to {json_file_path}") | ||
|
||
def parse_args(): | ||
if len(sys.argv) < 2: | ||
print("Usage: python auto-tag-anime.py \"<path>\"") | ||
sys.exit() | ||
|
||
if not os.path.exists(sys.argv[1]): | ||
print('Path does not exist') | ||
sys.exit() | ||
|
||
if __name__ == "__main__": | ||
parse_args() | ||
add_anime_tags = AddAnimeTags() | ||
add_anime_tags.navigate_directory(sys.argv[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import sys | ||
import numpy as np | ||
from PIL import Image | ||
import tensorflow as tf | ||
|
||
|
||
class DeepDanbooruModel: | ||
THRESHOLD = 0.75 # Increase this to achieve more accurate results or decrease it for less accurate results. | ||
METADATA_PATH = "./tags/metadata.txt" # Modify the tags you prefer to display exclusively. | ||
CHARACTERS_PATH = "./tags/characters.txt" # Modify the characters you prefer to display exclusively. | ||
|
||
# Do not change the paths or settings below here. | ||
MODEL_PATH = "./model/model-resnet_custom_v3.h5" | ||
TAGS_PATH = "./model/tags.txt" | ||
IMAGE_SIZE = (512, 512) | ||
|
||
def __init__(self): | ||
self.model = self.load_model() | ||
self.tags = self.load_tags() | ||
self.characters = self.load_characters() | ||
self.metadata_tags = self.load_metadata_tags() | ||
|
||
def load_model(self) -> tf.keras.Model: | ||
print('Loading model...') | ||
if not os.path.exists(self.MODEL_PATH): | ||
self.model_not_found_error(self.MODEL_PATH) | ||
|
||
try: | ||
model = tf.keras.models.load_model(self.MODEL_PATH, compile=False) | ||
print('Model loaded successfully.') | ||
except Exception as e: | ||
print(f'Failed to load the model. Error: {e}') | ||
sys.exit() | ||
|
||
return model | ||
|
||
def load_tags(self) -> np.ndarray: | ||
if not os.path.exists(self.TAGS_PATH): | ||
self.model_not_found_error(self.TAGS_PATH) | ||
|
||
try: | ||
with open(self.TAGS_PATH, 'r') as tags_stream: | ||
tags = np.array([tag.strip() for tag in tags_stream if tag.strip()]) | ||
print(f'Tags loaded successfully.') | ||
except Exception as e: | ||
print(f'Failed to load tags. Error: {e}') | ||
sys.exit() | ||
|
||
return tags | ||
|
||
def load_characters(self) -> set: | ||
if not os.path.exists(self.CHARACTERS_PATH): | ||
self.model_not_found_error(self.CHARACTERS_PATH) | ||
|
||
try: | ||
with open(self.CHARACTERS_PATH, 'r') as characters_stream: | ||
characters = {character.strip() for character in characters_stream if character.strip()} | ||
print(f'Characters loaded successfully. Number of characters: {len(characters)}') | ||
except Exception as e: | ||
print(f'Failed to load characters. Error: {e}') | ||
sys.exit() | ||
|
||
return characters | ||
|
||
def load_metadata_tags(self) -> set: | ||
if not os.path.exists(self.METADATA_PATH): | ||
self.model_not_found_error(self.METADATA_PATH) | ||
|
||
try: | ||
with open(self.METADATA_PATH, 'r') as metadata_stream: | ||
metadata_tags = {tag.strip() for tag in metadata_stream if tag.strip()} | ||
print(f'Metadata tags loaded successfully. Number of metadata tags: {len(metadata_tags)}') | ||
except Exception as e: | ||
print(f'Failed to load metadata tags. Error: {e}') | ||
sys.exit() | ||
|
||
return metadata_tags | ||
|
||
@staticmethod | ||
def model_not_found_error(path: str): | ||
print(f'File not found at {path}') | ||
print('Please download the required file from https://github.com/KichangKim/DeepDanbooru') | ||
sys.exit() | ||
|
||
def classify_image(self, image_path: str) -> tuple[str, list[str], list[str]]: | ||
try: | ||
image = self.load_image(image_path) | ||
except IOError: | ||
return 'fail', [], [] | ||
|
||
results = self.model.predict(np.array([image])) | ||
|
||
if results.shape[1] != self.tags.shape[0]: | ||
print("Mismatch between model output and number of tags!") | ||
return 'fail', [], [] | ||
|
||
result_tags = self.get_result_tags(results.reshape(-1)) | ||
|
||
tags = [tag for tag in result_tags.keys() if tag in self.metadata_tags] | ||
characters = [tag for tag in result_tags.keys() if tag in self.characters] | ||
|
||
return 'success', tags, characters | ||
|
||
def load_image(self, image_path: str) -> np.ndarray: | ||
image = Image.open(image_path).convert('RGB').resize(self.IMAGE_SIZE) | ||
return np.array(image) / 255.0 | ||
|
||
def get_result_tags(self, results: np.ndarray) -> dict[str, float]: | ||
return {self.tags[i]: results[i] for i in range(len(self.tags)) if results[i] > self.THRESHOLD} |
Oops, something went wrong.