Skip to content

Commit

Permalink
more updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sguequierre committed Jan 29, 2025
1 parent 8f316e2 commit 14a8aea
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions docs/data-ai/ai/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ my-training/

Add the following code to `setup.py` and add additional required packages on line 11:

```python {class="line-numbers linkable-line-numbers" data-line="11"}
```python {class="line-numbers linkable-line-numbers" data-line="9"}
from setuptools import find_packages, setup

setup(
Expand Down Expand Up @@ -92,11 +92,13 @@ If you haven't already, create a folder called <file>model</file> and create an

{{% expand "Click to see the template" %}}

```python {class="line-numbers linkable-line-numbers" data-line="126,170" }
```python {class="line-numbers linkable-line-numbers" data-line="139" }
import argparse
import json
import os
import typing as ty
from tensorflow.keras import Model # Add proper import
import tensorflow as tf # Add proper import

single_label = "MODEL_TYPE_SINGLE_LABEL_CLASSIFICATION"
multi_label = "MODEL_TYPE_MULTI_LABEL_CLASSIFICATION"
Expand All @@ -106,32 +108,37 @@ unknown_label = "UNKNOWN"
API_KEY = os.environ['API_KEY']
API_KEY_ID = os.environ['API_KEY_ID']

DEFAULT_EPOCHS = 200

# This parses the required args for the training script.
# The model_dir variable will contain the output directory where
# the ML model that this script creates should be stored.
# The data_json variable will contain the metadata for the dataset
# that you should use to train the model.
def parse_args():
"""Returns dataset file, model output directory, and num_epochs if present.
"""Returns dataset file, model output directory, labels, and num_epochs if present.
These must be parsed as command line arguments and then used as the model
input and output, respectively. The number of epochs can be used to
optionally override the default.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_file", dest="data_json", type=str)
parser.add_argument("--model_output_directory", dest="model_dir", type=str)
parser.add_argument("--dataset_file", dest="data_json", type=str, required=True)
parser.add_argument("--model_output_directory", dest="model_dir", type=str, required=True)
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
parser.add_argument(
"--labels",
dest="labels",
type=str,
required=False,
required=True,
help="Space-separated list of labels, enclosed in single quotes (e.g., 'label1 label2').",
)
args = parser.parse_args()
return args.data_json, args.model_dir, args.num_epochs, args.labels


if not args.labels:
raise ValueError("Labels must be provided")

labels = [label.strip() for label in args.labels.strip("'").split()]
return args.data_json, args.model_dir, args.num_epochs, labels

# This is used for parsing the dataset file (produced and stored in Viam),
# parse it to get the label annotations
Expand Down Expand Up @@ -215,8 +222,7 @@ def parse_filenames_and_bboxes_from_json(

# Build the model
def build_and_compile_model(
labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int]
) -> Model:
labels: ty.List[str], model_type: str, input_shape: ty.Tuple[int, int, int]) -> Model:
"""Builds and compiles a model
Args:
labels: list of string lists, where each string list contains up to
Expand Down Expand Up @@ -255,12 +261,16 @@ def save_model(
model_dir: output directory for model artifacts
model_name: name of saved model
"""
file_type = ""

# Save the model to the output directory.
file_type = "tflite" # Add proper file type
filename = os.path.join(model_dir, f"{model_name}.{file_type}")

# Example: Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model
with open(filename, "wb") as f:
f.write(model)
f.write(tflite_model)


if __name__ == "__main__":
Expand All @@ -275,14 +285,19 @@ if __name__ == "__main__":
image_filenames, image_labels = parse_filenames_and_labels_from_json(
DATA_JSON, LABELS, model_type)

# Validate epochs
epochs = (
DEFAULT_EPOCHS if NUM_EPOCHS is None or NUM_EPOCHS <= 0 else int(NUM_EPOCHS)
)

# Build and compile model on data
model = build_and_compile_model()
model = build_and_compile_model(image_labels, model_type, IMG_SIZE + (3,))

# Save labels.txt file
save_labels(LABELS + [unknown_label], MODEL_DIR)
# Convert the model to tflite
save_model(
model, MODEL_DIR, "classification_model", IMG_SIZE + (3,)
model, MODEL_DIR, "classification_model"
)
```

Expand Down

0 comments on commit 14a8aea

Please sign in to comment.