Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass train_state to update_params #790

Merged
merged 10 commits into from
Oct 29, 2024
Merged
2 changes: 2 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def update_params(
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
Expand All @@ -212,6 +213,7 @@ def update_params(
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
- Allowed to update state for the optimizer.
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
- The submission can access the target evaluation metric via the `workload` variable.
- **A call to this function will be considered a step**
- The time between a call to this function and the next call to this function will be considered the per-step time.
Expand Down
6 changes: 4 additions & 2 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
RandomState,
Optional[Dict[str, Any]]
],
UpdateReturn]

Expand All @@ -424,7 +425,8 @@ def update_params(workload: Workload,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
rng: RandomState,
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,23 @@ def _loss_fn(params):
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
return optimizer_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -110,21 +110,24 @@ def _loss_fn(params):

# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del global_step
del train_state
del eval_results
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
Expand Down
Loading
Loading