Skip to content

Commit

Permalink
Save
Browse files Browse the repository at this point in the history
  • Loading branch information
Nishad Gothoskar authored and littleredcomputer committed Feb 23, 2024
1 parent ba4234a commit 3c199c3
Show file tree
Hide file tree
Showing 4 changed files with 407 additions and 6 deletions.
13 changes: 7 additions & 6 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genjax.incremental import Diff, NoChange, UnknownChange

import bayes3d as b
import bayes3d.scene_graph

from .genjax_distributions import (
contact_params_uniform,
Expand Down Expand Up @@ -127,14 +128,14 @@ def get_far_plane(trace):


def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = b.get_indices(trace).shape[0] + 1
N = get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]


add_object_jit = jax.jit(add_object)
Expand All @@ -151,7 +152,7 @@ def print_trace(trace):


def viz_trace_meshcat(trace, colors=None):
b.clear()
b.clear_visualizer()
b.show_cloud(
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
)
Expand Down Expand Up @@ -223,14 +224,14 @@ def enumerator(trace, key, *args):
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[2]
)[0]

def enumerator_with_weight(trace, key, *args):
return trace.update(
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[1:3]
)[0:2]

def enumerator_score(trace, key, *args):
return enumerator(trace, key, *args).get_score()
Expand Down Expand Up @@ -301,4 +302,4 @@ def update_address(trace, key, address, value):
key,
genjax.choice_map({address: value}),
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
)[2]
)[0]
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy.spatial.transform import Rotation as R

import bayes3d as b
from bayes3d import distributions

# Can be helpful for debugging:
# jax.config.update('jax_enable_checks', True)
Expand Down
399 changes: 399 additions & 0 deletions demo_c2f.ipynb

Large diffs are not rendered by default.

Binary file added tutorial_real_data.pkl
Binary file not shown.

0 comments on commit 3c199c3

Please sign in to comment.