Source code for fvdb_reality_capture.transforms.downsample_images

# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
import logging
import pathlib
from typing import Any, Literal

import cv2
import tqdm

from fvdb_reality_capture.sfm_scene import SfmCache, SfmPosedImageMetadata, SfmScene

from .base_transform import BaseTransform, transform


[docs] @transform class DownsampleImages(BaseTransform): """ A :class:`~base_transform.BaseTransform` which downsamples all images in an :class:`~fvdb_reality_capture.sfm_scene.SfmScene` by a specified factor and caches the downsampled images for future use. You can specify the cached downsampled image type (e.g., ``"jpg"`` or ``"png"``), the mode for downsampling (e.g., ``cv2.INTER_AREA``), and the rescaled JPEG quality (if using JPEG). If the downsampled images already exist in the scene's cache with the correct parameters, they will be loaded from the cache instead of being regenerated. Example usage: .. code-block:: python # Example usage: from fvdb_reality_capture import transforms from fvdb_reality_capture.sfm_scene import SfmScene scene_transform = transforms.DownsampleImages(4) input_scene: SfmScene = ... # Load or create an SfmScene # The returned scene will have paths pointing to downsampled images by a factor of 4. transformed_scene: SfmScene = scene_transform(input_scene) """ version = "1.0.0"
[docs] def __init__( self, image_downsample_factor: int, image_type: Literal["jpg", "png"] = "jpg", rescale_sampling_mode: int = cv2.INTER_AREA, rescaled_jpeg_quality: int = 98, ): """ Create a new :class:`DownsampleImages` transform with the specified downsampling factor and image caching parameters (image type, downsampling mode, and quality). .. note:: We use enums from `OpenCV <https://opencv.org/>`_ for the ``rescale_sampling_mode`` parameter, e.g., ``cv2.INTER_AREA``, ``cv2.INTER_LINEAR``, ``cv2.INTER_CUBIC``, etc. This means if you want to change the resampling mode, you will need to ``import cv2``` and pass in the appropriate enum value. See the `OpenCV documentation <https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121>` for more details on valid enum values. Args: image_downsample_factor (int): The factor by which to downsample the images. image_type (str): The type of the cached downsampled images, either "jpg" or "png". rescale_sampling_mode (int): The interpolation method to use for rescaling images. Note that we use enums from `OpenCV <https://opencv.org/>`_ for this parameter, e.g., ``cv2.INTER_AREA``, ``cv2.INTER_LINEAR``, ``cv2.INTER_CUBIC``, etc. rescaled_jpeg_quality (int): The quality of the JPEG images when saving them to the cache (1-100). """ super().__init__() self._image_downsample_factor = image_downsample_factor self._image_type = image_type self._rescale_sampling_mode = rescale_sampling_mode self._rescaled_jpeg_quality = rescaled_jpeg_quality self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
[docs] def __call__(self, input_scene: SfmScene) -> SfmScene: """ Return a new :class:`~fvdb_reality_capture.sfm_scene.SfmScene` with images downsampled by the specified factor. *i.e.* images will be resized to ``(width / image_downsample_factor, height / image_downsample_factor)``. Args: input_scene (SfmScene): The input scene with images to be downsampled. Returns: output_scene (SfmScene): The scene with downsampled images. """ if self._image_downsample_factor == 1: self._logger.info("Image downsample factor is 1, skipping downsampling.") return input_scene if len(input_scene.images) == 0: self._logger.warning("No images found in the SfmScene. Returning the input scene unchanged.") return input_scene if len(input_scene.cameras) == 0: self._logger.warning("No cameras found in the SfmScene. Returning the input scene unchanged.") return input_scene input_cache: SfmCache = input_scene.cache cache_prefix = f"downsampled_{self._image_downsample_factor}x_{self._image_type}_q{self._rescaled_jpeg_quality}_m{self._rescale_sampling_mode}" output_cache = input_cache.make_folder( cache_prefix, description=f"Rescaled images by a factor of {self._image_downsample_factor}" ) new_camera_metadata = {} for cam_id, cam_meta in input_scene.cameras.items(): rescaled_cam_w = int(cam_meta.width / self._image_downsample_factor) rescaled_cam_h = int(cam_meta.height / self._image_downsample_factor) new_camera_metadata[cam_id] = cam_meta.resize(rescaled_cam_w, rescaled_cam_h) self._logger.info( f"Rescaling images using downsample factor {self._image_downsample_factor}, " f"sampling mode {self._rescale_sampling_mode}, and quality {self._rescaled_jpeg_quality}." ) self._logger.info(f"Attempting to load downsampled images from cache.") # How many zeros to pad the image index in the mask file names num_zeropad = len(str(len(input_scene.images))) + 2 new_image_metadata = [] regenerate_cache = False if output_cache.num_files != input_scene.num_images: if output_cache.num_files == 0: self._logger.info(f"No downsampled images found in the cache.") else: self._logger.info( f"Inconsistent number of downsampled images in the cache. " f"Expected {input_scene.num_images}, found {output_cache.num_files}. " f"Clearing cache and regenerating downsampled images." ) output_cache.clear_current_folder() regenerate_cache = True for image_id in range(input_scene.num_images): if regenerate_cache: break cache_image_filename = f"image_{image_id:0{num_zeropad}}" image_meta = input_scene.images[image_id] if not output_cache.has_file(cache_image_filename): self._logger.info( f"Image {cache_image_filename} not found in the cache. " f"Clearing cache and regenerating." ) output_cache.clear_current_folder() regenerate_cache = True break cache_file_meta = output_cache.get_file_metadata(cache_image_filename) value_meta = cache_file_meta["metadata"] value_quality = value_meta.get("quality", -1) value_mode = value_meta.get("downsample_mode", -1) if ( cache_file_meta.get("data_type", "") != self._image_type or value_quality != self._rescaled_jpeg_quality or value_mode != self._rescale_sampling_mode ): self._logger.info( f"Output cache image metadata does not match expected format. " f"Clearing the cache and regenerating downsampled images." ) output_cache.clear_current_folder() regenerate_cache = True break new_image_metadata.append( SfmPosedImageMetadata( world_to_camera_matrix=image_meta.world_to_camera_matrix, camera_to_world_matrix=image_meta.camera_to_world_matrix, camera_metadata=new_camera_metadata[image_meta.camera_id], camera_id=image_meta.camera_id, image_path=str(cache_file_meta["path"]), mask_path=image_meta.mask_path, point_indices=image_meta.point_indices, image_id=image_meta.image_id, ) ) if regenerate_cache: new_image_metadata = [] self._logger.info( f"Generating images downsampled by a factor of {self._image_downsample_factor} and saving to cache." ) pbar = tqdm.tqdm(input_scene.images, unit="imgs") for _, image_meta in enumerate(pbar): image_filename = pathlib.Path(image_meta.image_path).name full_res_image_path = image_meta.image_path full_res_img = cv2.imread(full_res_image_path) assert full_res_img is not None, f"Failed to load image {full_res_image_path}" img_h, img_w = full_res_img.shape[:2] rescaled_img_h = int(img_h / self._image_downsample_factor) rescaled_img_w = int(img_w / self._image_downsample_factor) assert rescaled_img_w == new_camera_metadata[image_meta.camera_id].width, "Got mismatched widths!" assert rescaled_img_h == new_camera_metadata[image_meta.camera_id].height, "Got mismatched heights!" pbar.set_description( f"Rescaling {image_filename} from {img_w} x {img_h} to {rescaled_img_w} x {rescaled_img_h}" ) rescaled_image = cv2.resize( full_res_img, (rescaled_img_w, rescaled_img_h), interpolation=self._rescale_sampling_mode ) assert ( rescaled_image.shape[0] == rescaled_img_h and rescaled_image.shape[1] == rescaled_img_w ), f"Rescaled image {image_filename} has shape {rescaled_image.shape} but expected {rescaled_img_h, rescaled_img_w}" # Save the rescaled image to the cache cache_image_filename = f"image_{image_meta.image_id:0{num_zeropad}}" cache_file_meta = output_cache.write_file( name=cache_image_filename, data=rescaled_image, data_type=self._image_type, quality=self._rescaled_jpeg_quality, metadata={ "quality": self._rescaled_jpeg_quality, "downsample_mode": self._rescale_sampling_mode, }, ) new_image_metadata.append( SfmPosedImageMetadata( world_to_camera_matrix=image_meta.world_to_camera_matrix, camera_to_world_matrix=image_meta.camera_to_world_matrix, camera_metadata=new_camera_metadata[image_meta.camera_id], camera_id=image_meta.camera_id, image_path=str(cache_file_meta["path"]), mask_path=image_meta.mask_path, point_indices=image_meta.point_indices, image_id=image_meta.image_id, ) ) pbar.close() self._logger.info( f"Rescaled {input_scene.num_images} images by a factor of {self._image_downsample_factor} " f"and saved to cache with sampling mode {self._rescale_sampling_mode} and quality " f"{self._rescaled_jpeg_quality}." ) output_scene = SfmScene( cameras=new_camera_metadata, images=new_image_metadata, points=input_scene.points, points_err=input_scene.points_err, points_rgb=input_scene.points_rgb, scene_bbox=input_scene.scene_bbox, transformation_matrix=input_scene.transformation_matrix, cache=output_cache, ) return output_scene
[docs] @staticmethod def name() -> str: """ Return the name of the :class:`DownsampleImages` transform. **i.e.** ``"DownsampleImages"``. Returns: str: The name of the :class:`DownsampleImages` transform. **i.e.** ``"DownsampleImages"``. """ return "DownsampleImages"
[docs] def state_dict(self) -> dict[str, Any]: """ Return the state of the :class:`DownsampleImages` transform for serialization. You can use this state dictionary to recreate the transform using :meth:`from_state_dict`. Returns: state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform. """ return { "name": self.name(), "version": self.version, "image_downsample_factor": self._image_downsample_factor, "image_type": self._image_type, "rescale_sampling_mode": self._rescale_sampling_mode, "rescaled_jpeg_quality": self._rescaled_jpeg_quality, }
[docs] @staticmethod def from_state_dict(state_dict: dict[str, Any]) -> "DownsampleImages": """ Create a :class:`DownsampleImages` transform from a state dictionary generated with :meth:`state_dict`. Args: state_dict (dict): The state dictionary for the transform. Returns: transform (DownsampleImages): An instance of the :class:`DownsampleImages` transform. """ if state_dict["name"] != "DownsampleImages": raise ValueError(f"Expected state_dict with name 'DownsampleImages', got {state_dict['name']} instead.") return DownsampleImages( image_downsample_factor=state_dict["image_downsample_factor"], image_type=state_dict["image_type"], rescale_sampling_mode=state_dict["rescale_sampling_mode"], rescaled_jpeg_quality=state_dict["rescaled_jpeg_quality"], )