Skip to content

Commit

Permalink
more training script updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sguequierre committed Jan 29, 2025
1 parent cd40828 commit 8f316e2
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions docs/data-ai/ai/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def parse_args():
help="Space-separated list of labels, enclosed in single quotes (e.g., 'label1 label2').",
)
args = parser.parse_args()
return args.data_json, args.model_dir, args.num_epochs
return args.data_json, args.model_dir, args.num_epochs, args.labels


# This is used for parsing the dataset file (produced and stored in Viam),
Expand Down Expand Up @@ -264,14 +264,11 @@ def save_model(


if __name__ == "__main__":
DATA_JSON, MODEL_DIR = parse_args()
DATA_JSON, MODEL_DIR, NUM_EPOCHS, LABELS = parse_args()

IMG_SIZE = (256, 256)

# Read dataset file.
# TODO: change labels to the desired model output.
LABELS = ["orange_triangle", "blue_star"]

# The model type can be changed based on whether you want the model to
# output one label per image or multiple labels per image
model_type = multi_label
Expand Down

0 comments on commit 8f316e2

Please sign in to comment.