Skip to content

Commit

Permalink
🚀 - Release working version of auto-tag
Browse files Browse the repository at this point in the history
  • Loading branch information
Thusuzzee committed Jul 28, 2024
0 parents commit bf31e7f
Show file tree
Hide file tree
Showing 9 changed files with 16,564 additions and 0 deletions.
165 changes: 165 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Model
model-resnet_custom_v3.h5

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Auto-Tag for Anime Images
Original code by [Epsp0](https://github.com/Epsp0) in [auto-tag-anime](https://github.com/Epsp0/auto-tag-anime)\
Model provided by [KichangKim](https://github.com/KichangKim) in [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)

## Requrements
- Python 3.10

## Instructions
1. **Download the model**: Download from [Google Drive](https://drive.google.com/file/d/1qffwjF-BHV6MkPVliLO1jZwMQatri06v) or [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) and place it in the `./model` folder.
2. **Install dependencies**: Run the following command to install required packages:
```bash
pip install -r requirements.txt
```

## Usage
```bash
python3 auto-tag.py "/path/to/directory/"
```

## Important Disclaimer

**Do not modify `tags.txt`**: This file contains all possible tags used by the model and must remain unchanged to ensure the model functions correctly.

**Modify only `metadata.txt`**: If you want to control which tags should be included in the JSON output, edit the `metadata.txt` file. Add or remove tags from this file to customize the tags that will be saved to the JSON.

## Simple Version

For users who prefer a simpler version where character tags are mixed in with other tags and cannot be customized, please switch to the `simple-tag` branch.\
This version includes all tags by default and does not allow modifying the tags list.

To switch to the `simple-tag` branch, run:
```bash
git checkout simple-tag
```
In this version, you will get a single set of tags for each image without the option to exclude characters or specify custom tags.
62 changes: 62 additions & 0 deletions auto-tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import sys
import time
from tags import load_to_memory, save_to_json
from model import DeepDanbooruModel

class AddAnimeTags:
def __init__(self):
self.model = DeepDanbooruModel()
self.directory = None

def navigate_directory(self, path: str):
start_time = time.time()

if os.path.isdir(path):
self.directory = path
for root, _, files in os.walk(path):
for filename in files:
file_path = os.path.join(root, filename)
self.classify_and_add_tags(file_path)
self.save_tags()
else:
self.classify_and_add_tags(path)
self.directory = os.path.dirname(path)
self.save_tags()

total_time = time.time() - start_time
print(f"Total operation time: {total_time:.2f} seconds")

def classify_and_add_tags(self, path: str):
image_start_time = time.time()

status, tags, characters = self.model.classify_image(path)
if status == 'success':
load_to_memory(path, tags, characters)
image_time = time.time() - image_start_time
num_tags = len(tags)
num_characters = len(characters)
print(f"[Success] [{image_time:.2f} seconds] [{num_tags} tags, {num_characters} characters added] [{path}]")
else:
image_time = time.time() - image_start_time
print(f"[Failed] [{image_time:.2f} seconds] [No tags] [{path}]")

def save_tags(self):
if self.directory:
json_file_path = os.path.join(self.directory, 'tags.json')
save_to_json(json_file_path)
print(f"Tags saved to {json_file_path}")

def parse_args():
if len(sys.argv) < 2:
print("Usage: python auto-tag-anime.py \"<path>\"")
sys.exit()

if not os.path.exists(sys.argv[1]):
print('Path does not exist')
sys.exit()

if __name__ == "__main__":
parse_args()
add_anime_tags = AddAnimeTags()
add_anime_tags.navigate_directory(sys.argv[1])
110 changes: 110 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import sys
import numpy as np
from PIL import Image
import tensorflow as tf


class DeepDanbooruModel:
THRESHOLD = 0.75 # Increase this to achieve more accurate results or decrease it for less accurate results.
METADATA_PATH = "./tags/metadata.txt" # Modify the tags you prefer to display exclusively.
CHARACTERS_PATH = "./tags/characters.txt" # Modify the characters you prefer to display exclusively.

# Do not change the paths or settings below here.
MODEL_PATH = "./model/model-resnet_custom_v3.h5"
TAGS_PATH = "./model/tags.txt"
IMAGE_SIZE = (512, 512)

def __init__(self):
self.model = self.load_model()
self.tags = self.load_tags()
self.characters = self.load_characters()
self.metadata_tags = self.load_metadata_tags()

def load_model(self) -> tf.keras.Model:
print('Loading model...')
if not os.path.exists(self.MODEL_PATH):
self.model_not_found_error(self.MODEL_PATH)

try:
model = tf.keras.models.load_model(self.MODEL_PATH, compile=False)
print('Model loaded successfully.')
except Exception as e:
print(f'Failed to load the model. Error: {e}')
sys.exit()

return model

def load_tags(self) -> np.ndarray:
if not os.path.exists(self.TAGS_PATH):
self.model_not_found_error(self.TAGS_PATH)

try:
with open(self.TAGS_PATH, 'r') as tags_stream:
tags = np.array([tag.strip() for tag in tags_stream if tag.strip()])
print(f'Tags loaded successfully.')
except Exception as e:
print(f'Failed to load tags. Error: {e}')
sys.exit()

return tags

def load_characters(self) -> set:
if not os.path.exists(self.CHARACTERS_PATH):
self.model_not_found_error(self.CHARACTERS_PATH)

try:
with open(self.CHARACTERS_PATH, 'r') as characters_stream:
characters = {character.strip() for character in characters_stream if character.strip()}
print(f'Characters loaded successfully. Number of characters: {len(characters)}')
except Exception as e:
print(f'Failed to load characters. Error: {e}')
sys.exit()

return characters

def load_metadata_tags(self) -> set:
if not os.path.exists(self.METADATA_PATH):
self.model_not_found_error(self.METADATA_PATH)

try:
with open(self.METADATA_PATH, 'r') as metadata_stream:
metadata_tags = {tag.strip() for tag in metadata_stream if tag.strip()}
print(f'Metadata tags loaded successfully. Number of metadata tags: {len(metadata_tags)}')
except Exception as e:
print(f'Failed to load metadata tags. Error: {e}')
sys.exit()

return metadata_tags

@staticmethod
def model_not_found_error(path: str):
print(f'File not found at {path}')
print('Please download the required file from https://github.com/KichangKim/DeepDanbooru')
sys.exit()

def classify_image(self, image_path: str) -> tuple[str, list[str], list[str]]:
try:
image = self.load_image(image_path)
except IOError:
return 'fail', [], []

results = self.model.predict(np.array([image]))

if results.shape[1] != self.tags.shape[0]:
print("Mismatch between model output and number of tags!")
return 'fail', [], []

result_tags = self.get_result_tags(results.reshape(-1))

tags = [tag for tag in result_tags.keys() if tag in self.metadata_tags]
characters = [tag for tag in result_tags.keys() if tag in self.characters]

return 'success', tags, characters

def load_image(self, image_path: str) -> np.ndarray:
image = Image.open(image_path).convert('RGB').resize(self.IMAGE_SIZE)
return np.array(image) / 255.0

def get_result_tags(self, results: np.ndarray) -> dict[str, float]:
return {self.tags[i]: results[i] for i in range(len(self.tags)) if results[i] > self.THRESHOLD}
Loading

0 comments on commit bf31e7f

Please sign in to comment.