Source code for fvdb.utils._build_ext

# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
import os
from typing import Any, Sequence

import setuptools
from torch.utils import cpp_extension

import fvdb


[docs] def fvdbCudaExtension(name: str, sources: Sequence[str], *args: Any, **kwargs: Any) -> setuptools.Extension: """ Utility function for creating pytorch extensions that depend on fvdb. You then have access to all fVDB's internal headers to program with. Example usage: .. code-block:: python from fvdb.utils import FVDBExtension ext = FVDBExtension( name='my_extension', sources=['my_extension.cpp'], extra_compile_args={'cxx': ['-std=c++17']}, libraries=['mylib'], ) Args: name (str): The name of the extension. sources (Sequence[str]): The list of source files. args (list[Any]): Other arguments to pass to :func:`torch.utils.cpp_extension.CppExtension`. kwargs (dict): Other keyword arguments to pass to :func:`torch.utils.cpp_extension.CppExtension`. Returns: cpp_extension (setuptools.Extension) A :class:`setuptools.Extension` object which can be used to build a PyTorch C++ extension that depends on fVDB. """ libraries = kwargs.get("libraries", []) libraries.append("fvdb") kwargs["libraries"] = libraries library_dirs = kwargs.get("library_dirs", []) library_dirs.append(os.path.dirname(fvdb.__file__)) kwargs["library_dirs"] = library_dirs include_dirs = kwargs.get("include_dirs", []) include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include")) # We also need to add this because fvdb internally will refer to their headers without the fvdb/ prefix. include_dirs.append(os.path.join(os.path.dirname(fvdb.__file__), "include/fvdb")) kwargs["include_dirs"] = include_dirs extra_link_args = kwargs.get("extra_link_args", []) extra_link_args.append(f"-Wl,-rpath={os.path.dirname(fvdb.__file__)}") kwargs["extra_link_args"] = extra_link_args extra_compile_args = kwargs.get("extra_compile_args", {}) extra_compile_args["nvcc"] = extra_compile_args.get("nvcc", []) if "--extended-lambda" not in extra_compile_args["nvcc"]: extra_compile_args["nvcc"].append("--extended-lambda") kwargs["extra_compile_args"] = extra_compile_args return cpp_extension.CUDAExtension(name, sources, *args, **kwargs)