diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 63cf25fe5..b390639f3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index ab0ee82b1..88725d5c3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 72a3bf289..3fc054984 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 934538b63..f218184d7 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index f6ada3c8e..14bca5730 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 9c7f66c43..4e1e523a2 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f968d4abf..076658093 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 14c22141c..d9dde586e 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 7e41e9fd7..abb598fd4 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 81110bae6..def94296b 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -61,10 +61,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 3f75c9904..4fd7d2212 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -83,10 +83,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index d326f4035..c14de49ab 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -40,10 +40,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 39cf3d4f9..ce4bfebb0 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 880f9168d..17c5d8a03 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -198,10 +198,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 06eeacb39..793a3f1de 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 0710fb9a0..225924b98 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -59,10 +59,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 891da63be..63b0cb219 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -126,10 +126,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7886dc75d..7c545d7ab 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -197,10 +197,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index dc101896b..b173ba8ba 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 52aba82bf..c063f0a64 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 63cf25fe5..b390639f3 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 72a3bf289..3fc054984 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index e47c7fa0c..35ef2bfa8 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 442949866..0b7cc570b 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 95bea68aa..da2208519 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -205,10 +205,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 15b6b6858..a793673f9 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -139,10 +139,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index e853a821b..504dff0d1 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -121,10 +121,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index a98d134fc..999422fb0 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -77,10 +77,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 586429e37..92f222a18 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -20,10 +20,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 9bfb23367..b8a394322 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from algorithmic_efficiency import spec @@ -30,10 +30,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn)