JaggedTensor

class fvdb.JaggedTensor(tensors: Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]] | None = None, *, impl: JaggedTensor | None = None)[source]

A jagged (ragged) tensor data structure with support for efficient operations.

JaggedTensor represents sequences of tensors with varying lengths, stored efficiently in a flat contiguous format with associated index/offset structures. This is useful for batch processing of variable-length sequences on the GPU while maintaining memory efficiency and enabling vectorized operations.

A JaggedTensor can represent:

1. A sequence of tensors with varying shapes along the first dimension. These are usually written as [tensor_1, tensor_2, ..., tensor_N] where each tensor_i can have a different shape along the first dimension.

2. Nested sequences (list of lists) with varying lengths at multiple levels. These are usually written as [[tensor_11, tensor_12, ...], [tensor_21, tensor_22, ...], ...] where both the outer and inner sequences can have varying lengths, and each tensor_ij can have a different shape along the first dimension.

The JaggedTensor data structure consists of the following components:

JaggedTensor integrates with PyTorch through __torch_function__, allowing many torch operations to work directly on jagged tensors while preserving the jagged structure. Operations that preserve the leading (flattened) dimension work seamlessly, while shape-changing operations require specialized j* methods.

Example usage:

# Create a JaggedTensor from a list of tensors jt = JaggedTensor.from_list_of_tensors([torch.randn(3, 4), torch.randn(2, 4), torch.randn(5, 4)])

# Perform element-wise operations jt2 = jt + 1.0 jt3 = torch.relu(jt2)

# Access jagged data and structure data = jt3.jdata offsets = jt3.joffsets

# Get the first tensor in the jagged sequence first_tensor = jt3[0]

# Get the last tensor in the jagged sequence last_tensor = jt3[-1]

Note

The JaggedTensor should be constructed using the explicit classmethods: - from_tensor() for a single tensor - from_list_of_tensors() for a list of tensors - from_list_of_lists_of_tensors() for nested lists of tensors - from_data_and_indices() for pre-computed flat format - from_data_and_offsets() for pre-computed flat format with offsets

abs() JaggedTensor[source]

Compute the absolute value element-wise.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with absolute values.

abs_() JaggedTensor[source]

Compute the absolute value element-wise in-place.

Returns:

jagged_tensor (JaggedTensor) – The modified JaggedTensor (self).

ceil() JaggedTensor[source]

Round elements up to the nearest integer.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with ceiling applied.

ceil_() JaggedTensor[source]

Round elements up to the nearest integer in-place.

Returns:

jagged_tensor (JaggedTensor) – The modified JaggedTensor (self).

clone() JaggedTensor[source]

Create a deep copy of the JaggedTensor.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with copied data and structure.

cpu() JaggedTensor[source]

Move the JaggedTensor to CPU memory.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor on CPU device.

cuda() JaggedTensor[source]

Move the JaggedTensor to CUDA (GPU) memory.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor on CUDA device.

detach() JaggedTensor[source]

Detach the JaggedTensor from the autograd graph.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor detached from the computation graph.

property device: device

Device where this JaggedTensor is stored.

Returns:

torch.device – The device of this JaggedTensor.

double() JaggedTensor[source]

Convert elements to double (float64) dtype.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with double precision.

property dtype: dtype

Data type of the elements in this JaggedTensor.

Returns:

torch.dtype – The data type of this JaggedTensor.

property edim: int

Dimensionality of the element (regular) structure.

For example, if each tensor in the jagged sequence has shape (?, 4, 5), then edim will be 2 since there are two regular dimensions per element.

Returns:

int – The dimensionality of the element structure.

property eshape: list[int]

Shape of the element dimensions.

For example, if each tensor in the jagged sequence has shape (?, 4, 5), then eshape will be [4, 5].

Returns:

list[int] – The shape of the element dimensions.

float() JaggedTensor[source]

