JaggedTensor
- class fvdb.JaggedTensor(tensors: Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]] | None = None, *, impl: JaggedTensor | None = None)[source]
A jagged (ragged) tensor data structure with support for efficient operations.
JaggedTensorrepresents sequences of tensors with varying lengths, stored efficiently in a flat contiguous format with associated index/offset structures. This is useful for batch processing of variable-length sequences on the GPU while maintaining memory efficiency and enabling vectorized operations.A
JaggedTensorcan represent:1. A sequence of tensors with varying shapes along the first dimension. These are usually written as
[tensor_1, tensor_2, ..., tensor_N]where eachtensor_ican have a different shape along the first dimension.2. Nested sequences (list of lists) with varying lengths at multiple levels. These are usually written as
[[tensor_11, tensor_12, ...], [tensor_21, tensor_22, ...], ...]where both the outer and inner sequences can have varying lengths, and eachtensor_ijcan have a different shape along the first dimension.The
JaggedTensordata structure consists of the following components:JaggedTensor integrates with PyTorch through __torch_function__, allowing many torch operations to work directly on jagged tensors while preserving the jagged structure. Operations that preserve the leading (flattened) dimension work seamlessly, while shape-changing operations require specialized j* methods.
Example usage:
# Create a JaggedTensor from a list of tensors jt = JaggedTensor.from_list_of_tensors([torch.randn(3, 4), torch.randn(2, 4), torch.randn(5, 4)])
# Perform element-wise operations jt2 = jt + 1.0 jt3 = torch.relu(jt2)
# Access jagged data and structure data = jt3.jdata offsets = jt3.joffsets
# Get the first tensor in the jagged sequence first_tensor = jt3[0]
# Get the last tensor in the jagged sequence last_tensor = jt3[-1]
Note
The
JaggedTensorshould be constructed using the explicit classmethods: -from_tensor()for a single tensor -from_list_of_tensors()for a list of tensors -from_list_of_lists_of_tensors()for nested lists of tensors -from_data_and_indices()for pre-computed flat format -from_data_and_offsets()for pre-computed flat format with offsets- abs() JaggedTensor[source]
Compute the absolute value element-wise.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith absolute values.
- abs_() JaggedTensor[source]
Compute the absolute value element-wise in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- ceil() JaggedTensor[source]
Round elements up to the nearest integer.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith ceiling applied.
- ceil_() JaggedTensor[source]
Round elements up to the nearest integer in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- clone() JaggedTensor[source]
Create a deep copy of the JaggedTensor.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith copied data and structure.
- cpu() JaggedTensor[source]
Move the JaggedTensor to CPU memory.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensoron CPU device.
- cuda() JaggedTensor[source]
Move the JaggedTensor to CUDA (GPU) memory.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensoron CUDA device.
- detach() JaggedTensor[source]
Detach the JaggedTensor from the autograd graph.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensordetached from the computation graph.
- property device: device
Device where this
JaggedTensoris stored.- Returns:
torch.device – The device of this
JaggedTensor.
- double() JaggedTensor[source]
Convert elements to double (float64) dtype.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith double precision.
- property dtype: dtype
Data type of the elements in this
JaggedTensor.- Returns:
torch.dtype – The data type of this
JaggedTensor.
- property edim: int
Dimensionality of the element (regular) structure.
For example, if each tensor in the jagged sequence has shape
(?, 4, 5), thenedimwill be2since there are two regular dimensions per element.- Returns:
int – The dimensionality of the element structure.
- property eshape: list[int]
Shape of the element dimensions.
For example, if each tensor in the jagged sequence has shape
(?, 4, 5), theneshapewill be[4, 5].- Returns:
list[int] – The shape of the element dimensions.
- float() JaggedTensor[source]
Convert elements to float (float32) dtype.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith float32 precision.
- floor() JaggedTensor[source]
Round elements down to the nearest integer.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith floor applied.
- floor_() JaggedTensor[source]
Round elements down to the nearest integer in-place.
- Returns:
jagged_tensor (JaggedTensor) – The modified
JaggedTensor(self).
- classmethod from_data_and_indices(data: Tensor, indices: Tensor, num_tensors: int) JaggedTensor[source]
Create a
JaggedTensorfrom flattened data and per-element indices.Example
data = torch.tensor([1, 2, 3, 4, 5, 6]) indices = torch.tensor([0, 0, 1, 1, 1, 2])
jt = JaggedTensor.from_data_and_indices(data, indices, num_tensors=3)
# jt represents: # - tensor 0: [1, 2] # - tensor 1: [3, 4, 5] # - tensor 2: [6]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).indices (torch.Tensor) – Index tensor mapping each element to its parent tensor. Shape:
(total_elements,). Values in range[0, num_tensors).num_tensors (int) – Total number of tensors in the sequence.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorconstructed from the data and indices.
- classmethod from_data_and_offsets(data: Tensor, offsets: Tensor) JaggedTensor[source]
Create a
JaggedTensorfrom flattened data and offset array.Offsets define boundaries between tensors in the flattened data array. Tensor
icontains elementsdata[offsets[i]:offsets[i+1]].Example:
data = torch.tensor([1, 2, 3, 4, 5, 6]) offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6] jt = JaggedTensor.from_data_and_offsets(data, offsets) # jt represents: # - tensor 0: [1, 2] # - tensor 1: [3, 4, 5] # - tensor 2: [6]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).offsets (torch.Tensor) – Offset tensor marking tensor boundaries. Shape:
(num_tensors + 1,). Must be monotonically increasing.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorconstructed from thedataandoffsets.
- classmethod from_data_indices_and_list_ids(data: Tensor, indices: Tensor, list_ids: Tensor, num_tensors: int) JaggedTensor[source]
Create a nested JaggedTensor from data, indices, and list IDs.
Creates a multi-level jagged structure where list_ids provide an additional level of grouping beyond the basic indices.
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape: (total_elements, …).
indices (torch.Tensor) – Index tensor mapping each element to its tensor. Shape: (total_elements,).
list_ids (torch.Tensor) – List ID tensor for nested structure. Shape: (total_elements,).
num_tensors (int) – Total number of tensors.
- Returns:
jagged_tensor (JaggedTensor) – A new JaggedTensor with nested jagged structure.
- classmethod from_data_offsets_and_list_ids(data: Tensor, offsets: Tensor, list_ids: Tensor) JaggedTensor[source]
Create a nested
JaggedTensorfrom data, offsets, and list IDs.The offsets are used to define boundaries between tensors in the flattened array, and the list ids provide an additional level of grouping.
Example:
data = torch.tensor([1, 2, 3, 4, 5, 6]) offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6] list_ids = torch.tensor([[0, 0], [0, 1], [1, 0]]) # First two tensors in list 0, last in list 1 jt = JaggedTensor.from_data_offsets_and_list_ids(data, offsets, list_ids) # jt represents the structure [[t_00, t_01], [t_10]] # where t_00 = [1, 2], t_01 = [3, 4, 5], t_10 = [6]
- Parameters:
data (torch.Tensor) – Flattened data tensor containing all elements. Shape:
(total_elements, ...).offsets (torch.Tensor) – Offset tensor marking tensor boundaries. Shape:
(num_tensors + 1,).list_ids (torch.Tensor) – List ID tensor for nested structure. Shape:
(num_tensors, 2).
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith nested jagged structure.
- classmethod from_list_of_lists_of_tensors(tensors: Sequence[Sequence[Tensor]]) JaggedTensor[source]
Create a
JaggedTensorfrom a nested sequences oftorch.Tensors.Creates a multi-level jagged structure where both outer and inner sequences can have varying lengths.
- Parameters:
tensors (Sequence[Sequence[torch.Tensor]]) – Nested list/tuple of
torch.Tensors.- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith nested jagged structure.
- classmethod from_list_of_tensors(tensors: Sequence[Tensor]) JaggedTensor[source]
Create a
JaggedTensorfrom a sequence of tensors with varying first dimensions.All tensors must have the same shape except for the first dimension, which can vary. e.g.
[tensor_1, tensor_2, ..., tensor_N]where eachtensor_ihas shape(L_i, D_1, D_2, ...)with varyingL_i.- Parameters:
tensors (Sequence[torch.Tensor]) – List or tuple of
torch.Tensorwith compatible shapes.- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorcontaining the sequence of tensors.
- classmethod from_tensor(data: Tensor) JaggedTensor[source]
Create a
JaggedTensorfrom a singletorch.Tensor.- Parameters:
data (torch.Tensor) – The input tensor.
- Returns:
jagged_tensor (JaggedTensor) – A new JaggedTensor wrapping the input tensor.
- int() JaggedTensor[source]
Convert elements to int (int32) dtype.
- Returns:
JaggedTensor – A new
JaggedTensorwith int32 dtype.
- property is_cpu: bool
Whether this
JaggedTensoris stored on the CPU.- Returns:
bool –
Trueif on CPU,Falseotherwise.
- property is_cuda: bool
Whether this
JaggedTensoris stored on a CUDA device.- Returns:
bool –
Trueif on CUDA,Falseotherwise.
- jagged_like(data: Tensor) JaggedTensor[source]
Create a new JaggedTensor with the same structure but different data.
The new JaggedTensor will have the same jagged structure
(joffsets, jidx, etc.)as the current one, but with newjdatavalues.- Parameters:
data (torch.Tensor) – New data tensor with compatible shape. Must have the same leading dimension as self.jdata.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith the same structure but new data.
- property jdata: Tensor
Flattened data tensor containing all elements in this JaggedTensor.
For example, if this
JaggedTensorrepresents three tensors of shapes(2, 4),(3, 4), and(1, 4), thenjdatawill have shape(6, 4).- Returns:
torch.Tensor – The data tensor.
- jflatten(dim: int = 0) JaggedTensor[source]
Flatten the jagged dimensions starting from the specified dimension.
Example
# Original jagged tensor with 2 jagged dimensions # representing a tensor of shape [ [ t_00, t_01, … ], [ t_b0, t_b1, … ] ] jt = JaggedTensor.from_list_of_lists_of_tensors(…)
# Flatten starting from dim=0 jt_flat = jt.jflatten(dim=0)
# jt_flat is now a jagged tensor with 1 jagged dimension and represents # [ t_00, t_01, …, t_b0, t_b1, … ]
- Parameters:
dim (int) – The dimension from which to start flattening. Defaults to 0.
- Returns:
jagged_tensor (JaggedTensor) – A new
JaggedTensorwith flattened jagged structure.
- property jidx: Tensor
Indices for each element in the jagged structure. This maps each element in the
jdatatensor to its corresponding position in the jagged layout.Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``jidx`` tensor would be: ``tensor([0, 1, 0, 1, 2, 0])``. jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.jidx) # Output: tensor([0, 1, 0, 1, 2, 0])
- Returns:
torch.Tensor – The jagged indices tensor.
- property jlidx: Tensor
List indices for nested jagged structures. This is a
torch.Tensorthat maps each element in thejdatatensor to its corresponding list in the jagged layout.Example:
# For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the jlidx tensor would be: tensor([0, 0], [0, 1], [1, 0]). jt = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt.jlidx) # Output: tensor([[0, 0], [0, 1], [1, 0]])
- Returns:
torch.Tensor – The jagged list indices tensor.
- jmax(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]
Compute the maximum along a dimension of each tensor in the jagged structure.
Returns both the maximum values and the indices where they occur.
Example
# Create a jagged tensor from a list of tensors of each of shape (L_i, D) jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])
# Compute the maximum along the jagged dimension (dim=0) values, indices = jt.jmax(dim=0)
# values is now a jagged tensor containing the maximum values from each tensor # along dim=0 # this is equivalent to (but faster than): # values = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).values for t in [t1, t2, t3]]) # indices = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).indices for t in [t1, t2, t3]])
- Parameters:
dim (int) – The dimension along which to compute max for each tensor. Defaults to 0.
keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.
- Returns:
values (JaggedTensor) – A
JaggedTensorcontaining the maximum values.indices (JaggedTensor) – A
JaggedTensorcontaining the indices of the maximum values.
- jmin(dim: int = 0, keepdim: bool = False) list[JaggedTensor][source]
Compute the minimum along a dimension of each tensor in the jagged structure.
Returns both the minimum values and the indices where they occur.
Example:
# Create a jagged tensor from a list of tensors of each of shape (L_i, D) jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3]) # Compute the minimum along the jagged dimension (dim=0) values, indices = jt.jmin(dim=0) # values is now a jagged tensor containing the minimum values from each tensor # along dim=0 # this is equivalent to (but faster than): # values = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).values for t in [t1, t2, t3]]) # indices = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).indices for t in [t1, t2, t3]])
- Parameters:
values (JaggedTensor) – A
JaggedTensorcontaining the minimum values.indices (JaggedTensor) – A
JaggedTensorcontaining the indices of the minimum values.
- Returns:
list[JaggedTensor] – A list containing [values, indices] as JaggedTensors.
- property joffsets: Tensor
Offsets marking boundaries between tensors.
Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``joffsets`` tensor would be: ``tensor([0, 2, 5, 6])``. jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.joffsets) # Output: tensor([0, 2, 5, 6]) # For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the joffsets tensor would be: tensor([0, 2, 3, 7]). jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt_ll.joffsets) # Output: tensor([0, 2, 3, 7])
- Returns:
torch.Tensor – The jagged offsets tensor.
- jreshape(lshape: Sequence[int] | Sequence[Sequence[int]]) JaggedTensor[source]
Reshape the jagged dimensions to new sizes.
- Parameters:
lshape (Sequence[int] | Sequence[Sequence[int]]) – New shape(s) for jagged dimensions. Can be a single sequence of sizes or nested sequences for multi-level structure.
- Returns:
JaggedTensor – A new JaggedTensor with reshaped jagged structure.
- jreshape_as(other: JaggedTensor | Tensor) JaggedTensor[source]
Reshape the jagged structure to match another JaggedTensor or Tensor.
- Parameters:
other (JaggedTensor | torch.Tensor) – The target structure to match.
- Returns:
JaggedTensor – A new JaggedTensor with structure matching other.
- jsqueeze(dim: int | None = None) JaggedTensor[source]
Remove singleton dimensions from the jagged structure.
- Parameters:
dim (int | None) – Specific dimension to squeeze, or None to squeeze all singleton dimensions. Defaults to None.
- Returns:
JaggedTensor – A new JaggedTensor with singleton dimensions removed.
- jsum(dim: int = 0, keepdim: bool = False) JaggedTensor[source]
Sum along a jagged dimension.
- Parameters:
dim (int) – The jagged dimension along which to sum. Defaults to 0.
keepdim (bool) – Whether to keep the reduced dimension. Defaults to False.
- Returns:
JaggedTensor – A new JaggedTensor with values summed along the specified dimension.
- property ldim: int
Dimensionality of the jagged (leading) structure. i.e. the number of jagged levels.
If the
JaggedTensorrepresents a simple jagged structure (a single list of tensors), thenldimwill be1. For nested jagged structures (lists of lists of tensors),ldimwill be greater than1.- Returns:
int – The dimensionality of the jagged structure.
- long() JaggedTensor[source]
Convert elements to long (int64) dtype.
- Returns:
JaggedTensor – A new JaggedTensor with int64 dtype.
- property lshape: list[int] | list[list[int]]
List structure shape(s) of the jagged dimensions.
Example:
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4), # the ``lshape`` will be: ``[2, 3, 4]`` (three tensors in the jagged structure). jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)]) print(jt.lshape) # Output: [2, 3, 1] # For a JaggedTensor representing two lists of tensors: # List 0: tensors of shapes (2, 3) and (1, 3) # List 1: tensor of shape (4, 3) # the ``lshape`` will be: ``[[2, 1], [4]]``. jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]]) print(jt_ll.lshape) # Output: [[2, 1], [4]]
- Returns:
list[int] | list[list[int]] – The jagged structure shapes.
- property num_tensors: int
Return the total number of tensors in the jagged sequence.
- Returns:
int – Number of tensors in this
JaggedTensor.
- property requires_grad: bool
Whether this
JaggedTensorrequires gradient computation.- Returns:
bool –
Trueif gradients are tracked, False otherwise.
- requires_grad_(requires_grad: bool) JaggedTensor[source]
Set the requires_grad attribute in-place.
- Parameters:
requires_grad (bool) – Whether to track gradients for this tensor.
- Returns:
JaggedTensor – The modified JaggedTensor (self).
- rmask(mask: Tensor) JaggedTensor[source]
Apply a mask to filter elements along the regular (non-jagged) dimension.
- Parameters:
mask (torch.Tensor) – Boolean mask tensor to apply. Shape must be compatible with the regular dimensions.
- Returns:
JaggedTensor – A new
JaggedTensorwith masked elements.
- round(decimals: int = 0) JaggedTensor[source]
Round elements to the specified number of decimals.
- Parameters:
decimals (int) – Number of decimal places to round to. Defaults to 0.
- Returns:
JaggedTensor – A new
JaggedTensorwith rounded values.
- round_(decimals: int = 0) JaggedTensor[source]
Round elements to the specified number of decimals in-place.
- Parameters:
decimals (int) – Number of decimal places to round to. Defaults to 0.
- Returns:
JaggedTensor – The modified
JaggedTensor(self).
- property rshape: tuple[int, ...]
Return the shape of the
jdatatensor.Note
rshapestands for “raw shape” and represents the full shape of the underlying data tensor, including both jagged and regular dimensions.- Returns:
tuple[int, …] – Shape of the underlying data tensor.
- sqrt() JaggedTensor[source]
Compute the square root element-wise.
- Returns:
JaggedTensor – A new
JaggedTensorwith square root applied.
- sqrt_() JaggedTensor[source]
Compute the square root element-wise in-place.
- Returns:
JaggedTensor – The modified
JaggedTensor(self).
- to(device_or_dtype: device | str | dtype) JaggedTensor[source]
Move the JaggedTensor to a device or convert to a dtype.
- Parameters:
device_or_dtype (torch.device | str | torch.dtype) – Target
torch.deviceortorch.dtype. Can be a device (“cpu”, “cuda”), or a dtype (torch.float32, etc.).- Returns:
JaggedTensor – A new
JaggedTensoron the specified device or with specified dtype.
- type(dtype: dtype) JaggedTensor[source]
Convert the
JaggedTensorto a specific dtype.- Parameters:
dtype (torch.dtype) – Target data type (e.g.
torch.float32,torch.int64).- Returns:
JaggedTensor – A new
JaggedTensorwith the specified dtype.
- type_as(other: JaggedTensor | Tensor) JaggedTensor[source]
Convert the
JaggedTensorto match the dtype of another tensor.- Parameters:
other (JaggedTensor | torch.Tensor) – Reference
torch.TensororJaggedTensorwhose dtype to match.- Returns:
JaggedTensor – A new
JaggedTensorwith dtype matching other.
- unbind() list[Tensor] | list[list[Tensor]][source]
Unbind the
JaggedTensorinto its constituent tensors.- Returns:
list[torch.Tensor] | list[list[torch.Tensor]] – A list of
torch.Tensor(for simple jagged structure) or a list of lists oftorch.Tensor(for nested structure).