Skip to content
This repository has been archived by the owner on Feb 11, 2025. It is now read-only.

Commit

Permalink
add functions to identify a particular index for movement
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Jul 23, 2024
1 parent bd4a88e commit 329bea9
Showing 1 changed file with 111 additions and 16 deletions.
127 changes: 111 additions & 16 deletions src/social_norms_trees/behaviortreeworld_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field, replace
from functools import partial, wraps
from itertools import islice
from typing import TypeVar, Optional
from typing import TypeVar, Optional, List

import click
import py_trees
Expand Down Expand Up @@ -111,24 +111,93 @@ def enumerate_nodes(tree: py_trees.behaviour.Behaviour):
return enumerate(iterate_nodes(tree))


def format_tree_with_indices(tree: py_trees.behaviour.Behaviour):
def label_tree_lines(
tree: py_trees.behaviour.Behaviour,
labels: List[str],
representation=py_trees.display.unicode_tree,
) -> str:
max_len = max([len(s) for s in labels])
padded_labels = [s.rjust(max_len) for s in labels]

tree_representation_lines = representation(tree).split("\n")
enumerated_tree_representation_lines = [
f"{i}: {t}" for i, t in zip(padded_labels, tree_representation_lines)
]

output = "\n".join(enumerated_tree_representation_lines)
return output


def format_children_with_indices(composite: py_trees.composites.Composite) -> str:
"""
Examples:
>>> tree = py_trees.composites.Sequence("s1", False, children=[
... py_trees.behaviours.Dummy(),
... py_trees.behaviours.Success(),
... py_trees.composites.Sequence("s2", False, children=[
... py_trees.behaviours.Dummy(),
... ]),
... py_trees.composites.Sequence("", False, children=[
... py_trees.behaviours.Failure(),
... py_trees.behaviours.Periodic("p", n=1),
... ]),
... ])
>>> print(format_children_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE
_: [-] s1
0: --> Dummy
1: --> Success
2: [-] s2
_: --> Dummy
3: [-]
_: --> Failure
_: --> p
"""
index_strings = []
i = 0
for b in iterate_nodes(composite):
if b in composite.children:
index_strings.append(str(i))
i += 1
else:
index_strings.append("_")

output = label_tree_lines(composite, index_strings)
return output


def format_tree_with_indices(tree: py_trees.behaviour.Behaviour) -> str:
"""
Examples:
>>> print(format_tree_with_indices(py_trees.behaviours.Dummy()))
0: --> Dummy
>>> tree = py_trees.composites.Sequence("s1", False, children=[
... py_trees.behaviours.Dummy(),
... py_trees.behaviours.Success(),
... py_trees.composites.Sequence("s2", False, children=[
... py_trees.behaviours.Dummy(),
... ]),
... py_trees.composites.Sequence("", False, children=[
... py_trees.behaviours.Failure(),
... py_trees.behaviours.Periodic("p", n=1),
... ]),
... ])
>>> print(format_tree_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE
0: [-] s1
1: --> Dummy
2: --> Success
3: [-] s2
4: --> Dummy
5: [-]
6: --> Failure
7: --> p
"""

index_strings = [str(i) for i, _ in enumerate_nodes(tree)]
max_len = max([len(s) for s in index_strings])
padded_index_strings = [s.rjust(max_len) for s in index_strings]

tree_representation = py_trees.display.unicode_tree(tree)
tree_representation_lines = tree_representation.split("\n")
enumerated_tree_representation_lines = [
f"{i}: {t}" for i, t in zip(padded_index_strings, tree_representation_lines)
]
output = label_tree_lines(tree, index_strings)

output = "\n".join(enumerated_tree_representation_lines)
return output


Expand All @@ -141,6 +210,18 @@ def prompt_identify_node(
message: str = "Which node?",
display_nodes: bool = True,
) -> py_trees.behaviour.Behaviour:
node_index = prompt_identify_tree_iterator_index(
tree=tree, message=message, display_nodes=display_nodes
)
node = next(islice(iterate_nodes(tree), node_index, node_index + 1))
return node


def prompt_identify_tree_iterator_index(
tree: py_trees.behaviour.Behaviour,
message: str = "Which position?",
display_nodes: bool = True,
) -> int:
if display_nodes:
text = f"{format_tree_with_indices(tree)}\n{message}"
else:
Expand All @@ -149,8 +230,23 @@ def prompt_identify_node(
text=text,
type=int,
)
node = next(islice(iterate_nodes(tree), node_index, node_index + 1))
return node
return node_index


def prompt_identify_child_index(
tree: py_trees.behaviour.Behaviour,
message: str = "Which position?",
display_nodes: bool = True,
) -> int:
if display_nodes:
text = f"{format_children_with_indices(tree)}\n{message}"
else:
text = f"{message}"
node_index = click.prompt(
text=text,
type=int,
)
return node_index


def add_child(
Expand Down Expand Up @@ -247,7 +343,7 @@ def move_node(
tree, f"What should its parent be?", display_nodes=False
)
if index is None:
index = click.prompt(f"What should its position be?", type=int)
index = prompt_identify_child_index(new_parent)

assert isinstance(new_parent, py_trees.composites.Composite)
assert isinstance(node.parent, py_trees.composites.Composite)
Expand Down Expand Up @@ -354,12 +450,11 @@ def exchange_nodes(
)

print(py_trees.display.ascii_tree(tree))
print(format_tree_with_indices(tree))
move_node(tree)
print(format_tree_with_indices(tree))
exchange_nodes(tree)
remove_node(tree)

print(format_tree_with_indices(tree))
remove_node(tree)

pass

Expand Down

0 comments on commit 329bea9

Please sign in to comment.