Convert elements to float (float32) dtype.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with float32 precision.

floor() JaggedTensor[source]

Round elements down to the nearest integer.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with floor applied.

floor_() JaggedTensor[source]

Round elements down to the nearest integer in-place.

Returns:

jagged_tensor (JaggedTensor) – The modified JaggedTensor (self).

classmethod from_data_and_indices(data: Tensor, indices: Tensor, num_tensors: int) JaggedTensor[source]

Create a JaggedTensor from flattened data and per-element indices.

Example

data = torch.tensor([1, 2, 3, 4, 5, 6]) indices = torch.tensor([0, 0, 1, 1, 1, 2])

jt = JaggedTensor.from_data_and_indices(data, indices, num_tensors=3)

# jt represents: # - tensor 0: [1, 2] # - tensor 1: [3, 4, 5] # - tensor 2: [6]

Parameters:
  • data (torch.Tensor) – Flattened data tensor containing all elements. Shape: (total_elements, ...).

  • indices (torch.Tensor) – Index tensor mapping each element to its parent tensor. Shape: (total_elements,). Values in range [0, num_tensors).

  • num_tensors (int) – Total number of tensors in the sequence.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor constructed from the data and indices.

classmethod from_data_and_offsets(data: Tensor, offsets: Tensor) JaggedTensor[source]

Create a JaggedTensor from flattened data and offset array.

Offsets define boundaries between tensors in the flattened data array. Tensor i contains elements data[offsets[i]:offsets[i+1]].

Example:

data = torch.tensor([1, 2, 3, 4, 5, 6])

offsets = torch.tensor([0, 2, 5, 6])  # 3 tensors: [0:2], [2:5], [5:6]

jt = JaggedTensor.from_data_and_offsets(data, offsets)

# jt represents:
#  - tensor 0: [1, 2]
#  - tensor 1: [3, 4, 5]
#  - tensor 2: [6]
Parameters:
  • data (torch.Tensor) – Flattened data tensor containing all elements. Shape: (total_elements, ...).

  • offsets (torch.Tensor) – Offset tensor marking tensor boundaries. Shape: (num_tensors + 1,). Must be monotonically increasing.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor constructed from the data and offsets.

classmethod from_data_indices_and_list_ids(data: Tensor, indices: Tensor, list_ids: Tensor, num_tensors: int) JaggedTensor[source]

Create a nested JaggedTensor from data, indices, and list IDs.

Creates a multi-level jagged structure where list_ids provide an additional level of grouping beyond the basic indices.

Parameters:
  • data (torch.Tensor) – Flattened data tensor containing all elements. Shape: (total_elements, …).

  • indices (torch.Tensor) – Index tensor mapping each element to its tensor. Shape: (total_elements,).

  • list_ids (torch.Tensor) – List ID tensor for nested structure. Shape: (total_elements,).

  • num_tensors (int) – Total number of tensors.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with nested jagged structure.

classmethod from_data_offsets_and_list_ids(data: Tensor, offsets: Tensor, list_ids: Tensor) JaggedTensor[source]

Create a nested JaggedTensor from data, offsets, and list IDs.

The offsets are used to define boundaries between tensors in the flattened array, and the list ids provide an additional level of grouping.

Example:

data = torch.tensor([1, 2, 3, 4, 5, 6])
offsets = torch.tensor([0, 2, 5, 6])  # 3 tensors: [0:2], [2:5], [5:6]
list_ids = torch.tensor([[0, 0], [0, 1], [1, 0]]) # First two tensors in list 0, last in list 1

jt = JaggedTensor.from_data_offsets_and_list_ids(data, offsets, list_ids)

# jt represents the structure [[t_00, t_01], [t_10]]
# where t_00 = [1, 2], t_01 = [3, 4, 5], t_10 = [6]
Parameters:
  • data (torch.Tensor) – Flattened data tensor containing all elements. Shape: (total_elements, ...).

  • offsets (torch.Tensor) – Offset tensor marking tensor boundaries. Shape: (num_tensors + 1,).

  • list_ids (torch.Tensor) – List ID tensor for nested structure. Shape: (num_tensors, 2).

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with nested jagged structure.

