Sparse Convolution

class fvdb.ConvolutionPlan(_pack_info: SparseConvPackInfo, _channel_pairs: tuple[tuple[int, int], ...], _transposed: bool, _expert_config: dict[str, Any], _backend: ConvPackBackend)[source]

A pre-configured plan for efficient sparse 3D convolution operations on fvdb.Grid and fvdb.GridBatch.

ConvolutionPlan encapsulates all the configuration and optimization structures needed to perform sparse convolution operations efficiently. Like FFT plans in signal processing libraries, a ConvolutionPlan represents a single direction of computation - either regular convolution or transposed convolution.

The plan handles the complex sparse data structures and backend optimizations internally, allowing users to focus on the core convolution parameters: input/output channels, kernel size, stride, and the grid structure.

Transposition is treated as just a different kind of kernel, so the inputs and outputs and weights are treated the same as if it were a regular convolution. For the default padded case, transposed outputs can’t automatically infer the target_grid, so it must be provided, unless the dense, halo, and lggs backends are used.

Usage Pattern:

  1. Create a plan using one of the from_* class methods (see from_grid_batch(), and from_grid()).

  2. Use the execute() method to perform convolutions with different weights and data on the same grid structures.

  3. Reuse the same plan for multiple convolutions with the same configuration

Example Usage:

from fvdb import Grid, ConvolutionPlan

# Create a grid
my_grid = Grid.from_ijk(...)

# Create a plan for 3x3x3 convolution with stride 1
plan = ConvolutionPlan.from_grid(
    kernel_size=3,
    stride=1,
    source_grid=my_grid
)

# execute convolution with different weights
features = torch.randn(num_voxels, 32, device="cuda")
weights = torch.randn(64, 32, 3, 3, 3, device="cuda")
output = plan.execute(features, weights)

Note

  • Always create plans using the from_* class methods, never call __init__ directly

  • Plans are immutable once created

  • The same plan can be reused for multiple execute() calls with different data/weights

  • Channel pairs can be specified at plan creation time for optimal backend selection

execute(data: Tensor, weights: Tensor) Tensor[source]
execute(data: JaggedTensor, weights: Tensor) JaggedTensor

Execute this ConvolutionPlan with the input data and weights.

This is the main method for performing convolution operations. It applies the convolution kernel to the sparse voxel data according to the plan’s pre-configured structure and optimizations.

If this plan was created for a single grid (e.g. using from_grid() or from_grid_transposed()), then data should be a torch.Tensor with shape (total_voxels, in_channels).

If this plan was created for a batch of grids (e.g. using from_grid_batch() or from_grid_batch_transposed()), then data should be a JaggedTensor with shape (batch_size, num_voxels_in_grid_b, in_channels).

Note

  • The same plan can be reused with different weights and data

  • Channel pairs must match those specified during plan creation

  • The plan automatically handles the sparse structure and backend optimizations

  • For transposed convolution plans, this performs the transpose operation

Parameters:
  • data (torch.Tensor | JaggedTensor) – Input voxel features. Can be either: (i) torch.Tensor for single grids: shape (total_voxels, in_channels) or (ii) JaggedTensor for batches of grids: shape (batch_size, num_voxels_in_grid_b, in_channels)

  • weights (torch.Tensor) – Convolution kernel weights with shape: (out_channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])

Returns:

output_features (torch.Tensor | JaggedTensor) – Convolved features with the same type as input: (i) torch.Tensor with shape (total_output_voxels, out_channels) for single grids or (ii) JaggedTensor with shape (batch_size, output_voxels_per_grid, out_channels) for batches

Raises:

ValueError – If the channel pair (in_channels, out_channels) from the weights is not supported by this plan’s channel_pairs configuration.

Example:

# Single grid example
features = torch.randn(1000, 32, device="cuda")  # 1000 voxels, 32 channels
weights = torch.randn(64, 32, 3, 3, 3, device="cuda")  # 32->64 channels, 3x3x3 kernel
output = plan.execute(features, weights)  # Shape: (output_voxels, 64)

# Batched example
batch_features = JaggedTensor(torch.randn(5, 1000, 32, device="cuda"))
output = plan.execute(batch_features, weights)  # Shape: (5, output_voxels, 64)
classmethod from_grid(kernel_size: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, stride: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, source_grid: Grid, target_grid: Grid | None = None, *, expert_config: dict[str, Any] = {'allow_tf32': True, 'backend': 'default', 'feature_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64), 'weight_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64)}, channel_pairs: tuple[tuple[int, int], ...] = ()) ConvolutionPlan[source]

Create a ConvolutionPlan for convolution on a single grid. i.e. convolution where the input and output domains are both of type fvdb.Grid.

This method creates a plan for processing a single grid, which is suitable when you have individual grids rather than batched data (for that case, use from_grid_batch()).

Parameters:
  • kernel_size (NumericMaxRank1) – Size of the convolution kernel. Can be a single int (cubic kernel) or a 3-element sequence for (x, y, z) dimensions.

  • stride (NumericMaxRank1) – Convolution stride. Can be a single int or 3-element sequence.

  • source_grid (Grid) – fvdb.Grid encoding the structure of the input domain.

  • target_grid (Grid | None) – fvdb.Grid encoding the structure of the output domain. If None, the target_grid is automatically computed based on kernel_size and stride applied to source_grid. (except for the case of the dense, halo, and lggs backends where it uses ``target_grid = source_grid``. For those backends, ``target_grid`` must be ``None``.)

  • expert_config (dict[str, Any]) – Advanced configuration options (rarely needed by typical users).

  • channel_pairs (tuple[tuple[int, int], ...]) – Supported input/output channel combinations as tuples. Each tuple represents (input_channels, output_channels). e.g: ((32, 64), (64, 128)) supports 32->64 and 64->128 convolutions. This parameter can be useful to select a more performant backend when the channel configurations are known in advance. Defaults to _ANY_CHANNEL_PAIRS, which means any channel pairs are supported.

Returns:

convolution_plan (ConvolutionPlan) – Configured plan ready for execute() operations.

Example:

# Create a single grid
grid = Grid.from_zero_voxels(device="cuda", voxel_size=0.1, origin=0)

# Create plan for 3x3x3 convolution
plan = ConvolutionPlan.from_grid(
    kernel_size=3,
    stride=1,
    source_grid=grid
)

# execute to single grid data
features = torch.randn(100, 8, device="cuda")
weights = torch.randn(16, 8, 3, 3, 3, device="cuda")
output = plan.execute(features, weights)
classmethod from_grid_batch(kernel_size: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, stride: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, source_grid: GridBatch, target_grid: GridBatch | None = None, *, expert_config: dict[str, Any] = {'allow_tf32': True, 'backend': 'default', 'feature_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64), 'weight_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64)}, channel_pairs: tuple[tuple[int, int], ...] = ()) ConvolutionPlan[source]

Create a ConvolutionPlan for convolution on batches of grids. i.e. convolution where the input and output domains are both of type fvdb.GridBatch.

The plan returned by this method is optimized for running convolution on a batch of grids simultaneously and in parallel, which is more efficient than processing individual grids separately when you have a batch of data.

Parameters:
  • kernel_size (NumericRank1) – Size of the convolution kernel. Can be a single int (cubic kernel) or a 3-element sequence for (x, y, z) dimensions.

  • stride (NumericRank1) – Convolution stride. Can be a single int or 3-element sequence.

  • source_grid (GridBatch) – fvdb.GridBatch encoding the structure of the input domain.

  • target_grid (GridBatch | None) – fvdb.GridBatch encoding the structure of the output domain. If None, the target_grid is automatically computed based on kernel_size and stride applied to source_grid. (except for the case of the dense, halo, and lggs backends where it uses ``target_grid = source_grid``. For those backends, ``target_grid`` must be ``None``.)

  • expert_config (dict[str, Any]) – Advanced configuration options (rarely needed by typical users).

  • channel_pairs (tuple[tuple[int, int], ...]) – Supported input/output channel combinations as tuples. Each tuple represents (input_channels, output_channels). e.g: ((32, 64), (64, 128)) supports 32->64 and 64->128 convolutions. This parameter can be useful to select a more performant backend when the channel configurations are known in advance. Defaults to _ANY_CHANNEL_PAIRS, which means any channel pairs are supported.

Returns:

convolution_plan (ConvolutionPlan) – Configured plan ready for execute() operations.

Example:

# Create a batched grid
grid_batch = GridBatch.from_points(...)

# Create plan for 3x3x3 convolution on batched grids
plan = ConvolutionPlan.from_grid_batch(
    kernel_size=3,
    stride=1,
    source_grid=grid_batch
)

# execute to batched data
batch_data = JaggedTensor(torch.randn(5, 1000, 8, device="cuda"))
weights = torch.randn(16, 8, 3, 3, 3, device="cuda")
output = plan.execute(batch_data, weights)
classmethod from_grid_batch_transposed(kernel_size: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, stride: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, source_grid: GridBatch, target_grid: GridBatch | None = None, *, expert_config: dict[str, Any] = {'allow_tf32': True, 'backend': 'default', 'feature_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64), 'weight_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64)}, channel_pairs: tuple[tuple[int, int], ...] = ()) ConvolutionPlan[source]

Create a ConvolutionPlan for transposed convolution on batches of grids. i.e. transposed convolution where the input and output domains are both of type fvdb.GridBatch.

Transposed convolution (also known as deconvolution) is commonly used for upsampling operations, such as in decoder networks or generative models. It performs the mathematical transpose of the convolution operation.

Note

Though deconvolution is the “reverse” of convolution in some sense, this configuration still treats input and output channels as inputs and outputs, it doesn’t swap them. The source and target grids are not swapped, it is best to think of deconvolution as convolution with a different kernel than deconvolution, but it is otherwise the same kind of abstract operation.

Note

For most backends, target_grid can be automatically computed. Only certain expert backends require specific target_grid configurations.

Parameters:
  • kernel_size (NumericMaxRank1) – Size of the convolution kernel. Can be a single int (cubic kernel) or a 3-element sequence for (x, y, z) dimensions.

  • stride – Convolution stride. Can be a single int or 3-element sequence.

  • source_grid (GridBatch) – fvdb.GridBatch encoding the structure of the input domain.

  • target_grid (GridBatch | None) – fvdb.GridBatch encoding the structure of the output domain. If None, the target_grid is automatically computed based on kernel_size and stride applied to source_grid. (except for the case of the dense, halo, and lggs backends where it uses ``target_grid = source_grid``. For those backends, ``target_grid`` must be ``None``.)

  • expert_config (dict[str, Any]) – Advanced configuration options (rarely needed by typical users).

  • channel_pairs (tuple[tuple[int, int], ...]) – Supported input/output channel combinations as tuples. Each tuple represents (input_channels, output_channels). e.g: ((32, 64), (64, 128)) supports 32->64 and 64->128 convolutions. This parameter can be useful to select a more performant backend when the channel configurations are known in advance. Defaults to _ANY_CHANNEL_PAIRS, which means any channel pairs are supported.

Returns:

convolution_plan (ConvolutionPlan) – Configured plan ready for transposed convolution operations via execute().

