# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
import logging
import os
from typing import Literal
import cv2
import numpy as np
import torch
import tqdm
from fvdb.types import NumericMaxRank1, to_VecNf
from scipy.spatial import ConvexHull
from fvdb_reality_capture.sfm_scene import SfmCache, SfmPosedImageMetadata, SfmScene
from .base_transform import BaseTransform, transform
def _crop_scene_to_bbox(
input_scene: SfmScene,
transform_name: str,
composite_with_existing_masks: bool,
mask_format: str,
bbox: np.ndarray,
logger: logging.Logger,
):
if bbox.shape != (6,):
raise ValueError("Bounding box must be a 1D array of shape (6,)")
output_cache_prefix = f"{transform_name}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}_{bbox[4]}_{bbox[5]}_{mask_format}_{composite_with_existing_masks}"
output_cache_prefix = output_cache_prefix.replace(" ", "_") # Ensure no spaces in the cache prefix
output_cache_prefix = output_cache_prefix.replace(".", "_") # Ensure no dots in the cache prefix
output_cache_prefix = output_cache_prefix.replace("-", "neg") # Ensure no dashes in the cache prefix
input_cache: SfmCache = input_scene.cache
output_cache = input_cache.make_folder(
output_cache_prefix,
description=f"Image masks ({mask_format}) for cropping to bounding box {bbox}",
)
# Create a mask over all the points which are inside the bounding box
points_mask = np.logical_and.reduce(
[
input_scene.points[:, 0] > bbox[0],
input_scene.points[:, 0] < bbox[3],
input_scene.points[:, 1] > bbox[1],
input_scene.points[:, 1] < bbox[4],
input_scene.points[:, 2] > bbox[2],
input_scene.points[:, 2] < bbox[5],
]
)
# Mask the scene using the points mask
masked_scene = input_scene.filter_points(points_mask)
# How many zeros to pad the image index in the mask file names
num_zeropad = len(str(len(masked_scene.images))) + 2
new_image_metadata = []
regenerate_cache = False
if output_cache.num_files != len(masked_scene.images) + 1:
if output_cache.num_files == 0:
logger.info(f"No masks found in the cache for cropping.")
else:
logger.info(
f"Inconsistent number of masks for images. Expected {len(masked_scene.images)}, found {output_cache.num_files}. "
f"Clearing cache and regenerating masks."
)
output_cache.clear_current_folder()
regenerate_cache = True
if output_cache.has_file("transform"):
_, transform_data = output_cache.read_file("transform")
cached_transform: np.ndarray | None = transform_data.get("transform", None)
if cached_transform is None:
logger.info(f"Transform metadata does not match expected format. No 'transform' key in cached file.")
output_cache.clear_current_folder()
regenerate_cache = True
elif not isinstance(cached_transform, np.ndarray) or cached_transform.shape != (4, 4):
logger.info(
f"Transform metadata does not match expected format. Expected 'transform'."
f"Clearing the cache and regenerating transform."
)
output_cache.clear_current_folder()
regenerate_cache = True
elif not np.allclose(cached_transform, input_scene.transformation_matrix):
logger.info(
f"Cached transform does not match input scene transform. Clearing the cache and regenerating transform."
)
output_cache.clear_current_folder()
regenerate_cache = True
else:
logger.info("No transform found in cache, regenerating.")
output_cache.clear_current_folder()
regenerate_cache = True
for image_id in range(len(masked_scene.images)):
if regenerate_cache:
break
image_cache_filename = f"mask_{image_id:0{num_zeropad}}"
image_meta = masked_scene.images[image_id]
if not output_cache.has_file(image_cache_filename):
logger.info(f"Mask for image {image_id} not found in cache. Clearing cache and regenerating masks.")
output_cache.clear_current_folder()
regenerate_cache = True
break
key_meta = output_cache.get_file_metadata(image_cache_filename)
if key_meta.get("data_type", "") != mask_format:
logger.info(
f"Output cache masks metadata does not match expected format. Expected '{mask_format}'."
f"Clearing the cache and regenerating masks."
)
output_cache.clear_current_folder()
regenerate_cache = True
break
new_image_metadata.append(
SfmPosedImageMetadata(
world_to_camera_matrix=image_meta.world_to_camera_matrix,
camera_to_world_matrix=image_meta.camera_to_world_matrix,
camera_metadata=image_meta.camera_metadata,
camera_id=image_meta.camera_id,
image_id=image_meta.image_id,
image_path=image_meta.image_path,
mask_path=str(key_meta["path"]),
point_indices=image_meta.point_indices,
)
)
if regenerate_cache:
output_cache.write_file("transform", {"transform": input_scene.transformation_matrix}, data_type="pt")
logger.info(f"Computing image masks for cropping and saving to cache.")
new_image_metadata = []
min_x, min_y, min_z, max_x, max_y, max_z = bbox
# (8, 4)-shaped array representing the corners of the bounding cube containing the input points
# in homogeneous coordinates
cube_bounds_world_space_homogeneous = np.array(
[
[min_x, min_y, min_z, 1.0],
[min_x, min_y, max_z, 1.0],
[min_x, max_y, min_z, 1.0],
[min_x, max_y, max_z, 1.0],
[max_x, min_y, min_z, 1.0],
[max_x, min_y, max_z, 1.0],
[max_x, max_y, min_z, 1.0],
[max_x, max_y, max_z, 1.0],
]
)
for image_meta in tqdm.tqdm(masked_scene.images, unit="imgs", desc="Computing image masks for cropping"):
cam_meta = image_meta.camera_metadata
# Transform the cube corners to camera space
cube_bounds_cam_space = image_meta.world_to_camera_matrix @ cube_bounds_world_space_homogeneous.T # [4, 8]
# Divide out the homogeneous coordinate -> [3, 8]
cube_bounds_cam_space = cube_bounds_cam_space[:3, :] / cube_bounds_cam_space[-1, :]
# Project the camera-space cube corners into image space [3, 3] * [8, 3] - > [8, 2]
cube_bounds_pixel_space = cam_meta.projection_matrix @ cube_bounds_cam_space # [3, 8]
# Divide out the homogeneous coordinate and transpose -> [8, 2]
cube_bounds_pixel_space = (cube_bounds_pixel_space[:2, :] / cube_bounds_pixel_space[2, :]).T
# Compute the pixel-space convex hull of the cube corners
convex_hull = ConvexHull(cube_bounds_pixel_space)
# Each face of the convex hull is defined by a normal vector and an offset
# These define a set of half spaces. We're going to check that we're on the inside of all of them
# to determine if a pixel is inside the convex hull
hull_normals = convex_hull.equations[:, :-1] # [num_faces, 2]
hull_offsets = convex_hull.equations[:, -1] # [n_faces]
# Generate a grid of pixel (u, v) coordinates of shape [image_height, image_width, 2]
image_width = image_meta.camera_metadata.width
image_height = image_meta.camera_metadata.height
pixel_u, pixel_v = np.meshgrid(np.arange(image_width), np.arange(image_height), indexing="xy")
pixel_coords = np.stack([pixel_u, pixel_v], axis=-1) # [image_height, image_width, 2]
# Shift and take the dot product between each pixel coordinate and the hull half-space normals
# to get the shortest signed distance to each face of the convex hull
# This produces an (image_height, image_width, num_faces)-shaped array
# where each pixel has a signed distance to each face of the convex hull
pixel_to_half_space_signed_distances = (
pixel_coords @ hull_normals.T + hull_offsets[np.newaxis, np.newaxis, :]
)
# A pixel lies inside the hull if it's signed distance to all faces is less than or equal to zero
# This produces a boolean mask of shape [image_height, image_width]
# where True indicates the pixel is inside the hull
inside_mask = np.all(pixel_to_half_space_signed_distances <= 0.0, axis=-1) # [image_height, image_width]
# If the mask already exists, load it and composite this one into it
mask_to_save = inside_mask.astype(np.uint8) * 255 # Convert to uint8 mask
if os.path.exists(image_meta.mask_path) and composite_with_existing_masks:
if image_meta.mask_path.strip().endswith(".npy"):
existing_mask = np.load(image_meta.mask_path)
elif image_meta.mask_path.strip().endswith(".png"):
existing_mask = cv2.imread(image_meta.mask_path, cv2.IMREAD_GRAYSCALE)
assert existing_mask is not None, f"Failed to load mask {image_meta.mask_path}"
elif image_meta.mask_path.strip().endswith(".jpg"):
existing_mask = cv2.imread(image_meta.mask_path, cv2.IMREAD_GRAYSCALE)
assert existing_mask is not None, f"Failed to load mask {image_meta.mask_path}"
else:
raise ValueError(f"Unsupported mask file format: {image_meta.mask_path}")
if existing_mask.ndim == 3:
# Ensure the mask is 3D to match the input mask
inside_mask = inside_mask[..., np.newaxis]
elif existing_mask.ndim != 2:
raise ValueError(f"Unsupported mask shape: {existing_mask.shape}. Must have 2D or 3D shape.")
if existing_mask.shape[:2] != inside_mask.shape[:2]:
raise ValueError(
f"Existing mask shape {existing_mask.shape[:2]} does not match computed mask shape {inside_mask.shape[:2]}."
)
mask_to_save = existing_mask * inside_mask
cache_file_meta = output_cache.write_file(
name=f"mask_{image_meta.image_id:0{num_zeropad}}",
data=mask_to_save,
data_type=mask_format,
)
new_image_metadata.append(
SfmPosedImageMetadata(
world_to_camera_matrix=image_meta.world_to_camera_matrix,
camera_to_world_matrix=image_meta.camera_to_world_matrix,
camera_metadata=image_meta.camera_metadata,
camera_id=image_meta.camera_id,
image_id=image_meta.image_id,
image_path=image_meta.image_path,
mask_path=str(cache_file_meta["path"]),
point_indices=image_meta.point_indices,
)
)
output_scene = SfmScene(
cameras=masked_scene.cameras,
images=new_image_metadata,
points=masked_scene.points,
points_rgb=masked_scene.points_rgb,
points_err=masked_scene.points_err,
scene_bbox=bbox,
transformation_matrix=input_scene.transformation_matrix,
cache=output_cache,
)
return output_scene
[docs]
@transform
class CropScene(BaseTransform):
"""
A :class:`~base_transform.BaseTransform` which crops the input
:class:`~fvdb_reality_capture.sfm_scene.SfmScene` points to lie within a specified bounding box.
This transform additionally and updates the scene's masks to nullify pixels whose rays do not intersect
the bounding box.
.. note::
If the input scene already has masks, these new masks will be composited with the existing masks to ensure that
pixels outside the cropped region are properly masked. This can be disabled by setting
``composite_with_existing_masks`` to ``False``.
Example usage:
.. code-block:: python
# Example usage:
from fvdb_reality_capture import transforms
from fvdb_reality_capture.sfm_scene import SfmScene
import numpy as np
# Bounding box in the format (min_x, min_y, min_z, max_x, max_y, max_z)
scene_transform = transforms.CropScene(bbox=np.array([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]))
input_scene: SfmScene = ... # Load or create an SfmScene
# The transformed scene will have points only within the bounding box, and posed images will have
# masks updated to nullify pixels corresponding to regions outside the cropped scene.
transformed_scene: SfmScene = scene_transform(input_scene)
"""
version = "1.0.0"
[docs]
def __init__(
self,
bbox: NumericMaxRank1,
mask_format: Literal["png", "jpg", "npy"] = "png",
composite_with_existing_masks: bool = True,
):
"""
Create a new :class:`CropScene` transform with a bounding box.
Args:
bbox (NumericMaxRank1): A bounding box in the format ``(min_x, min_y, min_z, max_x, max_y, max_z)``.
mask_format (Literal["png", "jpg", "npy"]): The format to save the masks in. Defaults to "png".
composite_with_existing_masks (bool): Whether to composite the masks generated into existing masks for
pixels corresponding to regions outside the cropped scene. If set to ``True``, existing masks
will be loaded and composited with the new mask. Defaults to ``True``. The resulting composited
mask will allow a pixel to be valid if it is valid in both the existing and new mask.
"""
super().__init__()
bbox = to_VecNf(bbox, 6, dtype=torch.float64).numpy()
self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
if not len(bbox) == 6:
raise ValueError("Bounding box must be a tuple of the form (min_x, min_y, min_z, max_x, max_y, max_z).")
self._bbox = np.asarray(bbox).astype(np.float32)
self._mask_format = mask_format
if self._mask_format not in ["png", "jpg", "npy"]:
raise ValueError(
f"Unsupported mask format: {self._mask_format}. Supported formats are 'png', 'jpg', and 'npy'."
)
self._composite_with_existing_masks = composite_with_existing_masks
[docs]
@staticmethod
def name() -> str:
"""
Return the name of the :class:`CropScene` transform. **i.e.** ``"CropScene"``.
Returns:
str: The name of the :class:`CropScene` transform. **i.e.** ``"CropScene"``.
"""
return "CropScene"
[docs]
@staticmethod
def from_state_dict(state_dict: dict) -> "CropScene":
"""
Create a :class:`CropScene` transform from a state dictionary created with :meth:`state_dict`.
Args:
state_dict (dict): The state dictionary for the transform.
Returns:
transform (CropScene): An instance of the :class:`CropScene` transform.
"""
bbox = state_dict.get("bbox", None)
if bbox is None:
raise ValueError("State dictionary must contain 'bbox' key with bounding box coordinates.")
if not isinstance(bbox, np.ndarray) or len(bbox) != 6:
raise ValueError(
"Bounding box must be a tuple or array of the form (min_x, min_y, min_z, max_x, max_y, max_z)."
)
return CropScene(bbox)
[docs]
def state_dict(self) -> dict:
"""
Return the state of the :class:`CropScene` transform for serialization.
You can use this state dictionary to recreate the transform using :meth:`from_state_dict`.
Returns:
state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform.
"""
return {
"name": self.name(),
"version": self.version,
"bbox": self._bbox,
"mask_format": self._mask_format,
"composite_into_existing_masks": self._composite_with_existing_masks,
}
[docs]
def __call__(self, input_scene: SfmScene) -> SfmScene:
"""
Return a new :class:`~fvdb_reality_capture.sfm_scene.SfmScene` with points cropped to lie within the bounding
box specified at initialization, and with masks updated to nullify pixels whose rays do not intersect the
bounding box.
Args:
input_scene (SfmScene): The scene to be cropped.
Returns:
output_scene (SfmScene): The cropped scene.
"""
# Ensure the bounding box is a numpy array of length 6
bbox = np.asarray(self._bbox, dtype=np.float32)
if bbox.shape != (6,):
raise ValueError("Bounding box must be a 1D array of shape (6,)")
self._logger.info(f"Cropping scene to bounding box: {self._bbox}")
return _crop_scene_to_bbox(
input_scene=input_scene,
transform_name=self.name(),
composite_with_existing_masks=self._composite_with_existing_masks,
mask_format=self._mask_format,
bbox=bbox,
logger=self._logger,
)
[docs]
@transform
class CropSceneToPoints(BaseTransform):
"""
A :class:`~base_transform.BaseTransform` which crops the input
:class:`~fvdb_reality_capture.sfm_scene.SfmScene` points to lie within the bounding box around its points plus
or minus a padding margin. This transform additionally and updates the scene's masks to nullify pixels whose rays
do not intersect the bounding box.
.. note::
If the input scene already has masks, these new masks will be composited with the existing masks to ensure that
pixels outside the cropped region are properly masked. This can be disabled by setting
``composite_with_existing_masks`` to ``False``.
.. note::
You may want to use this over :class:`CropScene` if you want the bounding box to depend on the input scene
points rather than being fixed (*e.g.* if you don't know the bounding box ahead of time). This transform
is also useful if you just want to apply conservative masking to the input scene based on its points.
.. note::
The margin is specified as a fraction of the bounding box size. For example, a margin of 0.1 will expand the
bounding box by 10% (5% in all directions). So if the scene's bounding box is ``(0, 0, 0)`` to ``(1, 1, 1)``,
a margin of ``0.1`` will result in a bounding box of ``(-0.05, -0.05, -0.05)`` to ``(1.05, 1.05, 1.05)``.
The margin can also be negative to shrink the bounding box.
Example usage:
.. code-block:: python
# Example usage:
from fvdb_reality_capture import transforms
from fvdb_reality_capture.sfm_scene import SfmScene
import numpy as np
# Crop the scene to be 0.1 times smaller than the bounding box around its points
# (i.e. a margin of -0.1)
scene_transform = transforms.CropSceneToPoints(margin=-0.1)
input_scene: SfmScene = ... # Load or create an SfmScene
# The transformed scene will have points only within the bounding box of its points
# minus a factor of 0.1 times the size. (i.e. a margin of -0.1).
# Posed images will have masks updated to nullify pixels corresponding to regions outside the cropped scene.
transformed_scene: SfmScene = scene_transform(input_scene)
"""
version = "1.0.0"
[docs]
def __init__(
self,
margin: float = 0.0,
mask_format: Literal["png", "jpg", "npy"] = "png",
composite_with_existing_masks: bool = True,
):
"""
Create a new :class:`CropSceneToPoints` transform with the given margin.
Args:
margin (float): The margin factor to apply around the bounding box of the points.
Can be negative to shrink the bounding box. This is a fraction of the bounding box size.
For example, a margin of ``0.1`` will expand the bounding box by 10% (5% in all directions),
while a margin of ``-0.1`` will shrink the bounding box by 10% (-5% in all directions).
Defaults to ``0.0``.
mask_format (Literal["png", "jpg", "npy"]): The format to save the masks in. Defaults to "png".
composite_with_existing_masks (bool): Whether to composite the masks generated into existing masks for
pixels corresponding to regions outside the cropped scene. If set to True, existing masks
will be loaded and composited with the new mask. Defaults to True.
"""
super().__init__()
self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
self._margin = margin
self._mask_format = mask_format
if self._mask_format not in ["png", "jpg", "npy"]:
raise ValueError(
f"Unsupported mask format: {self._mask_format}. Supported formats are 'png', 'jpg', and 'npy'."
)
self._composite_with_existing_masks = composite_with_existing_masks
[docs]
@staticmethod
def name() -> str:
"""
Return the name of the :class:`CropSceneToPoints` transform. *i.e.* ``"CropSceneToPoints"``.
Returns:
str: The name of the :class:`CropSceneToPoints` transform. *i.e.* ``"CropSceneToPoints"``.
"""
return "CropSceneToPoints"
[docs]
@staticmethod
def from_state_dict(state_dict: dict) -> "CropSceneToPoints":
"""
Create a :class:`CropSceneToPoints` transform from a state dictionary generated with :meth:`state_dict`.
Args:
state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform.
Returns:
transform (:class:`CropSceneToPoints`): An instance of the :class:`CropSceneToPoints` transform loaded
from the state dictionary.
"""
margin = state_dict.get("margin", None)
if margin is None:
raise ValueError("State dictionary must contain 'margin' key with margin value.")
if not isinstance(margin, (float, int)):
raise ValueError("Margin must be a non-negative float.")
mask_format = state_dict.get("mask_format", None)
if mask_format is None:
raise ValueError("State dictionary must contain 'mask_format' key with mask format value.")
if mask_format is not None and mask_format not in ["png", "jpg", "npy"]:
raise ValueError(f"Unsupported mask format: {mask_format}. Supported formats are 'png', 'jpg', and 'npy'.")
composite_into_existing_masks = state_dict.get("composite_into_existing_masks", None)
if composite_into_existing_masks is None:
raise ValueError("State dictionary must contain 'composite_into_existing_masks' key with boolean value.")
if not isinstance(composite_into_existing_masks, bool):
raise ValueError("composite_into_existing_masks must be a boolean.")
return CropSceneToPoints(
margin=margin, mask_format=mask_format, composite_with_existing_masks=composite_into_existing_masks
)
[docs]
def state_dict(self) -> dict:
"""
Return the state of the :class:`CropSceneToPoints` transform for serialization.
You can use this state dictionary to recreate the transform using :meth:`from_state_dict`.
Returns:
state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform.
"""
return {
"name": self.name(),
"version": self.version,
"margin": self._margin,
"mask_format": self._mask_format,
"composite_into_existing_masks": self._composite_with_existing_masks,
}
[docs]
def __call__(self, input_scene: SfmScene) -> SfmScene:
"""
Return a new :class:`~fvdb_reality_capture.sfm_scene.SfmScene` with points cropped to lie within the
bounding box of the input scene's points plus or minus the margin specified at initialization,
and with masks updated to nullify pixels whose rays do not intersect the bounding box.
Args:
input_scene (SfmScene): The scene to be cropped.
Returns:
output_scene (SfmScene): The cropped scene.
"""
points_min = input_scene.points.min(axis=0)
points_max = input_scene.points.max(axis=0)
box_size = points_max - points_min
padding = self._margin * box_size / 0.5
points_min -= padding
points_max += padding
bbox = np.array(
[
points_min[0],
points_min[1],
points_min[2],
points_max[0],
points_max[1],
points_max[2],
],
dtype=np.float32,
)
if bbox.shape != (6,):
raise ValueError("Bounding box must be a 1D array of shape (6,)")
self._logger.info(f"Cropping scene to point bounding box {bbox} using margin {self._margin}")
return _crop_scene_to_bbox(
input_scene=input_scene,
transform_name=self.name(),
composite_with_existing_masks=self._composite_with_existing_masks,
mask_format=self._mask_format,
bbox=bbox,
logger=self._logger,
)