Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some small things #18

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/models/cc12m_1024x1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ unet_config:
temporal_mode: false
temporal_positional_encoding: false
temporal_spatial_ds: false
diffusion_config:
sampler_config:
num_diffusion_steps: 1000
reproject_signal: False
schedule_type: DEEPFLOYD
prediction_type: V_PREDICTION
loss_target_type: DDPM
beta_start: 0.0001
beta_end: 0.02
threshold_function: CLIP
rescale_schedule: 1.0
schedule_shifted: True
model_output_scale: 0.0
use_vdm_loss_weights: False
no_use_residual: true
tolgacangoz marked this conversation as resolved.
Show resolved Hide resolved

# import defaults
# reader-config-file: configs/datasets/reader_config.yaml
Expand Down
17 changes: 16 additions & 1 deletion configs/models/cc12m_256x256.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ unet_config:
conditioning_feature_dim: -1
conditioning_feature_proj_dim: -1
freeze_inner_unet: false
initialize_inner_with_pretrained: None
initialize_inner_with_pretrained: null
tolgacangoz marked this conversation as resolved.
Show resolved Hide resolved
inner_config:
attention_levels: [1, 2]
conditioning_feature_dim: -1
Expand Down Expand Up @@ -76,6 +76,21 @@ unet_config:
temporal_mode: false
temporal_positional_encoding: false
temporal_spatial_ds: false
diffusion_config:
sampler_config:
num_diffusion_steps: 1000
reproject_signal: False
schedule_type: DEEPFLOYD
prediction_type: V_PREDICTION
loss_target_type: DDPM
beta_start: 0.0001
beta_end: 0.02
threshold_function: CLIP
rescale_schedule: 1.0
schedule_shifted: True
model_output_scale: 0.0
use_vdm_loss_weights: False
no_use_residual: true

reader_config:
image_size: 256
Expand Down
4 changes: 2 additions & 2 deletions ml_mdm/clis/generate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def generate(
prompt = input_template.format(prompt=prompt)
if len(negative_template) > 0:
negative_prompt = negative_prompt + negative_template
print(f"Postive: {prompt} / Negative: {negative_prompt}")
print(f"Positive: {prompt} / Negative: {negative_prompt}")

if not os.path.exists(ckpt_name):
logging.info(f"Did not generate because {ckpt_name} does not exist")
Expand Down Expand Up @@ -478,7 +478,7 @@ def main(args):

with gr.Column(scale=2):
with gr.Accordion(
"Addditional outputs", open=False, elem_id="output-accordion"
"Additional outputs", open=False, elem_id="output-accordion"
):
with gr.Row(equal_height=True):
output_text = gr.Textbox(value=None, label="System output")
Expand Down
3 changes: 3 additions & 0 deletions ml_mdm/models/nested_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Nested4UNetConfig(Nested3UNetConfig):


def download(vision_model_path):
from ml_mdm.utils import fix_old_checkpoints
fix_old_checkpoints.mimic_old_modules()
tolgacangoz marked this conversation as resolved.
Show resolved Hide resolved

import os

from distributed import get_local_rank
Expand Down
8 changes: 4 additions & 4 deletions ml_mdm/samplers.py
tolgacangoz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def get_xt(self, x0, eps, g, scales):
x_t += [
super().get_xt(
self.get_image_rescaled(x, s)
if not self._config.schedule_shifted
if self._config.schedule_shifted
else x,
e,
gi,
Expand All @@ -611,7 +611,7 @@ def get_prediction_targets(self, x0, eps, g, g_last, scales, prediction_type=Non
tgt += [
super().get_prediction_targets(
self.get_image_rescaled(x, s)
if not self._config.schedule_shifted
if self._config.schedule_shifted
else x,
e,
gi,
Expand Down Expand Up @@ -668,7 +668,7 @@ def get_xt_minus_1(
need_noise=time_step != 1,
ddim_eta=ddim_eta,
clip_fn=self.clip_sample,
image_scale=s if not self._config.schedule_shifted else 1,
image_scale=s if self._config.schedule_shifted else 1,
)
for x, p, g, g_last, s in zip(x_t, p_t, g_t, g_s, scales)
]
Expand All @@ -693,7 +693,7 @@ def _postprocess(
):
scales = [
x_t[i].size(-1) / x_t[-1].size(-1)
if not self._config.schedule_shifted
if self._config.schedule_shifted
else 1
for i in range(len(x_t))
]
Expand Down