classmethod from_grid_transposed(kernel_size: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, stride: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, source_grid: Grid, target_grid: Grid | None = None, *, expert_config: dict[str, Any] = {'allow_tf32': True, 'backend': 'default', 'feature_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64), 'weight_dtypes': (torch.float16, torch.bfloat16, torch.float32, torch.float64)}, channel_pairs: tuple[tuple[int, int], ...] = ()) ConvolutionPlan[source]

Create a ConvolutionPlan for transposed convolution on a single grid. i.e. transposed convolution where the input and output domains are both of type fvdb.Grid.

Transposed convolution (also known as deconvolution) is commonly used for upsampling operations, such as in decoder networks or generative models. It performs the mathematical transpose of the convolution operation.

Note

Though deconvolution is the “reverse” of convolution in some sense, this configuration still treats input and output channels as inputs and outputs, it doesn’t swap them. The source and target grids are not swapped, it is best to think of deconvolution as convolution with a different kernel than deconvolution, but it is otherwise the same kind of abstract operation.

Note

For most backends, target_grid can be automatically computed. Only certain expert backends require specific target_grid configurations.

Parameters:
  • kernel_size (NumericMaxRank1) – Size of the convolution kernel. Can be a single int (cubic kernel) or a 3-element sequence for (x, y, z) dimensions.

  • stride (NumericMaxRank1) – Convolution stride. Can be a single int or 3-element sequence.

  • source_grid (Grid) – fvdb.Grid encoding the structure of the input domain.

  • target_grid (Grid | None) – fvdb.Grid encoding the structure of the output domain. If None, the target_grid is automatically computed based on kernel_size and stride applied to source_grid. (except for the case of the dense, halo, and lggs backends where it uses ``target_grid = source_grid``. For those backends, ``target_grid`` must be ``None``.)

  • expert_config (dict[str, Any]) – Advanced configuration options (rarely needed by typical users).

  • channel_pairs (tuple[tuple[int, int], ...]) – Supported input/output channel combinations as tuples. Each tuple represents (input_channels, output_channels). e.g: ((32, 64), (64, 128)) supports 32->64 and 64->128 convolutions. This parameter can be useful to select a more performant backend when the channel configurations are known in advance. Defaults to _ANY_CHANNEL_PAIRS, which means any channel pairs are supported.

Returns:

convolution_plan (ConvolutionPlan) – Configured plan ready for transposed convolution operations.

classmethod from_plan_transposed(plan: ConvolutionPlan) ConvolutionPlan[source]

Create a transposed version of an existing ConvolutionPlan.

This method creates a new plan that performs the transpose operation of the given plan (i.e convolution becomes transposed convolution and vice versa). It automatically swaps the source and target grids, reverses the channel pairs, and flips the transposed flag.

Note

This is particularly useful for creating encoder-decoder pairs where the decoder needs to undo the operations of the encoder.

Parameters:

plan (ConvolutionPlan) – An existing ConvolutionPlan to transpose.

Returns:

convolution_plan (ConvolutionPlan) – A new plan that performs the transpose of the input plan.

Example:

# Create forward plan
forward_plan = ConvolutionPlan.from_grid(
    kernel_size=3,
    stride=1,
    source_grid=input_grid
)

# Create the corresponding backward/transpose plan
transposed_plan = ConvolutionPlan.from_plan_transposed(forward_plan)
property has_fixed_topology: bool

Returns True if the source and target grids have the same topology, meaning the same voxel structure. This is often required by certain backends.

Returns:

has_fixed_topology (bool) – True if source and target grids are the same topology, False otherwise.

property source_grid: Grid

Return the fvdb.Grid representing the source domain of the convolution, or raise an error if the plan was created for a batch of grids.

Returns:

source_grid (Grid) – The source fvdb.Grid of the convolution plan.

Raises:

ValueError – If the plan was created for a batch of grids.

property source_grid_batch: GridBatch

Return the fvdb.GridBatch representing the source domain of the convolution. If the plan was created for a single grid, it is returned as a batch of size 1.

Returns:

source_grid_batch (GridBatch) – The source fvdb.GridBatch of the convolution plan.

property target_grid: Grid

Return the fvdb.Grid representing the target domain of the convolution, or raise an error if the plan was created for a batch of grids.

Returns:

target_grid (Grid) – The target fvdb.Grid of the convolution plan.

Raises:

ValueError – If the plan was created for a batch of grids.

property target_grid_batch: GridBatch

Return the fvdb.GridBatch representing the target domain of the convolution. If the plan was created for a single grid, it is returned as a batch of size 1.

Returns:

target_grid_batch (GridBatch) – The target fvdb.GridBatch of the convolution plan.

valid_usage(in_channels: int, out_channels: int, kernel_size: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, stride: Tensor | ndarray | int | float | integer | floating | Sequence[int | float | integer | floating] | Size, transposed: bool) bool[source]

Check if this ConvolutionPlan is valid for the given usage. This method returns True if the plan can apply a (transposed) convolution with the given kernel_size and stride from in_channels to out_channels.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (NumericMaxRank1) – Kernel size. Can be a single int or 3-element sequence.

  • stride (NumericMaxRank1) – Stride. Can be a single int or 3-element sequence.

  • transposed (bool) – Whether the plan is transposed.

Returns:

is_valid (bool) – True if the plan is valid for the given configuration, False otherwise.