Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #60

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
args: [--unsafe]
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.9.3
hooks:
- id: ruff
# types_or: [ python, pyi, jupyter ]
Expand Down
6 changes: 3 additions & 3 deletions bayes3d/colmap/colmap_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def read_intrinsics_text(path):
elems = line.split()
camera_id = int(elems[0])
model = elems[1]
assert (
model == "PINHOLE"
), "While the loader support other types, the rest of the code assumes PINHOLE"
assert model == "PINHOLE", (
"While the loader support other types, the rest of the code assumes PINHOLE"
)
width = int(elems[2])
height = int(elems[3])
params = np.array(tuple(map(float, elems[4:])))
Expand Down
4 changes: 3 additions & 1 deletion bayes3d/colmap/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
FovY = focal2fov(focal_length_y, height)
FovX = focal2fov(focal_length_x, width)
else:
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
assert False, (
"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
)

image_path = os.path.join(images_folder, os.path.basename(extr.name))
image_name = os.path.basename(image_path).split(".")[0]
Expand Down
10 changes: 5 additions & 5 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def get_far_plane(trace):
def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = b.get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
choices[f"parent_{N - 1}"] = parent
choices[f"id_{N - 1}"] = obj_id
choices[f"face_parent_{N - 1}"] = face_parent
choices[f"face_child_{N - 1}"] = face_child
choices[f"contact_params_{N - 1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]


Expand Down
2 changes: 1 addition & 1 deletion bayes3d/neural/cosypose_baseline/cosypose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def cosypose_interface(rgb_imgs, camera_k):
all_scores = []
for i, rgb_img in enumerate(rgb_imgs):
pred = COSYPOSE_MODEL.inference(rgb_img, camera_k)
print(f"{i+1}/{num_imgs} inference done")
print(f"{i + 1}/{num_imgs} inference done")

pred_poses = np.asarray(pred.poses.cpu())
pred_ids = [
Expand Down
16 changes: 9 additions & 7 deletions bayes3d/neural/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
return model

stride = nn_utils._pair(stride)
assert all(
[(patch_size // s_) * s_ == patch_size for s_ in stride]
), f"stride {stride} should divide patch_size {patch_size}"
assert all([(patch_size // s_) * s_ == patch_size for s_ in stride]), (
f"stride {stride} should divide patch_size {patch_size}"
)

# fix the stride
model.patch_embed.proj.stride = stride
Expand Down Expand Up @@ -415,7 +415,9 @@ def extract_descriptors(
if not include_cls:
x = x[:, :, 1:, :] # remove cls token
else:
assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
assert not bin, (
"bin = True and include_cls = True are not supported together, set one of them False."
)
if not bin:
desc = (
x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1)
Expand All @@ -431,9 +433,9 @@ def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
:param batch: batch to extract saliency maps for. Has shape BxCxHxW.
:return: a tensor of saliency maps. has shape Bxt-1
"""
assert (
self.model_type == "dino_vits8"
), "saliency maps are supported only for dino_vits model_type."
assert self.model_type == "dino_vits8", (
"saliency maps are supported only for dino_vits model_type."
)
self._extract_features(batch, [11], "attn")
head_idxs = [0, 2, 4, 5]
curr_feats = self._feats[0] # Bxhxtxt
Expand Down
2 changes: 1 addition & 1 deletion bayes3d/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def add_mesh(self, mesh, mesh_name=None, scaling_factor=1.0, center_mesh=True):
bounding_box_dims, bounding_box_pose = bayes3d.utils.aabb(mesh.vertices)
if center_mesh:
if not jnp.isclose(bounding_box_pose[:3, 3], 0.0).all():
print(f"Centering mesh with translation {bounding_box_pose[:3,3]}")
print(f"Centering mesh with translation {bounding_box_pose[:3, 3]}")
mesh.vertices = mesh.vertices - bounding_box_pose[:3, 3]

self.meshes.append(mesh)
Expand Down
8 changes: 4 additions & 4 deletions bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def xfm_points(points, matrix):
pos_grads.min().item(),
pos_grads.max().item(),
)
print(f"JAX rasterization (eval + grad): {(end_time - start_time)*1000} ms")
print(f"JAX rasterization (eval + grad): {(end_time - start_time) * 1000} ms")

# save viz
b.viz.get_depth_image(rast_out[0][:, :, 2]).save("img_jax.png")
Expand Down Expand Up @@ -229,7 +229,7 @@ def xfm_points(points, matrix):
pos_grads.min().item(),
pos_grads.max().item(),
)
print(f"Torch rasterization (eval + grad): {(end_time - start_time)*1000} ms")
print(f"Torch rasterization (eval + grad): {(end_time - start_time) * 1000} ms")

# save viz
b.viz.get_depth_image(jnp.array(rast_out[0][:, :, 2].cpu())).save("img_torch.png")
Expand Down Expand Up @@ -278,7 +278,7 @@ def xfm_points(points, matrix):
print(
f"JAX BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}"
)
print(f"JAX interpolation: {(end_time - start_time)*1000} ms")
print(f"JAX interpolation: {(end_time - start_time) * 1000} ms")

# save viz
b.viz.get_depth_image(gb_pos[0][:, :, 2]).save("interpolate_jax.png")
Expand Down Expand Up @@ -316,7 +316,7 @@ def xfm_points(points, matrix):
print(
f"TORCH BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}"
)
print(f"Torch interpolation: {(end_time - start_time)*1000} ms")
print(f"Torch interpolation: {(end_time - start_time) * 1000} ms")

# save viz
b.viz.get_depth_image(jnp.array(gb_pos[0][:, :, 2].cpu())).save(
Expand Down
26 changes: 13 additions & 13 deletions scripts/_mkl/notebooks/00a - Types.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|default_exp types"
"# |default_exp types"
]
},
{
Expand All @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"# |export\n",
"from typing import Any, NamedTuple\n",
"import numpy as np\n",
"import jax\n",
Expand All @@ -29,18 +29,18 @@
"Int = Array\n",
"FaceIndex = int\n",
"FaceIndices = Array\n",
"ArrayN = Array\n",
"Array3 = Array\n",
"Array2 = Array\n",
"ArrayNx2 = Array\n",
"ArrayNx3 = Array\n",
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
"PrecisionMatrix = Matrix\n",
"ArrayN = Array\n",
"Array3 = Array\n",
"Array2 = Array\n",
"ArrayNx2 = Array\n",
"ArrayNx3 = Array\n",
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
"PrecisionMatrix = Matrix\n",
"CovarianceMatrix = Matrix\n",
"CholeskyMatrix = Matrix\n",
"SquareMatrix = Matrix\n",
"Vector = Array\n",
"Direction = Vector\n",
"CholeskyMatrix = Matrix\n",
"SquareMatrix = Matrix\n",
"Vector = Array\n",
"Direction = Vector\n",
"BaseVector = Vector"
]
},
Expand Down
Loading
Loading