Skip to content

Commit

Permalink
fix all optional types in train config
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 7, 2023
1 parent 512b52b commit c6c3882
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
36 changes: 18 additions & 18 deletions dalle2_pytorch/train_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool =
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
base_model_kwargs: Optional[Dict[str, Any]] = None

def create(self):
if self.make == "openai":
Expand All @@ -134,8 +134,8 @@ def create(self):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
max_text_len: Optional[int] = None
num_timesteps: Optional[int] = None
num_time_embeds: int = 1
num_image_embeds: int = 1
num_text_embeds: int = 1
Expand All @@ -158,7 +158,7 @@ def create(self):
return DiffusionPriorNetwork(**kwargs)

class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig = None
clip: Optional[AdapterConfig] = None
net: DiffusionPriorNetworkConfig
image_embed_dim: int
image_size: int
Expand Down Expand Up @@ -195,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
warmup_steps: int = None # number of warmup steps
warmup_steps: Optional[int] = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
Expand Down Expand Up @@ -228,10 +228,10 @@ def from_json_path(cls, json_path):
class UnetConfig(BaseModel):
dim: int
dim_mults: ListOrTuple[int]
image_embed_dim: int = None
text_embed_dim: int = None
cond_on_text_encodings: bool = None
cond_dim: int = None
image_embed_dim: Optional[int] = None
text_embed_dim: Optional[int] = None
cond_on_text_encodings: Optional[bool] = None
cond_dim: Optional[int] = None
channels: int = 3
self_attn: ListOrTuple[int]
attn_dim_head: int = 32
Expand All @@ -243,14 +243,14 @@ class Config:

class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig]
image_size: int = None
image_size: Optional[int] = None
image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine
beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5
Expand Down Expand Up @@ -320,20 +320,20 @@ class DecoderTrainConfig(BaseModel):
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: Optional[int] = None # Same as above but for validation.
save_immediately: bool = False
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets
unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets

class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
FID: Optional[Dict[str, Any]] = None
IS: Optional[Dict[str, Any]] = None
KID: Optional[Dict[str, Any]] = None
LPIPS: Optional[Dict[str, Any]] = None

class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.15.2'
__version__ = '1.15.3'

0 comments on commit c6c3882

Please sign in to comment.