Skip to content

Commit

Permalink
Merge pull request #124 from ai-forever/feature/kandinsky
Browse files Browse the repository at this point in the history
Feature/kandinsky
  • Loading branch information
shonenkov authored Jun 22, 2022
2 parents 08c18de + 131e3e5 commit 2015898
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 121 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)

```
pip install rudalle==1.1.0rc0
pip install rudalle==1.1.0
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) \
[ruDALL-E Emojich (XL)](https://huggingface.co/sberbank-ai/rudalle-Emojich) (readme [here](https://github.com/sberbank-ai/ru-dalle/blob/master/Emojich.md)) \
[ruDALL-E Surrealist (XL)](https://huggingface.co/shonenkov-AI/rudalle-xl-surrealist)

[ruDALL-E Surrealist (XL)](https://huggingface.co/shonenkov-AI/rudalle-xl-surrealist) \
ruDALL-E Kandinsky (XXL) (soon)

### Minimal Example:

Expand Down Expand Up @@ -100,6 +100,15 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
![](https://raw.githubusercontent.com/shonenkov-AI/rudalle-aspect-ratio/main/pics/h_example.jpg)


### [Kandinsky]()

`роботы акварелью в стиле ван гога`
![](./pics/kandinsky/example-robots.png)

[![](./pics/habr_eng.svg)](https://habr.com/ru/company/sberbank/blog/671210/)

![](./pics/kandinsky/loss.jpg)
`FID = 15.4 (COCO Valid)`

### 🚀 Contributors 🚀

Expand Down
Binary file added pics/kandinsky/example-robots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pics/kandinsky/loss.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
'image_prompts',
]

__version__ = '1.1.0rc0'
__version__ = '1.1.0'
49 changes: 36 additions & 13 deletions rudalle/dalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
),
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model_v3.bin',
authors='SberAI, SberDevices',
authors='SberAI, SberDevices, shonenkovAI',
full_description='', # TODO
),
'Malevich_v2': dict(
Expand All @@ -52,12 +52,12 @@
),
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model_v2.bin',
authors='SberAI, SberDevices',
authors='SberAI, SberDevices, shonenkovAI',
full_description='', # TODO
),
'Emojich': dict(
hf_version='v2',
description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, '
description='😋 Emojich is 1.3 billion params model from the family GPT3-like, '
'it generates emoji-style images with the brain of ◾ Malevich.',
model_params=dict(
num_layers=24,
Expand All @@ -75,9 +75,32 @@
),
repo_id='sberbank-ai/rudalle-Emojich',
filename='pytorch_model.bin',
authors='SberAI',
authors='SberAI, SberDevices, shonenkovAI',
full_description='', # TODO
),
'Surrealist_XL': dict(
hf_version='v3',
description='Surrealist is 1.3 billion params model from the family GPT3-like, '
'that was trained on surrealism and Russian.',
model_params=dict(
num_layers=24,
hidden_size=2048,
num_attention_heads=16,
embedding_dropout_prob=0.1,
output_dropout_prob=0.1,
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384 + 128,
image_vocab_size=8192,
),
repo_id='shonenkov-AI/rudalle-xl-surrealist',
filename='pytorch_model.bin',
authors='shonenkovAI',
full_description='',
),
'Kandinsky': dict(
hf_version='v3',
description='Kandinsky is large 12 billion params model from the family GPT3-like, '
Expand All @@ -93,17 +116,16 @@
text_seq_length=128,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
cogview_layernorm_prescale=True,
custom_relax=True,
vocab_size=16384 + 128,
image_vocab_size=8192,
),
repo_id='',
filename='',
authors='SberAI, SberDevices',
repo_id='shonenkov-AI/Kandinsky',
filename='pytorch_model.bin',
authors='SberAI, SberDevices, shonenkovAI',
full_description='', # TODO
),
'dummy': dict(
hf_version='v3',
description='',
model_params=dict(
num_layers=12,
Expand All @@ -126,20 +148,21 @@
}


def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs):
# TODO docstring
def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', use_auth_token=None,
cache_dir='/tmp/rudalle', **model_kwargs):
assert name in MODELS

if fp16 and device == 'cpu':
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')

config = MODELS[name].copy()
config['model_params'].update(model_kwargs)
model = DalleModel(device=device, **config['model_params'])
model = DalleModel(device=device, hf_version=config['hf_version'], **config['model_params'])
if pretrained:
cache_dir = os.path.join(cache_dir, name)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'],
use_auth_token=use_auth_token)
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
model.load_state_dict(checkpoint)
if fp16:
Expand Down
12 changes: 4 additions & 8 deletions rudalle/dalle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def __init__(self,
loss_img_weight=7,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False,
cogview_layernorm_prescale=False,
custom_relax=False,
is_bool_mask=True,
mlp_activation='gelu_jit',
hf_version='v3'):
Expand Down Expand Up @@ -73,8 +71,6 @@ def __init__(self,
image_tokens_per_dim=image_tokens_per_dim,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
cogview_layernorm_prescale=cogview_layernorm_prescale,
custom_relax=custom_relax,
mlp_activation=mlp_activation,
is_bool_mask=is_bool_mask,
hf_version=self.hf_version,
Expand Down Expand Up @@ -110,13 +106,13 @@ def forward(
if self.hf_version == 'v2':
# some hardcode :)
text = F.pad(text, (1, 0), value=2)
text_pos = self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
text_embeddings = self.text_embeddings(text) + text_pos
text_embeddings = self.text_embeddings(text) + \
self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
image_input_ids = input_ids[:, self.text_seq_length:]

if exists(image_input_ids) and not is_empty(image_input_ids):
img_pos = self.get_image_pos_embeddings(image_input_ids, past_length=0)
image_embeddings = self.image_embeddings(image_input_ids) + img_pos
image_embeddings = self.image_embeddings(image_input_ids) + \
self.get_image_pos_embeddings(image_input_ids, past_length=0)
embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
else:
embeddings = text_embeddings
Expand Down
Loading

0 comments on commit 2015898

Please sign in to comment.