classmethod from_list_of_lists_of_tensors(tensors: Sequence[Sequence[Tensor]]) JaggedTensor[source]

Create a JaggedTensor from a nested sequences of torch.Tensor s.

Creates a multi-level jagged structure where both outer and inner sequences can have varying lengths.

Parameters:

tensors (Sequence[Sequence[torch.Tensor]]) – Nested list/tuple of torch.Tensor s.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with nested jagged structure.

classmethod from_list_of_tensors(tensors: Sequence[Tensor]) JaggedTensor[source]

Create a JaggedTensor from a sequence of tensors with varying first dimensions.

All tensors must have the same shape except for the first dimension, which can vary. e.g. [tensor_1, tensor_2, ..., tensor_N] where each tensor_i has shape (L_i, D_1, D_2, ...) with varying L_i.

Parameters:

tensors (Sequence[torch.Tensor]) – List or tuple of torch.Tensor with compatible shapes.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor containing the sequence of tensors.

classmethod from_tensor(data: Tensor) JaggedTensor[source]

Create a JaggedTensor from a single torch.Tensor.

Parameters:

data (torch.Tensor) – The input tensor.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor wrapping the input tensor.

int() JaggedTensor[source]

Convert elements to int (int32) dtype.

Returns:

JaggedTensor – A new JaggedTensor with int32 dtype.

property is_cpu: bool

Whether this JaggedTensor is stored on the CPU.

Returns:

boolTrue if on CPU, False otherwise.

property is_cuda: bool

Whether this JaggedTensor is stored on a CUDA device.

Returns:

boolTrue if on CUDA, False otherwise.

jagged_like(data: Tensor) JaggedTensor[source]

Create a new JaggedTensor with the same structure but different data.

The new JaggedTensor will have the same jagged structure (joffsets, jidx, etc.) as the current one, but with new jdata values.

Parameters:

data (torch.Tensor) – New data tensor with compatible shape. Must have the same leading dimension as self.jdata.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with the same structure but new data.

property jdata: Tensor

Flattened data tensor containing all elements in this JaggedTensor.

For example, if this JaggedTensor represents three tensors of shapes (2, 4), (3, 4), and (1, 4), then jdata will have shape (6, 4).

Returns:

torch.Tensor – The data tensor.

jflatten(dim: int = 0) JaggedTensor[source]

Flatten the jagged dimensions starting from the specified dimension.

Example

# Original jagged tensor with 2 jagged dimensions # representing a tensor of shape [ [ t_00, t_01, … ], [ t_b0, t_b1, … ] ] jt = JaggedTensor.from_list_of_lists_of_tensors(…)

# Flatten starting from dim=0 jt_flat = jt.jflatten(dim=0)

# jt_flat is now a jagged tensor with 1 jagged dimension and represents # [ t_00, t_01, …, t_b0, t_b1, … ]

Parameters:

dim (int) – The dimension from which to start flattening. Defaults to 0.

Returns:

jagged_tensor (JaggedTensor) – A new JaggedTensor with flattened jagged structure.

property jidx: Tensor

Indices for each element in the jagged structure. This maps each element in the jdata tensor to its corresponding position in the jagged layout.

Example:

# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``jidx`` tensor would be: ``tensor([0, 1, 0, 1, 2, 0])``.

jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.jidx)  # Output: tensor([0, 1, 0, 1, 2, 0])
Returns:

torch.Tensor – The jagged indices tensor.

property jlidx: Tensor

List indices for nested jagged structures. This is a torch.Tensor that maps each element in the jdata tensor to its corresponding list in the jagged layout.

Example:

# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the jlidx tensor would be: tensor([0, 0], [0, 1], [1, 0]).

jt = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt.jlidx)  # Output: tensor([[0, 0], [0, 1], [1, 0]])
Returns:

torch.Tensor – The jagged list indices tensor.

jmax(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]

Compute the maximum along a dimension of each tensor in the jagged structure.

Returns both the maximum values and the indices where they occur.

Example

# Create a jagged tensor from a list of tensors of each of shape (L_i, D) jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])

# Compute the maximum along the jagged dimension (dim=0) values, indices = jt.jmax(dim=0)

# values is now a jagged tensor containing the maximum values from each tensor # along dim=0 # this is equivalent to (but faster than): # values = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).values for t in [t1, t2, t3]]) # indices = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).indices for t in [t1, t2, t3]])

Parameters:
  • dim (int) – The dimension along which to compute max for each tensor. Defaults to 0.

  • keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.

Returns:
  • values (JaggedTensor) – A JaggedTensor containing the maximum values.

  • indices (JaggedTensor) – A JaggedTensor containing the indices of the maximum values.

jmin(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]

Compute the minimum along a dimension of each tensor in the jagged structure.

Returns both the minimum values and the indices where they occur.

Example:

# Create a jagged tensor from a list of tensors of each of shape (L_i, D)
jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])

# Compute the minimum along the jagged dimension (dim=0)
values, indices = jt.jmin(dim=0)

# values is now a jagged tensor containing the minimum values from each tensor
# along dim=0
# this is equivalent to (but faster than):
# values = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).values for t in [t1, t2, t3]])
# indices = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).indices for t in [t1, t2, t3]])
Parameters:
Returns:

list[JaggedTensor] – A list containing [values, indices] as JaggedTensors.

property joffsets: Tensor

Offsets marking boundaries between tensors.

Example:

# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``joffsets`` tensor would be: ``tensor([0, 2, 5, 6])``.
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.joffsets)  # Output: tensor([0, 2, 5, 6])

# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the joffsets tensor would be: tensor([0, 2, 3, 7]).
jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt_ll.joffsets)  # Output: tensor([0, 2, 3, 7])
Returns:

torch.Tensor – The jagged offsets tensor.

jreshape(lshape: Sequence[int] | Sequence[Sequence[int]]) JaggedTensor[source]

Reshape the jagged dimensions to new sizes.

Parameters:

lshape (Sequence[int] | Sequence[Sequence[int]]) – New shape(s) for jagged dimensions. Can be a single sequence of sizes or nested sequences for multi-level structure.

Returns:

JaggedTensor – A new JaggedTensor with reshaped jagged structure.

jreshape_as(other: JaggedTensor | Tensor) JaggedTensor[source]

Reshape the jagged structure to match another JaggedTensor or Tensor.

Parameters:

other (JaggedTensor | torch.Tensor) – The target structure to match.

Returns:

JaggedTensor – A new JaggedTensor with structure matching other.

jsqueeze(dim: int | None = None) JaggedTensor[source]

Remove singleton dimensions from the jagged structure.

Parameters:

dim (int | None) – Specific dimension to squeeze, or None to squeeze all singleton dimensions. Defaults to None.

Returns:

JaggedTensor – A new JaggedTensor with singleton dimensions removed.

jsum(dim: int = 0, keepdim: bool = False) JaggedTensor[source]

Sum along a jagged dimension.

Parameters:
  • dim (int) – The jagged dimension along which to sum. Defaults to 0.

  • keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.

Returns:

JaggedTensor – A new JaggedTensor with values summed along the specified dimension.

property ldim: int

Dimensionality of the jagged (leading) structure. i.e. the number of jagged levels.

If the JaggedTensor represents a simple jagged structure (a single list of tensors), then ldim will be 1. For nested jagged structures (lists of lists of tensors), ldim will be greater than 1.

Returns:

int – The dimensionality of the jagged structure.

