Skip to content

Commit

Permalink
experimental: Add minimal rust backend.
Browse files Browse the repository at this point in the history
First pass at adding a simple rust backend that use nalgebra for its matrix library.

Closes #405

GitOrigin-RevId: 28bfcfb8f956bf32a8a499093eb8f5e7878d4100
  • Loading branch information
matte1 authored and aaron-skydio committed Dec 24, 2024
1 parent 6cf7fbd commit 079830a
Show file tree
Hide file tree
Showing 29 changed files with 2,156 additions and 113 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ jobs:
libgmp-dev \
pandoc
- name: Install rust
run: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > installer.sh
chmod +x installer.sh
./installer.sh -y
# NOTE(brad): libunwind-dev is a broken dependency of libgoogle-glog-dev, itself
# a dependency of ceres. Without this step on jammy, apt-get install libgoogle-glog-dev
# would fail. If this step could be removed and still have the build succeed, it should.
Expand Down
3 changes: 3 additions & 0 deletions symforce/codegen/backends/rust/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
***THIS MODULE IS EXPERIMENTAL***

Backend for Rust. This currently only supports vector/matrices inputs and outputs, we do not have geo or cam types for Rust yet.
12 changes: 12 additions & 0 deletions symforce/codegen/backends/rust/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from pathlib import Path

__doc__ = (Path(__file__).parent / "README.rst").read_text()

from .rust_code_printer import RustCodePrinter
from .rust_code_printer import ScalarType
from .rust_config import RustConfig
146 changes: 146 additions & 0 deletions symforce/codegen/backends/rust/rust_code_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from enum import Enum

import sympy
from sympy.codegen.ast import float32
from sympy.codegen.ast import float64
from sympy.printing.rust import RustCodePrinter as SympyRustCodePrinter

from symforce import typing as T


class ScalarType(Enum):
FLOAT = float32
DOUBLE = float64


class RustCodePrinter(SympyRustCodePrinter):
"""
SymForce code printer for Rust. Based on the SymPy Rust printer.
"""

def __init__(
self,
scalar_type: ScalarType,
settings: T.Optional[T.Dict[str, T.Any]] = None,
override_methods: T.Optional[T.Dict[sympy.Function, str]] = None,
) -> None:
super().__init__(dict(settings or {}))

self.scalar_type = scalar_type.value
self.override_methods = override_methods or {}
for expr, name in self.override_methods.items():
self._set_override_methods(expr, name)

def _set_override_methods(self, expr: sympy.Function, name: str) -> None:
method_name = f"_print_{str(expr)}"

def _print_expr(expr: sympy.Expr) -> str:
expr_string = ", ".join(map(self._print, expr.args))
return f"{name}({expr_string})"

setattr(self, method_name, _print_expr)

@staticmethod
def _print_Zero(expr: sympy.Expr) -> str:
return "0.0"

def _print_Integer(self, expr: sympy.Integer, _type: T.Any = None) -> T.Any:
"""
Customizations:
* Cast all integers to either f32 or f64 because Rust does not have implicit casting
and needs to know the type of the literal at compile time. We assume that we are only
ever operating on floats in SymForce which should make this safe.
"""
if self.scalar_type is float32:
return f"{expr.p}_f32"
if self.scalar_type is float64:
return f"{expr.p}_f64"
assert False, f"Scalar type {self.scalar_type} not supported"

def _print_Pow(self, expr: T.Any, rational: T.Any = None) -> str:
if expr.exp.is_rational:
power = self._print_Rational(expr.exp)
func = "powf"
return f"{self._print(expr.base)}.{func}({power})"
else:
power = self._print(expr.exp)

if expr.exp.is_integer:
func = "powi"
else:
func = "powf"

return f"{expr.base}.{func}({power})"

@staticmethod
def _print_ImaginaryUnit(expr: sympy.Expr) -> str:
"""
Customizations:
* Print 1i instead of I
* Cast to Scalar, since the literal is of type std::complex<double>
"""
return "Scalar(1i)"

def _print_Float(self, flt: sympy.Float, _type: T.Any = None) -> T.Any:
"""
Customizations:
* Cast all literals to Scalar at compile time instead of using a suffix at codegen time
"""
if self.scalar_type is float32:
return f"{super()._print_Float(flt)}_f32"
if self.scalar_type is float64:
return f"{super()._print_Float(flt)}_f64"

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")

