From 14a8aea04f41b906145bd6c15f7be66cdd6fab72 Mon Sep 17 00:00:00 2001 From: Sierra Guequierre Date: Wed, 29 Jan 2025 12:51:57 -0500 Subject: [PATCH] more updates --- docs/data-ai/ai/train.md | 47 ++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/docs/data-ai/ai/train.md b/docs/data-ai/ai/train.md index 129731c5ad..0585ab1cc8 100644 --- a/docs/data-ai/ai/train.md +++ b/docs/data-ai/ai/train.md @@ -63,7 +63,7 @@ my-training/ Add the following code to `setup.py` and add additional required packages on line 11: -```python {class="line-numbers linkable-line-numbers" data-line="11"} +```python {class="line-numbers linkable-line-numbers" data-line="9"} from setuptools import find_packages, setup setup( @@ -92,11 +92,13 @@ If you haven't already, create a folder called model and create an {{% expand "Click to see the template" %}} -```python {class="line-numbers linkable-line-numbers" data-line="126,170" } +```python {class="line-numbers linkable-line-numbers" data-line="139" } import argparse import json import os import typing as ty +from tensorflow.keras import Model # Add proper import +import tensorflow as tf # Add proper import single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION" multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION" @@ -106,6 +108,7 @@ unknown_label = "UNKNOWN" API_KEY = os.environ['API_KEY'] API_KEY_ID = os.environ['API_KEY_ID'] +DEFAULT_EPOCHS = 200 # This parses the required args for the training script. # The model_dir variable will contain the output directory where @@ -113,25 +116,29 @@ API_KEY_ID = os.environ['API_KEY_ID'] # The data_json variable will contain the metadata for the dataset # that you should use to train the model. def parse_args(): - """Returns dataset file, model output directory, and num_epochs if present. + """Returns dataset file, model output directory, labels, and num_epochs if present. These must be parsed as command line arguments and then used as the model input and output, respectively. The number of epochs can be used to optionally override the default. """ parser = argparse.ArgumentParser() - parser.add_argument("--dataset_file", dest="data_json", type=str) - parser.add_argument("--model_output_directory", dest="model_dir", type=str) + parser.add_argument("--dataset_file", dest="data_json", type=str, required=True) + parser.add_argument("--model_output_directory", dest="model_dir", type=str, required=True) parser.add_argument("--num_epochs", dest="num_epochs", type=int) parser.add_argument( "--labels", dest="labels", type=str, - required=False, + required=True, help="Space-separated list of labels, enclosed in single quotes (e.g., 'label1 label2').", ) args = parser.parse_args() - return args.data_json, args.model_dir, args.num_epochs, args.labels - + + if not args.labels: + raise ValueError("Labels must be provided") + + labels = [label.strip() for label in args.labels.strip("'").split()] + return args.data_json, args.model_dir, args.num_epochs, labels # This is used for parsing the dataset file (produced and stored in Viam), # parse it to get the label annotations @@ -215,8 +222,7 @@ def parse_filenames_and_bboxes_from_json( # Build the model def build_and_compile_model( - labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int] -) -> Model: + labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int]) -> Model: """Builds and compiles a model Args: labels: list of string lists, where each string list contains up to @@ -255,12 +261,16 @@ def save_model( model_dir: output directory for model artifacts model_name: name of saved model """ - file_type = "" - - # Save the model to the output directory. + file_type = "tflite" # Add proper file type filename = os.path.join(model_dir, f"{model_name}.{file_type}") + + # Example: Convert to TFLite + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + + # Save the model with open(filename, "wb") as f: - f.write(model) + f.write(tflite_model) if __name__ == "__main__": @@ -275,14 +285,19 @@ if __name__ == "__main__": image_filenames, image_labels = parse_filenames_and_labels_from_json( DATA_JSON, LABELS, model_type) + # Validate epochs + epochs = ( + DEFAULT_EPOCHS if NUM_EPOCHS is None or NUM_EPOCHS <= 0 else int(NUM_EPOCHS) + ) + # Build and compile model on data - model = build_and_compile_model() + model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,)) # Save labels.txt file save_labels(LABELS + [unknown_label], MODEL_DIR) # Convert the model to tflite save_model( - model, MODEL_DIR, "classification_model", IMG_SIZE + (3,) + model, MODEL_DIR, "classification_model" ) ```