Source code for fvdb_reality_capture.transforms.base_transform

# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
from abc import ABC, abstractmethod
from typing import Any, TypeVar

from fvdb_reality_capture.sfm_scene import SfmScene

# Keeps track of names of registered transforms and their classes.
REGISTERED_TRANSFORMS = {}


DerivedTransform = TypeVar("DerivedTransform", bound=type)


def transform(cls: DerivedTransform) -> DerivedTransform:
    """
    Decorator to register a transform class which inherits from :class:`BaseTransform`.

    Args:
        cls: The transform class to register.

    Returns:
        cls: The registered transform class.
    """
    if not issubclass(cls, BaseTransform):
        raise TypeError(f"Transform {cls} must inherit from BaseTransform.")

    if cls.name() in REGISTERED_TRANSFORMS:
        del REGISTERED_TRANSFORMS[cls.name()]

    REGISTERED_TRANSFORMS[cls.name()] = cls

    return cls


[docs] class BaseTransform(ABC): """ Base class for all transforms. Transforms are used to modify an :class:`~fvdb_reality_capture.sfm_scene.SfmScene` before it is used for reconstruction or other processing. They can be used to filter images, adjust camera parameters, or perform other modifications to the scene. Subclasses of :class:`BaseTransform` must implement the following methods: """
[docs] @abstractmethod def __call__(self, input_scene: SfmScene) -> SfmScene: """ Abstract method to apply the transform to the input scene and return the transformed scene. Args: input_scene (SfmScene): The input scene to transform. Returns: output_scene (SfmScene): The transformed scene. """ pass
[docs] @staticmethod @abstractmethod def name() -> str: """ Abstract method to return the name of the transform. Returns: str: The name of the transform. """ pass
[docs] @abstractmethod def state_dict(self) -> dict[str, Any]: """ Abstract method to return a dictionary containing information to serialize/deserialize the transform. Returns: state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform. """ pass
[docs] @staticmethod @abstractmethod def from_state_dict(state_dict: dict[str, Any]) -> "BaseTransform": """ Abstract method to create a transform from a state dictionary generated with :meth:`state_dict`. Args: state_dict (dict[str, Any]): A dictionary containing information to serialize/deserialize the transform. Returns: transform (BaseTransform): An instance of the transform. """ StateDictType = REGISTERED_TRANSFORMS.get(state_dict["name"], None) if StateDictType is None: raise ValueError( f"Transform '{state_dict['name']}' is not registered. Transform classes must be registered " f"with the `transform` decorator which will be called when the transform is defined. " f"Ensure the transform class uses the `transform` decorator and was imported before calling from_state_dict." ) return StateDictType.from_state_dict(state_dict)
def __repr__(self): return self.__class__.__name__