Skip to content

Commit

Permalink
to device support (#10)
Browse files Browse the repository at this point in the history
* to device support
  • Loading branch information
shonenkov authored Nov 3, 2021
1 parent acce991 commit 0b3d648
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ runs/
jupyters/custom_*

*logs/
.DS_store
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
### Generate images from texts

```
pip install rudalle==0.0.1rc2
pip install rudalle==0.0.1rc3
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
Expand Down
2 changes: 1 addition & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
'image_prompts',
]

__version__ = '0.0.1-rc2'
__version__ = '0.0.1-rc3'
3 changes: 3 additions & 0 deletions rudalle/dalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
# TODO docstring
assert name in MODELS

if fp16 and device == 'cpu':
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')

config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
if pretrained:
Expand Down
4 changes: 4 additions & 0 deletions rudalle/dalle/fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ def load_state_dict(self, state_dict, strict=True):

def get_param(self, item):
return self.module.get_param(item)

def to(self, device, *args, **kwargs):
self.module.to(device)
return super().to(device, *args, **kwargs)
6 changes: 6 additions & 0 deletions rudalle/dalle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,9 @@ def forward(

loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}

def to(self, device, *args, **kwargs):
self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs)

0 comments on commit 0b3d648

Please sign in to comment.