-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from MaximilienNaveau/topic/humble-devel/refactor
Add a first shell of the agimus_controller API
- Loading branch information
Showing
16 changed files
with
394 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
*__pycache__* | ||
*.npy | ||
deprecated |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from agimus_controller.ocp_base import OCPBase | ||
|
||
|
||
def _create_ocp_hpp_crocco() -> OCPBase: | ||
pass | ||
|
||
|
||
def _create_ocp_collision_avoidance() -> OCPBase: | ||
pass | ||
|
||
|
||
def _create_ocp_single_ee_ref() -> OCPBase: | ||
pass | ||
|
||
|
||
def create_ocp(name: str) -> OCPBase: | ||
if name == "hpp_crocco": | ||
return _create_ocp_hpp_crocco() | ||
|
||
if name == "collision_avoidance": | ||
return _create_ocp_collision_avoidance() | ||
|
||
if name == "single_ee_ref": | ||
return _create_ocp_single_ee_ref() |
96 changes: 96 additions & 0 deletions
96
agimus_controller/agimus_controller/factory/robot_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
import numpy as np | ||
from pathlib import Path | ||
import pinocchio as pin | ||
from typing import Union | ||
|
||
|
||
@dataclass | ||
class RobotModelParameters: | ||
q0_name = str() | ||
free_flyer = False | ||
locked_joint_names = [] | ||
urdf = Path() | str | ||
srdf = Path() | str | ||
collision_as_capsule = False | ||
self_collision = False | ||
|
||
|
||
class RobotModelFactory: | ||
"""Parse the robot model, reduce it and filter the collision model.""" | ||
|
||
"""Complete model of the robot.""" | ||
_complete_model = pin.Model() | ||
""" Complete model of the robot with collision meshes. """ | ||
_complete_collision_model = pin.GeometryModel() | ||
""" Complete model of the robot with visualization meshes. """ | ||
_complete_visual_model = pin.GeometryModel() | ||
""" Reduced model of the robot. """ | ||
_rmodel = pin.Model() | ||
""" Reduced model of the robot with visualization meshes. """ | ||
_rcmodel = pin.GeometryModel() | ||
""" Reduced model of the robot with collision meshes. """ | ||
_rvmodel = pin.GeometryModel() | ||
""" Default configuration q0. """ | ||
_q0 = np.array([]) | ||
""" Parameters of the model. """ | ||
_params = RobotModelParameters() | ||
""" Path to the collisions environment. """ | ||
_env = Path() | ||
|
||
def load_model(self, param: RobotModelParameters, env: Union[Path, None]) -> None: | ||
self._params = param | ||
self._env = env | ||
self._load_pinocchio_models(param.urdf, param.free_flyer) | ||
self._load_default_configuration(param.srdf, param.q0_name) | ||
self._load_reduced_model(param.locked_joint_names, param.q0_name) | ||
self._update_collision_model( | ||
env, param.collision_as_capsule, param.self_collision, param.srdf | ||
) | ||
|
||
def _load_pinocchio_models(self, urdf: Path, free_flyer: bool) -> None: | ||
pass | ||
|
||
def _load_default_configuration(self, srdf_path: Path, q0_name: str) -> None: | ||
pass | ||
|
||
def _load_reduced_model(self, locked_joint_names, q0_name) -> None: | ||
pass | ||
|
||
def _update_collision_model( | ||
self, | ||
env: Union[Path, None], | ||
collision_as_capsule: bool, | ||
self_collision: bool, | ||
srdf: Path, | ||
) -> None: | ||
pass | ||
|
||
def create_complete_robot_model(self) -> pin.Model: | ||
return self._complete_model.copy() | ||
|
||
def create_complete_collision_model(self) -> pin.GeometryModel: | ||
return self._complete_collision_model.copy() | ||
|
||
def create_complete_visual_model(self) -> pin.GeometryModel: | ||
return self._complete_visual_model.copy() | ||
|
||
def create_reduced_robot_model(self) -> pin.Model: | ||
return self._rmodel.copy() | ||
|
||
def create_reduced_collision_model(self) -> pin.GeometryModel: | ||
return self._rcmodel.copy() | ||
|
||
def create_reduced_visual_model(self) -> pin.GeometryModel: | ||
return self._rvmodel.copy() | ||
|
||
def create_default_configuration(self) -> np.array: | ||
return self._q0.copy() | ||
|
||
def create_model_parameters(self) -> RobotModelParameters: | ||
return deepcopy(self._params) | ||
|
||
def print_model(self): | ||
print("full model =\n", self._complete_model) | ||
print("reduced model =\n", self._rmodel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from agimus_controller.warm_start_base import WarmStartBase | ||
|
||
|
||
def _create_warm_start_from_previous_solution() -> WarmStartBase: | ||
pass | ||
|
||
|
||
def _create_warm_start_from_diffusion_model() -> WarmStartBase: | ||
pass | ||
|
||
|
||
def create_warm_start(name: str) -> WarmStartBase: | ||
if name == "from_previous_solution": | ||
return _create_warm_start_from_previous_solution() | ||
|
||
if name == "from_diffusion_model": | ||
return _create_warm_start_from_diffusion_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import time | ||
|
||
from agimus_controller.mpc_data import OCPResults, MPCDebugData | ||
from agimus_controller.ocp_base import OCPBase | ||
from agimus_controller.trajectory import ( | ||
TrajectoryBuffer, | ||
TrajectoryPoint, | ||
WeightedTrajectoryPoint, | ||
) | ||
from agimus_controller.warm_start_base import WarmStartBase | ||
|
||
|
||
class MPC(object): | ||
def __init__(self) -> None: | ||
self._ocp = None | ||
self._warm_start = None | ||
self._mpc_debug_data: MPCDebugData = None | ||
self._buffer = None | ||
|
||
def setup( | ||
self, | ||
ocp: OCPBase, | ||
warm_start: WarmStartBase, | ||
buffer: TrajectoryBuffer = TrajectoryBuffer(), | ||
) -> None: | ||
self._ocp = ocp | ||
self._warm_start = warm_start | ||
self._buffer = buffer | ||
|
||
def run(self, initial_state: TrajectoryPoint, current_time_ns: int) -> OCPResults: | ||
assert self._ocp is not None | ||
assert self._warm_start is not None | ||
timer1 = time.perf_counter_ns() | ||
self._buffer.clear_past(current_time_ns) | ||
reference_trajectory = self._extract_horizon_from_buffer() | ||
self._ocp.set_reference_horizon(reference_trajectory) | ||
timer2 = time.perf_counter_ns() | ||
x0, x_init, u_init = self._warm_start.generate( | ||
initial_state, reference_trajectory, self._ocp.debug_data.result | ||
) | ||
timer3 = time.perf_counter_ns() | ||
self._ocp.solve(x0, x_init, u_init) | ||
self._warm_start.update_previous_solution(self._ocp.debug_data.result) | ||
timer4 = time.perf_counter_ns() | ||
|
||
# Extract the solution. | ||
self._mpc_debug_data = self._ocp.debug_data | ||
self._mpc_debug_data.duration_iteration_ns = timer4 - timer1 | ||
self._mpc_debug_data.duration_horizon_update_ns = timer2 - timer1 | ||
self._mpc_debug_data.duration_generate_warm_start_ns = timer3 - timer2 | ||
self._mpc_debug_data.duration_ocp_solve_ns = timer4 - timer3 | ||
|
||
return self._ocp.ocp_results | ||
|
||
@property | ||
def mpc_debug_data(self) -> MPCDebugData: | ||
return self._mpc_debug_data | ||
|
||
def append_trajectory_point(self, trajectory_point: WeightedTrajectoryPoint): | ||
self._buffer.append(trajectory_point) | ||
|
||
def append_trajectory_points( | ||
self, trajectory_points: list[WeightedTrajectoryPoint] | ||
): | ||
self._buffer.extend(trajectory_points) | ||
|
||
def _extract_horizon_from_buffer(self): | ||
return self._buffer.horizon(self._ocp.horizon_size, self._ocp.dt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from dataclasses import dataclass | ||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from agimus_controller.trajectory import TrajectoryPoint | ||
|
||
|
||
@dataclass | ||
class OCPResults: | ||
"""Output data structure of the MPC.""" | ||
|
||
states: list[npt.NDArray[np.float64]] | ||
ricatti_gains: list[npt.NDArray[np.float64]] | ||
feed_forward_terms: list[npt.NDArray[np.float64]] | ||
|
||
|
||
@dataclass | ||
class OCPDebugData: | ||
# Solver infos | ||
problem_solved: bool = False | ||
|
||
# Debug data | ||
result: list[TrajectoryPoint] | ||
references: list[TrajectoryPoint] | ||
kkt_norms: list[np.float64] | ||
collision_distance_residuals: list[dict[np.float64]] | ||
|
||
|
||
@dataclass | ||
class MPCDebugData: | ||
ocp: OCPDebugData | ||
# Timers | ||
duration_iteration_ns: int = 0 | ||
duration_horizon_update_ns: int = 0 | ||
duration_generate_warm_start_ns: int = 0 | ||
duration_ocp_solve_ns: int = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from abc import ABC, abstractmethod | ||
import numpy as np | ||
|
||
from agimus_controller.mpc_data import OCPResults, OCPDebugData | ||
from agimus_controller.trajectory import WeightedTrajectoryPoint | ||
|
||
|
||
class OCPBase(ABC): | ||
def __init__(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def set_reference_horizon( | ||
self, reference_trajectory: list[WeightedTrajectoryPoint] | ||
) -> None: | ||
... | ||
|
||
@abstractmethod | ||
@property | ||
def horizon_size() -> int: | ||
... | ||
|
||
@abstractmethod | ||
@property | ||
def dt() -> int: | ||
... | ||
|
||
@abstractmethod | ||
def solve( | ||
self, x0: np.ndarray, x_init: list[np.ndarray], u_init: list[np.ndarray] | ||
) -> None: | ||
... | ||
|
||
@abstractmethod | ||
@property | ||
def ocp_results(self) -> OCPResults: | ||
... | ||
|
||
@abstractmethod | ||
@property | ||
def debug_data(self) -> OCPDebugData: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from collections import deque | ||
from dataclasses import dataclass | ||
import numpy as np | ||
from pinocchio import SE3, Force | ||
|
||
|
||
@dataclass | ||
class TrajectoryPoint: | ||
"""Trajectory point aiming at being a reference for the MPC.""" | ||
|
||
time_ns: int | ||
robot_configuration: np.ndarray | ||
robot_velocity: np.ndarray | ||
robot_acceleration: np.ndarray | ||
robot_effort: np.ndarray | ||
forces: dict[Force] # Dictionary of pinocchio.Force | ||
end_effector_poses: dict[SE3] # Dictionary of pinocchio.SE3 | ||
|
||
|
||
@dataclass | ||
class TrajectoryPointWeights: | ||
"""Trajectory point weights aiming at being set in the MPC costs.""" | ||
|
||
w_robot_configuration: np.ndarray | ||
w_robot_velocity: np.ndarray | ||
w_robot_acceleration: np.ndarray | ||
w_robot_effort: np.ndarray | ||
w_forces: dict[np.ndarray] | ||
w_end_effector_poses: dict[np.ndarray] | ||
|
||
|
||
@dataclass | ||
class WeightedTrajectoryPoint: | ||
"""Trajectory point and it's corresponding weights.""" | ||
|
||
point: TrajectoryPoint | ||
weight: TrajectoryPointWeights | ||
|
||
|
||
class TrajectoryBuffer(deque): | ||
"""List of variable size in which the HPP trajectory nodes will be.""" | ||
|
||
def clear_past(self, current_time_ns): | ||
while self and self[0].point.time_ns < current_time_ns: | ||
self.popleft() | ||
|
||
def horizon(self, horizon_size, dt_ocp): | ||
# TBD improve this implementation in case the dt_mpc != dt_ocp | ||
return self._buffer[: self._ocp.horizon_size] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from abc import ABC, abstractmethod | ||
import numpy as np | ||
|
||
from agimus_controller.trajectory import TrajectoryPoint | ||
|
||
|
||
class WarmStartBase(ABC): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self._previous_solution: list[TrajectoryPoint] = list() | ||
|
||
@abstractmethod | ||
def generate( | ||
self, | ||
reference_trajectory: list[TrajectoryPoint], | ||
) -> tuple(np.ndarray, np.ndarray): | ||
"""Returns x_init, u_init.""" | ||
... | ||
|
||
def update_previous_solution( | ||
self, previous_solution: list[TrajectoryPoint] | ||
) -> None: | ||
"""Stores internally the previous solution of the OCP""" | ||
self._previous_solution = previous_solution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from setuptools import find_packages, setup | ||
|
||
PACKAGE_NAME = "agimus_controller" | ||
REQUIRES_PYTHON = ">=3.10.0" | ||
|
||
setup( | ||
name=PACKAGE_NAME, | ||
version="0.0.0", | ||
packages=find_packages(exclude=["tests"]), | ||
python_requires=REQUIRES_PYTHON, | ||
install_requires=[ | ||
"setuptools", | ||
"numpy==1.21.5", | ||
], | ||
zip_safe=True, | ||
maintainer="Guilhem Saurel", | ||
maintainer_email="[email protected]", | ||
description="Implements whole body MPC in python using the Croccodyl framework.", | ||
license="BSD-2", | ||
tests_require=["pytest"], | ||
) |
Oops, something went wrong.