long() JaggedTensor[source]

Convert elements to long (int64) dtype.

Returns:

JaggedTensor – A new JaggedTensor with int64 dtype.

property lshape: list[int] | list[list[int]]

List structure shape(s) of the jagged dimensions.

Example:

# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``lshape`` will be: ``[2, 3, 4]`` (three tensors in the jagged structure).
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.lshape)  # Output: [2, 3, 1]

# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the ``lshape`` will be: ``[[2, 1], [4]]``.
jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt_ll.lshape)  # Output: [[2, 1], [4]]
Returns:

list[int] | list[list[int]] – The jagged structure shapes.

property num_tensors: int

Return the total number of tensors in the jagged sequence.

Returns:

int – Number of tensors in this JaggedTensor.

property requires_grad: bool

Whether this JaggedTensor requires gradient computation.

Returns:

boolTrue if gradients are tracked, False otherwise.

requires_grad_(requires_grad: bool) JaggedTensor[source]

Set the requires_grad attribute in-place.

Parameters:

requires_grad (bool) – Whether to track gradients for this tensor.

Returns:

JaggedTensor – The modified JaggedTensor (self).

rmask(mask: Tensor) JaggedTensor[source]

Apply a mask to filter elements along the regular (non-jagged) dimension.

Parameters:

mask (torch.Tensor) – Boolean mask tensor to apply. Shape must be compatible with the regular dimensions.

Returns:

JaggedTensor – A new JaggedTensor with masked elements.

round(decimals: int = 0) JaggedTensor[source]

Round elements to the specified number of decimals.

Parameters:

decimals (int) – Number of decimal places to round to. Defaults to 0.

Returns:

JaggedTensor – A new JaggedTensor with rounded values.

round_(decimals: int = 0) JaggedTensor[source]

Round elements to the specified number of decimals in-place.

Parameters:

decimals (int) – Number of decimal places to round to. Defaults to 0.

Returns:

JaggedTensor – The modified JaggedTensor (self).

property rshape: tuple[int, ...]

Return the shape of the jdata tensor.

Note

rshape stands for “raw shape” and represents the full shape of the underlying data tensor, including both jagged and regular dimensions.

Returns:

tuple[int, …] – Shape of the underlying data tensor.

sqrt() JaggedTensor[source]

Compute the square root element-wise.

Returns:

JaggedTensor – A new JaggedTensor with square root applied.

sqrt_() JaggedTensor[source]

Compute the square root element-wise in-place.

Returns:

JaggedTensor – The modified JaggedTensor (self).

to(device_or_dtype: device | str | dtype) JaggedTensor[source]

Move the JaggedTensor to a device or convert to a dtype.

Parameters:

device_or_dtype (torch.device | str | torch.dtype) – Target torch.device or torch.dtype. Can be a device (“cpu”, “cuda”), or a dtype (torch.float32, etc.).

Returns:

JaggedTensor – A new JaggedTensor on the specified device or with specified dtype.

type(dtype: dtype) JaggedTensor[source]

Convert the JaggedTensor to a specific dtype.

Parameters:

dtype (torch.dtype) – Target data type (e.g. torch.float32, torch.int64).

Returns:

JaggedTensor – A new JaggedTensor with the specified dtype.

type_as(other: JaggedTensor | Tensor) JaggedTensor[source]

Convert the JaggedTensor to match the dtype of another tensor.

Parameters:

other (JaggedTensor | torch.Tensor) – Reference torch.Tensor or JaggedTensor whose dtype to match.

Returns:

JaggedTensor – A new JaggedTensor with dtype matching other.

unbind() list[Tensor] | list[list[Tensor]][source]

Unbind the JaggedTensor into its constituent tensors.

Returns:

list[torch.Tensor] | list[list[torch.Tensor]] – A list of torch.Tensor (for simple jagged structure) or a list of lists of torch.Tensor (for nested structure).