Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass train_state to update_params #790

Merged
merged 10 commits into from
Oct 29, 2024
Merged
2 changes: 2 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def update_params(
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
Expand All @@ -212,6 +213,7 @@ def update_params(
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
- Allowed to update state for the optimizer.
- Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
- The submission can access the target evaluation metric via the `workload` variable.
- **A call to this function will be considered a step**
- The time between a call to this function and the next call to this function will be considered the per-step time.
Expand Down
2 changes: 2 additions & 0 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def init_optimizer_state(workload: Workload,
Dict[str, Tensor],
LossType,
OptimizerState,
Dict[str, Any],
List[Tuple[int, float]],
int,
RandomState
Expand All @@ -422,6 +423,7 @@ def update_params(workload: Workload,
batch: Dict[str, Tensor],
loss_type: LossType,
optimizer_state: OptimizerState,
train_state: Dict[str, Any],
Niccolo-Ajroldi marked this conversation as resolved.
Show resolved Hide resolved
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""

import math
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del hyperparameters

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,13 +118,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del global_step
del train_state
del eval_results
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for CIFAR10."""

from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand Down Expand Up @@ -61,13 +61,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Training algorithm track submission functions for MNIST."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -83,12 +83,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
del global_step

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Training algorithm track submission functions for MNIST."""

from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

import torch

Expand Down Expand Up @@ -40,13 +40,15 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
del current_params_types
del train_state
del eval_results
del global_step

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an Adafactor optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for Adafactor in PyTorch."""

from functools import partial
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -198,12 +198,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

optimizer_state, opt_update_fn = optimizer_state
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch."""

from typing import Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple

from absl import logging
import torch
Expand Down Expand Up @@ -59,12 +59,14 @@ def update_params(workload: spec.Workload,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
train_state: Dict[str, Any],
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
Expand Down
Loading
Loading