Skip to content

Commit

Permalink
chore(components): Update Starry Net description and add nan- and zer…
Browse files Browse the repository at this point in the history
…o-threshold args

Signed-off-by: Googler <[email protected]>
PiperOrigin-RevId: 650295521
  • Loading branch information
Googler committed Jul 8, 2024
1 parent 38ef986 commit 0eae430
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions components/google-cloud/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
## Upcoming release
* Updated the Starry Net pipeline's template gallery description, and added dataprep_nan_threshold and dataprep_zero_threshold args to the Starry Net pipeline.
* Add support for running tasks on a `PersistentResource` (see [CustomJobSpec](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/CustomJobSpec)) via `persistent_resource_id` parameter on `v1.custom_job.CustomTrainingJobOp` and `v1.custom_job.create_custom_training_job_from_component`

## Release 2.15.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def dataprep(
disk_size_gb: int,
test_set_only: bool,
bigquery_output: str,
nan_threshold: float,
zero_threshold: float,
gcs_source: str,
gcs_static_covariate_source: str,
encryption_spec_key_name: str,
Expand Down Expand Up @@ -90,6 +92,13 @@ def dataprep(
to create TFRecords for traiing and validation.
bigquery_output: The BigQuery dataset where the test set is written in the
form bq://project.dataset.
nan_threshold: Series having more nan / missing values than
nan_threshold (inclusive) in percentage for either backtest or forecast
will not be sampled in the training set (including missing due to
train_start and train_end). All existing nans are replaced by zeros.
zero_threshold: Series having more 0.0 values than zero_threshold
(inclusive) in percentage for either backtest or forecast will not be
sampled in the training set.
gcs_source: The path the csv file of the data source.
gcs_static_covariate_source: The path to the csv file of static covariates.
encryption_spec_key_name: Customer-managed encryption key options for the
Expand Down Expand Up @@ -129,6 +138,8 @@ def dataprep(
f'--config.datasets.val_rolling_window_size={test_set_stride}',
f'--config.datasets.n_test_windows={n_test_windows}',
f'--config.datasets.test_rolling_window_size={test_set_stride}',
f'--config.datasets.nan_threshold={nan_threshold}',
f'--config.datasets.zero_threshold={zero_threshold}',
f'--config.model.static_cov_names={static_covariate_columns}',
f'--config.model.blocks_list={model_blocks}',
f'--bigquery_source={bigquery_source}',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def train(
n_val_windows: int,
n_test_windows: int,
test_set_stride: int,
nan_threshold: float,
zero_threshold: float,
cleaning_activation_regularizer_coeff: float,
change_point_activation_regularizer_coeff: float,
change_point_output_regularizer_coeff: float,
Expand Down Expand Up @@ -88,6 +90,13 @@ def train(
n_test_windows: The number of windows to use for the test set. Must be >= 1.
test_set_stride: The number of timestamps to roll forward when
constructing the val and test sets.
nan_threshold: Series having more nan / missing values than
nan_threshold (inclusive) in percentage for either backtest or forecast
will not be sampled in the training set (including missing due to
train_start and train_end). All existing nans are replaced by zeros.
zero_threshold: Series having more 0.0 values than zero_threshold
(inclusive) in percentage for either backtest or forecast will not be
sampled in the training set.
cleaning_activation_regularizer_coeff: The regularization coefficient for
the cleaning param estimator's final layer's activation in the cleaning
block.
Expand Down Expand Up @@ -182,6 +191,8 @@ def train(
f'--config.datasets.val_rolling_window_size={test_set_stride}',
f'--config.datasets.n_test_windows={n_test_windows}',
f'--config.datasets.test_rolling_window_size={test_set_stride}',
f'--config.datasets.nan_threshold={nan_threshold}',
f'--config.datasets.zero_threshold={zero_threshold}',
f'--config.model.regularizer_coeff={cleaning_activation_regularizer_coeff}',
f'--config.model.activation_regularizer_coeff={change_point_activation_regularizer_coeff}',
f'--config.model.output_regularizer_coeff={change_point_output_regularizer_coeff}',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def starry_net( # pylint: disable=dangerous-default-value
dataprep_target_column: str = '',
dataprep_static_covariate_columns: List[str] = [],
dataprep_previous_run_dir: str = '',
dataprep_nan_threshold: float = 0.2,
dataprep_zero_threshold: float = 0.2,
trainer_machine_type: str = 'n1-standard-4',
trainer_accelerator_type: str = 'NVIDIA_TESLA_V100',
trainer_num_epochs: int = 50,
Expand Down Expand Up @@ -84,7 +86,16 @@ def starry_net( # pylint: disable=dangerous-default-value
project: str = _placeholders.PROJECT_ID_PLACEHOLDER,
):
# fmt: off
"""Trains a STARRY-Net model.
"""Starry Net is a state-of-the-art forecaster used internally by Google.
Starry Net is a glass-box neural network inspired by statistical time series
models, capable of cleaning step changes and spikes, modeling seasonality and
events, forecasting trend, and providing both point and prediction interval
forecasts in a single, lightweight model. Starry Net stands out among neural
network based forecasting models by providing the explainability,
interpretability and tunability of traditional statistical forecasters.
For example, it features time series feature decomposition and damped local
linear exponential smoothing model as the trend structure.
Args:
tensorboard_instance_id: The tensorboard instance ID. This must be in same
Expand Down Expand Up @@ -149,6 +160,13 @@ def starry_net( # pylint: disable=dangerous-default-value
dataprep_previous_run_dir: The dataprep dir from a previous run. Use this
to save time if you've already created TFRecords from your BigQuery
dataset with the same dataprep parameters as this run.
dataprep_nan_threshold: Series having more nan / missing values than
nan_threshold (inclusive) in percentage for either backtest or forecast
will not be sampled in the training set (including missing due to
train_start and train_end). All existing nans are replaced by zeros.
dataprep_zero_threshold: Series having more 0.0 values than zero_threshold
(inclusive) in percentage for either backtest or forecast will not be
sampled in the training set.
trainer_machine_type: The machine type for training. Must be compatible with
trainer_accelerator_type.
trainer_accelerator_type: The accelerator type for training.
Expand Down Expand Up @@ -247,6 +265,8 @@ def starry_net( # pylint: disable=dangerous-default-value
disk_size_gb=dataflow_disk_size_gb,
test_set_only=True,
bigquery_output=dataprep_test_set_bigquery_dataset,
nan_threshold=dataprep_nan_threshold,
zero_threshold=dataprep_zero_threshold,
gcs_source=dataprep_csv_data_path,
gcs_static_covariate_source=dataprep_csv_static_covariates_path,
encryption_spec_key_name=encryption_spec_key_name
Expand Down Expand Up @@ -282,6 +302,8 @@ def starry_net( # pylint: disable=dangerous-default-value
disk_size_gb=dataflow_disk_size_gb,
test_set_only=False,
bigquery_output=dataprep_test_set_bigquery_dataset,
nan_threshold=dataprep_nan_threshold,
zero_threshold=dataprep_zero_threshold,
gcs_source=dataprep_csv_data_path,
gcs_static_covariate_source=dataprep_csv_static_covariates_path,
encryption_spec_key_name=encryption_spec_key_name
Expand Down Expand Up @@ -330,6 +352,8 @@ def starry_net( # pylint: disable=dangerous-default-value
n_val_windows=dataprep_n_val_windows,
n_test_windows=dataprep_n_test_windows,
test_set_stride=dataprep_test_set_stride,
nan_threshold=dataprep_nan_threshold,
zero_threshold=dataprep_zero_threshold,
cleaning_activation_regularizer_coeff=trainer_cleaning_activation_regularizer_coeff,
change_point_activation_regularizer_coeff=trainer_change_point_activation_regularizer_coeff,
change_point_output_regularizer_coeff=trainer_change_point_output_regularizer_coeff,
Expand Down

0 comments on commit 0eae430

Please sign in to comment.