-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
experimental: Add minimal rust backend.
First pass at adding a simple rust backend that use nalgebra for its matrix library. Closes #405 GitOrigin-RevId: 28bfcfb8f956bf32a8a499093eb8f5e7878d4100
- Loading branch information
1 parent
6cf7fbd
commit 079830a
Showing
29 changed files
with
2,156 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
***THIS MODULE IS EXPERIMENTAL*** | ||
|
||
Backend for Rust. This currently only supports vector/matrices inputs and outputs, we do not have geo or cam types for Rust yet. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# ---------------------------------------------------------------------------- | ||
# SymForce - Copyright 2022, Skydio, Inc. | ||
# This source code is under the Apache 2.0 license found in the LICENSE file. | ||
# ---------------------------------------------------------------------------- | ||
|
||
from pathlib import Path | ||
|
||
__doc__ = (Path(__file__).parent / "README.rst").read_text() | ||
|
||
from .rust_code_printer import RustCodePrinter | ||
from .rust_code_printer import ScalarType | ||
from .rust_config import RustConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# ---------------------------------------------------------------------------- | ||
# SymForce - Copyright 2022, Skydio, Inc. | ||
# This source code is under the Apache 2.0 license found in the LICENSE file. | ||
# ---------------------------------------------------------------------------- | ||
|
||
from enum import Enum | ||
|
||
import sympy | ||
from sympy.codegen.ast import float32 | ||
from sympy.codegen.ast import float64 | ||
from sympy.printing.rust import RustCodePrinter as SympyRustCodePrinter | ||
|
||
from symforce import typing as T | ||
|
||
|
||
class ScalarType(Enum): | ||
FLOAT = float32 | ||
DOUBLE = float64 | ||
|
||
|
||
class RustCodePrinter(SympyRustCodePrinter): | ||
""" | ||
SymForce code printer for Rust. Based on the SymPy Rust printer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
scalar_type: ScalarType, | ||
settings: T.Optional[T.Dict[str, T.Any]] = None, | ||
override_methods: T.Optional[T.Dict[sympy.Function, str]] = None, | ||
) -> None: | ||
super().__init__(dict(settings or {})) | ||
|
||
self.scalar_type = scalar_type.value | ||
self.override_methods = override_methods or {} | ||
for expr, name in self.override_methods.items(): | ||
self._set_override_methods(expr, name) | ||
|
||
def _set_override_methods(self, expr: sympy.Function, name: str) -> None: | ||
method_name = f"_print_{str(expr)}" | ||
|
||
def _print_expr(expr: sympy.Expr) -> str: | ||
expr_string = ", ".join(map(self._print, expr.args)) | ||
return f"{name}({expr_string})" | ||
|
||
setattr(self, method_name, _print_expr) | ||
|
||
@staticmethod | ||
def _print_Zero(expr: sympy.Expr) -> str: | ||
return "0.0" | ||
|
||
def _print_Integer(self, expr: sympy.Integer, _type: T.Any = None) -> T.Any: | ||
""" | ||
Customizations: | ||
* Cast all integers to either f32 or f64 because Rust does not have implicit casting | ||
and needs to know the type of the literal at compile time. We assume that we are only | ||
ever operating on floats in SymForce which should make this safe. | ||
""" | ||
if self.scalar_type is float32: | ||
return f"{expr.p}_f32" | ||
if self.scalar_type is float64: | ||
return f"{expr.p}_f64" | ||
assert False, f"Scalar type {self.scalar_type} not supported" | ||
|
||
def _print_Pow(self, expr: T.Any, rational: T.Any = None) -> str: | ||
if expr.exp.is_rational: | ||
power = self._print_Rational(expr.exp) | ||
func = "powf" | ||
return f"{self._print(expr.base)}.{func}({power})" | ||
else: | ||
power = self._print(expr.exp) | ||
|
||
if expr.exp.is_integer: | ||
func = "powi" | ||
else: | ||
func = "powf" | ||
|
||
return f"{expr.base}.{func}({power})" | ||
|
||
@staticmethod | ||
def _print_ImaginaryUnit(expr: sympy.Expr) -> str: | ||
""" | ||
Customizations: | ||
* Print 1i instead of I | ||
* Cast to Scalar, since the literal is of type std::complex<double> | ||
""" | ||
return "Scalar(1i)" | ||
|
||
def _print_Float(self, flt: sympy.Float, _type: T.Any = None) -> T.Any: | ||
""" | ||
Customizations: | ||
* Cast all literals to Scalar at compile time instead of using a suffix at codegen time | ||
""" | ||
if self.scalar_type is float32: | ||
return f"{super()._print_Float(flt)}_f32" | ||
if self.scalar_type is float64: | ||
return f"{super()._print_Float(flt)}_f64" | ||
|
||
raise NotImplementedError(f"Scalar type {self.scalar_type} not supported") | ||
|
||
def _print_Pi(self, expr: T.Any, _type: bool = False) -> str: | ||
if self.scalar_type is float32: | ||
return "core::f32::consts::PI" | ||
if self.scalar_type is float64: | ||
return "core::f64::consts::PI" | ||
|
||
raise NotImplementedError(f"Scalar type {self.scalar_type} not supported") | ||
|
||
def _print_Max(self, expr: sympy.Max) -> str: | ||
""" | ||
Customizations: | ||
* The first argument calls the max method on the second argument. | ||
""" | ||
return "{}.max({})".format(self._print(expr.args[0]), self._print(expr.args[1])) | ||
|
||
def _print_Min(self, expr: sympy.Min) -> str: | ||
""" | ||
Customizations: | ||
* The first argument calls the min method on the second argument. | ||
""" | ||
return "{}.min({})".format(self._print(expr.args[0]), self._print(expr.args[1])) | ||
|
||
def _print_log(self, expr: sympy.log) -> str: | ||
""" | ||
Customizations: | ||
""" | ||
return "{}.ln()".format(self._print(expr.args[0])) | ||
|
||
def _print_Rational(self, expr: sympy.Rational) -> str: | ||
p, q = int(expr.p), int(expr.q) | ||
|
||
float_suffix = None | ||
if self.scalar_type is float32: | ||
float_suffix = "f32" | ||
elif self.scalar_type is float64: | ||
float_suffix = "f64" | ||
|
||
return f"({p}_{float_suffix}/{q}_{float_suffix})" | ||
|
||
def _print_Exp1(self, expr: T.Any, _type: bool = False) -> str: | ||
if self.scalar_type is float32: | ||
return "core::f32::consts::E" | ||
elif self.scalar_type is float64: | ||
return "core::f64::consts::E" | ||
|
||
raise NotImplementedError(f"Scalar type {self.scalar_type} not supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# ---------------------------------------------------------------------------- | ||
# SymForce - Copyright 2022, Skydio, Inc. | ||
# This source code is under the Apache 2.0 license found in the LICENSE file. | ||
# ---------------------------------------------------------------------------- | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
from sympy.printing.codeprinter import CodePrinter | ||
|
||
from symforce import typing as T | ||
from symforce.codegen.backends.rust import rust_code_printer | ||
from symforce.codegen.codegen_config import CodegenConfig | ||
|
||
CURRENT_DIR = Path(__file__).parent | ||
|
||
|
||
@dataclass | ||
class RustConfig(CodegenConfig): | ||
""" | ||
Code generation config for the Rust backend. | ||
Args: | ||
doc_comment_line_prefix: Prefix applied to each line in a docstring | ||
line_length: Maximum allowed line length in docstrings; used for formatting docstrings. | ||
scala_type: The scalar type to use (float or double) | ||
use_eigen_types: Use eigen_lcm types for vectors instead of lists | ||
render_template_config: Configuration for template rendering, see RenderTemplateConfig for | ||
more information | ||
cse_optimizations: Optimizations argument to pass to :func:`sf.cse <symforce.symbolic.cse>` | ||
zero_epsilon_behavior: What should codegen do if a default epsilon is not set? | ||
normalize_results: Should function outputs be explicitly projected onto the manifold before | ||
returning? | ||
""" | ||
|
||
doc_comment_line_prefix: str = "///" | ||
line_length: int = 100 | ||
scalar_type: rust_code_printer.ScalarType = rust_code_printer.ScalarType.DOUBLE | ||
use_eigen_types: bool = False | ||
|
||
@classmethod | ||
def backend_name(cls) -> str: | ||
return "rust" | ||
|
||
@classmethod | ||
def template_dir(cls) -> Path: | ||
return CURRENT_DIR / "templates" | ||
|
||
@staticmethod | ||
def templates_to_render(generated_file_name: str) -> T.List[T.Tuple[str, str]]: | ||
return [("function/FUNCTION.rs.jinja", f"{generated_file_name}.rs")] | ||
|
||
def printer(self) -> CodePrinter: | ||
kwargs: T.Mapping[str, T.Any] = {} | ||
return rust_code_printer.RustCodePrinter(scalar_type=self.scalar_type, **kwargs) | ||
|
||
@staticmethod | ||
def format_matrix_accessor(key: str, i: int, j: int, *, shape: T.Tuple[int, int]) -> str: | ||
""" | ||
Format accessor for matrix types. | ||
Assumes matrices are row-major. | ||
""" | ||
RustConfig._assert_indices_in_bounds(i, j, shape) | ||
if shape[1] == 1: | ||
return f"{key}[{i}]" | ||
if shape[0] == 1: | ||
return f"{key}[{j}]" | ||
return f"{key}[({i}, {j})]" | ||
|
||
@staticmethod | ||
def format_eigen_lcm_accessor(key: str, i: int) -> str: | ||
""" | ||
Format accessor for eigen_lcm types. | ||
""" | ||
raise NotImplementedError("Rust does not support eigen_lcm") |
20 changes: 20 additions & 0 deletions
20
symforce/codegen/backends/rust/templates/function/FUNCTION.rs.jinja
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{# ---------------------------------------------------------------------------- | ||
# SymForce - Copyright 2022, Skydio, Inc. | ||
# This source code is under the Apache 2.0 license found in the LICENSE file. | ||
# ---------------------------------------------------------------------------- #} | ||
|
||
{%- import "../util/util.jinja" as util with context -%} | ||
|
||
|
||
pub mod {{ spec.namespace }} { | ||
|
||
#[allow(unused_parens)] | ||
|
||
{% if spec.docstring %} | ||
{{ util.print_docstring(spec.docstring) }} | ||
{% endif %} | ||
{{ util.function_declaration(spec) }} { | ||
{{ util.expr_code(spec) }} | ||
} | ||
|
||
} // mod {{ spec.namespace }} |
Oops, something went wrong.