Skip to content

Commit

Permalink
Merge branch 'sd3' into new_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Dec 9, 2024
2 parents 28e9352 + e425996 commit 70423ec
Show file tree
Hide file tree
Showing 9 changed files with 761 additions and 684 deletions.
48 changes: 25 additions & 23 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,42 +1,44 @@

name: Python package

on: [push]
name: Test with pytest

on:
push:
branches:
- main
- dev
- sd3
pull_request:
branches:
- main
- dev
- sd3

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
python-version: ["3.10"] # Python versions to test
pytorch-version: ["2.4.0"] # PyTorch versions to test

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- uses: actions/setup-python@v5
with:
python-version: '3.x'

- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel
python-version: ${{ matrix.python-version }}
cache: 'pip'

- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
cache: 'pip' # caching pip dependencies
- name: Install and update pip, setuptools, wheel
run: |
# Setuptools, wheel for compiling some packages
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest
pytest
run: pytest # See pytest.ini for configuration

9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ The command to install PyTorch is as follows:

### Recent Updates

Dec 7, 2024:

- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds!
<!--
Also, the ControlNet training script for SD has been changed from `train_controlnet.py` to `train_control_net.py`.
- `train_controlnet.py` is still available, but it will be removed in the future.
-->

- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.

Dec 3, 2024:

Expand Down
2 changes: 1 addition & 1 deletion flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def train(args):
# load controlnet
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
controlnet = flux_utils.load_controlnet(
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
)
controlnet.train()

Expand Down
2 changes: 1 addition & 1 deletion library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
parser.add_argument(
"--controlnet",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
Expand Down
6 changes: 4 additions & 2 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,10 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
self.use_scaled_pos_embed = use_scaled_pos_embed

if self.use_scaled_pos_embed:
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved
# self.pos_embed = None
# move pos_embed to CPU to free up memory up to 0.4 GB
self.pos_embed = self.pos_embed.cpu()

# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
Expand Down
8 changes: 4 additions & 4 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def unwrap_model(model):

# make control net
logger.info("make ControlNet")
if args.controlnet_model_path:
if args.controlnet_model_name_or_path:
with init_empty_weights():
control_net = SdxlControlNet()

logger.info(f"load ControlNet from {args.controlnet_model_path}")
filename = args.controlnet_model_path
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
filename = args.controlnet_model_name_or_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
Expand Down Expand Up @@ -679,7 +679,7 @@ def setup_parser() -> argparse.ArgumentParser:
sdxl_train_util.add_sdxl_training_arguments(parser)

parser.add_argument(
"--controlnet_model_path",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
Expand Down
41 changes: 41 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Tests

## Install

```
pip install pytest
```

## Usage

```
pytest
```

## Contribution

Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests.

Tests are functions starting with `test_` and files with the pattern `test_*.py`.

```
def test_x():
assert 1 == 2, "Invalid test response"
```

## Resources

### pytest

- https://docs.pytest.org/en/stable/index.html
- https://docs.pytest.org/en/stable/how-to/assert.html
- https://docs.pytest.org/en/stable/how-to/doctest.html

### PyTorch testing

- https://circleci.com/blog/testing-pytorch-model-with-pytest/
- https://pytorch.org/docs/stable/testing.html
- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
- https://github.com/huggingface/pytorch-image-models/tree/main/tests
- https://github.com/pytorch/pytorch/tree/main/test

Loading

0 comments on commit 70423ec

Please sign in to comment.