def _print_Pi(self, expr: T.Any, _type: bool = False) -> str:
if self.scalar_type is float32:
return "core::f32::consts::PI"
if self.scalar_type is float64:
return "core::f64::consts::PI"

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")

def _print_Max(self, expr: sympy.Max) -> str:
"""
Customizations:
* The first argument calls the max method on the second argument.
"""
return "{}.max({})".format(self._print(expr.args[0]), self._print(expr.args[1]))

def _print_Min(self, expr: sympy.Min) -> str:
"""
Customizations:
* The first argument calls the min method on the second argument.
"""
return "{}.min({})".format(self._print(expr.args[0]), self._print(expr.args[1]))

def _print_log(self, expr: sympy.log) -> str:
"""
Customizations:
"""
return "{}.ln()".format(self._print(expr.args[0]))

def _print_Rational(self, expr: sympy.Rational) -> str:
p, q = int(expr.p), int(expr.q)

float_suffix = None
if self.scalar_type is float32:
float_suffix = "f32"
elif self.scalar_type is float64:
float_suffix = "f64"

return f"({p}_{float_suffix}/{q}_{float_suffix})"

def _print_Exp1(self, expr: T.Any, _type: bool = False) -> str:
if self.scalar_type is float32:
return "core::f32::consts::E"
elif self.scalar_type is float64:
return "core::f64::consts::E"

raise NotImplementedError(f"Scalar type {self.scalar_type} not supported")
75 changes: 75 additions & 0 deletions symforce/codegen/backends/rust/rust_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from dataclasses import dataclass
from pathlib import Path

from sympy.printing.codeprinter import CodePrinter

from symforce import typing as T
from symforce.codegen.backends.rust import rust_code_printer
from symforce.codegen.codegen_config import CodegenConfig

CURRENT_DIR = Path(__file__).parent


@dataclass
class RustConfig(CodegenConfig):
"""
Code generation config for the Rust backend.
Args:
doc_comment_line_prefix: Prefix applied to each line in a docstring
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
scala_type: The scalar type to use (float or double)
use_eigen_types: Use eigen_lcm types for vectors instead of lists
render_template_config: Configuration for template rendering, see RenderTemplateConfig for
more information
cse_optimizations: Optimizations argument to pass to :func:`sf.cse <symforce.symbolic.cse>`
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
normalize_results: Should function outputs be explicitly projected onto the manifold before
returning?
"""

doc_comment_line_prefix: str = "///"
line_length: int = 100
scalar_type: rust_code_printer.ScalarType = rust_code_printer.ScalarType.DOUBLE
use_eigen_types: bool = False

@classmethod
def backend_name(cls) -> str:
return "rust"

@classmethod
def template_dir(cls) -> Path:
return CURRENT_DIR / "templates"

@staticmethod
def templates_to_render(generated_file_name: str) -> T.List[T.Tuple[str, str]]:
return [("function/FUNCTION.rs.jinja", f"{generated_file_name}.rs")]

def printer(self) -> CodePrinter:
kwargs: T.Mapping[str, T.Any] = {}
return rust_code_printer.RustCodePrinter(scalar_type=self.scalar_type, **kwargs)

@staticmethod
def format_matrix_accessor(key: str, i: int, j: int, *, shape: T.Tuple[int, int]) -> str:
"""
Format accessor for matrix types.
Assumes matrices are row-major.
"""
RustConfig._assert_indices_in_bounds(i, j, shape)
if shape[1] == 1:
return f"{key}[{i}]"
if shape[0] == 1:
return f"{key}[{j}]"
return f"{key}[({i}, {j})]"

@staticmethod
def format_eigen_lcm_accessor(key: str, i: int) -> str:
"""
Format accessor for eigen_lcm types.
"""
raise NotImplementedError("Rust does not support eigen_lcm")
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ---------------------------------------------------------------------------- #}

{%- import "../util/util.jinja" as util with context -%}


pub mod {{ spec.namespace }} {

#[allow(unused_parens)]

{% if spec.docstring %}
{{ util.print_docstring(spec.docstring) }}
{% endif %}
{{ util.function_declaration(spec) }} {
{{ util.expr_code(spec) }}
}

} // mod {{ spec.namespace }}
Loading

0 comments on commit 079830a

Please sign in to comment.