# Copyright Contributors to the OpenVDB Project
# SPDX-License-Identifier: Apache-2.0
#
"""
Jagged Tensor data structure and operations for FVDB.
This module provides the JaggedTensor class, a specialized data structure for representing
sequences of tensors with varying lengths (jagged or ragged arrays) with efficient GPU support.
Classes:
- JaggedTensor: A jagged tensor data structure with support for efficient operations
Constructors:
- JaggedTensor(): Create from tensors, sequences, or sequences of sequences
- JaggedTensor.from_tensor(): Create from a single tensor
- JaggedTensor.from_list_of_tensors(): Create from a list of tensors
- JaggedTensor.from_list_of_lists_of_tensors(): Create from nested lists of tensors
- JaggedTensor.from_data_and_indices(): Create from flat data and indices
- JaggedTensor.from_data_and_offsets(): Create from flat data and offsets
- JaggedTensor.from_data_indices_and_list_ids(): Create with nested structure
- JaggedTensor.from_data_offsets_and_list_ids(): Create with nested structure using offsets
Module-level factory functions:
- jempty(): Create empty jagged tensor
- jrand(): Create jagged tensor with random values
- jrandn(): Create jagged tensor with normal distribution
- jones(): Create jagged tensor filled with ones
- jzeros(): Create jagged tensor filled with zeros
JaggedTensor supports PyTorch interoperability through __torch_function__, allowing
many torch operations to work seamlessly with jagged data structures.
"""
import typing
from typing import TYPE_CHECKING, Any, Sequence, cast, overload
import numpy as np
import torch
from . import _parse_device_string
from ._Cpp import JaggedTensor as JaggedTensorCpp
from ._Cpp import jempty as jempty_cpp
from ._Cpp import jones as jones_cpp
from ._Cpp import jrand as jrand_cpp
from ._Cpp import jrandn as jrandn_cpp
from ._Cpp import jzeros as jzeros_cpp
from .types import (
DeviceIdentifier,
NumericMaxRank1,
NumericMaxRank2,
ValueConstraint,
resolve_device,
to_Vec3f,
to_Vec3fBatch,
to_Vec3fBatchBroadcastable,
to_Vec3fBroadcastable,
to_Vec3i,
to_Vec3iBatch,
to_Vec3iBatchBroadcastable,
to_Vec3iBroadcastable,
)
if TYPE_CHECKING:
from .grid import Grid
# --- JaggedTensor.__torch_function__ whitelist ---
# Whitelist of torch.<fn> names supported by JaggedTensor.__torch_function__.
# Only include ops that are elementwise or that *preserve* the primary (leading)
# dimension (i.e., the flattened jagged axis).
_JT_TORCH_WHITELIST: set[str] = {
# Unary, elementwise (and their in-place variants where applicable)
"abs",
"abs_",
"neg",
"relu",
"relu_",
"sigmoid",
"tanh",
"silu",
"gelu",
"exp",
"expm1",
"log",
"log1p",
"sqrt",
"rsqrt",
"ceil",
"floor",
"round",
"trunc",
"nan_to_num",
"clamp",
# Binary / ternary, elementwise
"add",
"sub",
"mul",
"div",
"true_divide",
"floor_divide",
"remainder",
"fmod",
"pow",
"maximum",
"minimum",
"fmax",
"fmin",
"eq",
"ne",
"lt",
"le",
"gt",
"ge",
"where",
"lerp",
# Reductions over *non-primary* dims (must keep the leading dim intact)
"sum",
"mean",
"prod",
"amax",
"amin",
"argmax",
"argmin",
"all",
"any",
"norm",
"var",
"std",
}
[docs]
class JaggedTensor:
"""
A jagged (ragged) tensor data structure with support for efficient operations.
:class:`JaggedTensor` represents sequences of tensors with varying lengths, stored efficiently
in a flat contiguous format with associated index/offset structures. This is useful
for batch processing of variable-length sequences on the GPU while maintaining memory
efficiency and enabling vectorized operations.
A :class:`JaggedTensor` can represent:
1. A sequence of tensors with varying shapes along the first dimension.
These are usually written as ``[tensor_1, tensor_2, ..., tensor_N]`` where each
``tensor_i`` can have a different shape along the first dimension.
2. Nested sequences (list of lists) with varying lengths at multiple levels.
These are usually written as ``[[tensor_11, tensor_12, ...], [tensor_21, tensor_22, ...], ...]``
where both the outer and inner sequences can have varying lengths, and each ``tensor_ij``
can have a different shape along the first dimension.
The :class:`JaggedTensor` data structure consists of the following components:
- :attr:`jdata`: The flattened data tensor containing all elements
- Indexing structures (:attr:`jidx`, :attr:`joffsets`, :attr:`jlidx`) to track element boundaries
- Shape information (:attr:`lshape`, :attr:`eshape`, :attr:`rshape`) describing the structure
JaggedTensor integrates with PyTorch through __torch_function__, allowing many
torch operations to work directly on jagged tensors while preserving the jagged
structure. Operations that preserve the leading (flattened) dimension work
seamlessly, while shape-changing operations require specialized j* methods.
Example usage:
# Create a JaggedTensor from a list of tensors
jt = JaggedTensor.from_list_of_tensors([torch.randn(3, 4), torch.randn(2, 4), torch.randn(5, 4)])
# Perform element-wise operations
jt2 = jt + 1.0
jt3 = torch.relu(jt2)
# Access jagged data and structure
data = jt3.jdata
offsets = jt3.joffsets
# Get the first tensor in the jagged sequence
first_tensor = jt3[0]
# Get the last tensor in the jagged sequence
last_tensor = jt3[-1]
.. note::
The :class:`JaggedTensor` should be constructed using the explicit classmethods:
- :meth:`from_tensor()` for a single tensor
- :meth:`from_list_of_tensors()` for a list of tensors
- :meth:`from_list_of_lists_of_tensors()` for nested lists of tensors
- :meth:`from_data_and_indices()` for pre-computed flat format
- :meth:`from_data_and_offsets()` for pre-computed flat format with offsets
"""
def __init__(
self,
tensors: torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence[torch.Tensor]] | None = None,
*,
impl: JaggedTensorCpp | None = None,
) -> None:
"""
Create a JaggedTensor from various input formats.
This constructor accepts multiple input formats for flexibility. For clearer
code, prefer using the explicit from_* classmethods instead.
Args:
tensors (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence[torch.Tensor]] | None):
Input data in one of several formats:
- torch.Tensor: A single tensor (creates jagged tensor with one element)
- Sequence[torch.Tensor]: List/tuple of tensors with varying first dimension
- Sequence[Sequence[torch.Tensor]]: Nested sequences for multi-level jagging
Defaults to None when impl is provided.
impl (JaggedTensorCpp | None): Internal C++ implementation object.
Used internally, should not be provided by users. Defaults to None.
"""
if impl is not None:
if tensors is not None:
raise ValueError("Cannot provide both tensors and impl")
self._impl = impl
else:
if tensors is None:
raise ValueError("Must provide either tensors or impl")
if not isinstance(tensors, (torch.Tensor, list, tuple)):
raise TypeError(
"tensors must be a torch.Tensor or a sequence (or sequence of sequences) of torch.Tensor"
)
# Convert sequences to lists for C++ binding compatibility
if isinstance(tensors, torch.Tensor):
self._impl = JaggedTensorCpp(tensors)
elif isinstance(tensors, (list, tuple)):
# Check if it's a sequence of sequences
if tensors and isinstance(tensors[0], (list, tuple)):
# Convert nested sequences to lists
converted: list[list[torch.Tensor]] = [
list(inner) if isinstance(inner, tuple) else cast(list[torch.Tensor], inner)
for inner in tensors
]
if isinstance(tensors, tuple):
converted = list(converted)
self._impl = JaggedTensorCpp(converted)
else:
# Simple sequence of tensors
converted_flat: list[torch.Tensor] = (
list(tensors) if isinstance(tensors, tuple) else cast(list[torch.Tensor], tensors) # type: ignore
)
self._impl = JaggedTensorCpp(converted_flat)
else:
self._impl = JaggedTensorCpp(tensors)
# ============================================================
# JaggedTensor from_* constructors
# ============================================================
[docs]
@classmethod
def from_tensor(cls, data: torch.Tensor) -> "JaggedTensor":
"""
Create a :class:`JaggedTensor` from a single :class:`torch.Tensor`.
Args:
data (torch.Tensor): The input tensor.
Returns:
jagged_tensor (JaggedTensor): A new JaggedTensor wrapping the input tensor.
"""
return cls(tensors=data)
[docs]
@classmethod
def from_list_of_tensors(cls, tensors: Sequence[torch.Tensor]) -> "JaggedTensor":
"""
Create a :class:`JaggedTensor` from a sequence of tensors with varying first dimensions.
All tensors must have the same shape except for the first dimension, which can vary.
*e.g.* ``[tensor_1, tensor_2, ..., tensor_N]`` where each ``tensor_i`` has shape
``(L_i, D_1, D_2, ...)`` with varying ``L_i``.
Args:
tensors (Sequence[torch.Tensor]): List or tuple of :class:`torch.Tensor` with compatible shapes.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` containing the sequence of tensors.
"""
return cls(tensors=tensors)
[docs]
@classmethod
def from_list_of_lists_of_tensors(cls, tensors: Sequence[Sequence[torch.Tensor]]) -> "JaggedTensor":
"""
Create a :class:`JaggedTensor` from a nested sequences of :class:`torch.Tensor` s.
Creates a multi-level jagged structure where both outer and inner sequences can
have varying lengths.
Args:
tensors (Sequence[Sequence[torch.Tensor]]): Nested list/tuple of :class:`torch.Tensor` s.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with nested jagged structure.
"""
return cls(tensors=tensors)
[docs]
@classmethod
def from_data_and_indices(cls, data: torch.Tensor, indices: torch.Tensor, num_tensors: int) -> "JaggedTensor":
"""
Create a :class:`JaggedTensor` from flattened data and per-element indices.
Example:
data = torch.tensor([1, 2, 3, 4, 5, 6])
indices = torch.tensor([0, 0, 1, 1, 1, 2])
jt = JaggedTensor.from_data_and_indices(data, indices, num_tensors=3)
# jt represents:
# - tensor 0: [1, 2]
# - tensor 1: [3, 4, 5]
# - tensor 2: [6]
Args:
data (torch.Tensor): Flattened data tensor containing all elements.
Shape: ``(total_elements, ...)``.
indices (torch.Tensor): Index tensor mapping each element to its parent tensor.
Shape: ``(total_elements,)``. Values in range ``[0, num_tensors)``.
num_tensors (int): Total number of tensors in the sequence.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` constructed from the data and indices.
"""
return cls(impl=JaggedTensorCpp.from_data_and_indices(data, indices, num_tensors))
[docs]
@classmethod
def from_data_and_offsets(cls, data: torch.Tensor, offsets: torch.Tensor) -> "JaggedTensor":
"""
Create a :class:`JaggedTensor` from flattened data and offset array.
Offsets define boundaries between tensors in the flattened data array.
Tensor ``i`` contains elements ``data[offsets[i]:offsets[i+1]]``.
Example:
.. code-block:: python
data = torch.tensor([1, 2, 3, 4, 5, 6])
offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6]
jt = JaggedTensor.from_data_and_offsets(data, offsets)
# jt represents:
# - tensor 0: [1, 2]
# - tensor 1: [3, 4, 5]
# - tensor 2: [6]
Args:
data (torch.Tensor): Flattened data tensor containing all elements.
Shape: ``(total_elements, ...)``.
offsets (torch.Tensor): Offset tensor marking tensor boundaries.
Shape: ``(num_tensors + 1,)``. Must be monotonically increasing.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` constructed from the ``data`` and ``offsets``.
"""
return cls(impl=JaggedTensorCpp.from_data_and_offsets(data, offsets))
[docs]
@classmethod
def from_data_indices_and_list_ids(
cls, data: torch.Tensor, indices: torch.Tensor, list_ids: torch.Tensor, num_tensors: int
) -> "JaggedTensor":
"""
Create a nested JaggedTensor from data, indices, and list IDs.
Creates a multi-level jagged structure where list_ids provide an additional
level of grouping beyond the basic indices.
Args:
data (torch.Tensor): Flattened data tensor containing all elements.
Shape: (total_elements, ...).
indices (torch.Tensor): Index tensor mapping each element to its tensor.
Shape: (total_elements,).
list_ids (torch.Tensor): List ID tensor for nested structure.
Shape: (total_elements,).
num_tensors (int): Total number of tensors.
Returns:
jagged_tensor (JaggedTensor): A new JaggedTensor with nested jagged structure.
"""
return cls(impl=JaggedTensorCpp.from_data_indices_and_list_ids(data, indices, list_ids, num_tensors))
[docs]
@classmethod
def from_data_offsets_and_list_ids(
cls, data: torch.Tensor, offsets: torch.Tensor, list_ids: torch.Tensor
) -> "JaggedTensor":
"""
Create a nested :class:`JaggedTensor` from data, offsets, and list IDs.
The offsets are used to define boundaries between tensors in the flattened array,
and the list ids provide an additional level of grouping.
Example:
.. code-block:: python
data = torch.tensor([1, 2, 3, 4, 5, 6])
offsets = torch.tensor([0, 2, 5, 6]) # 3 tensors: [0:2], [2:5], [5:6]
list_ids = torch.tensor([[0, 0], [0, 1], [1, 0]]) # First two tensors in list 0, last in list 1
jt = JaggedTensor.from_data_offsets_and_list_ids(data, offsets, list_ids)
# jt represents the structure [[t_00, t_01], [t_10]]
# where t_00 = [1, 2], t_01 = [3, 4, 5], t_10 = [6]
Args:
data (torch.Tensor): Flattened data tensor containing all elements.
Shape: ``(total_elements, ...)``.
offsets (torch.Tensor): Offset tensor marking tensor boundaries.
Shape: ``(num_tensors + 1,)``.
list_ids (torch.Tensor): List ID tensor for nested structure.
Shape: ``(num_tensors, 2)``.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with nested jagged structure.
"""
return cls(impl=JaggedTensorCpp.from_data_offsets_and_list_ids(data, offsets, list_ids))
# ============================================================
# Regular Instance Methods Begin
# ============================================================
[docs]
def abs(self) -> "JaggedTensor":
"""
Compute the absolute value element-wise.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with absolute values.
"""
return JaggedTensor(impl=self._impl.abs())
[docs]
def abs_(self) -> "JaggedTensor":
"""
Compute the absolute value element-wise in-place.
Returns:
jagged_tensor (JaggedTensor): The modified :class:`JaggedTensor` (self).
"""
return JaggedTensor(impl=self._impl.abs_())
[docs]
def ceil(self) -> "JaggedTensor":
"""
Round elements up to the nearest integer.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with ceiling applied.
"""
return JaggedTensor(impl=self._impl.ceil())
[docs]
def ceil_(self) -> "JaggedTensor":
"""
Round elements up to the nearest integer in-place.
Returns:
jagged_tensor (JaggedTensor): The modified :class:`JaggedTensor` (self).
"""
return JaggedTensor(impl=self._impl.ceil_())
[docs]
def clone(self) -> "JaggedTensor":
"""
Create a deep copy of the JaggedTensor.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with copied data and structure.
"""
return JaggedTensor(impl=self._impl.clone())
[docs]
def cpu(self) -> "JaggedTensor":
"""
Move the JaggedTensor to CPU memory.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` on CPU device.
"""
return JaggedTensor(impl=self._impl.cpu())
[docs]
def cuda(self) -> "JaggedTensor":
"""
Move the JaggedTensor to CUDA (GPU) memory.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` on CUDA device.
"""
return JaggedTensor(impl=self._impl.cuda())
[docs]
def detach(self) -> "JaggedTensor":
"""
Detach the JaggedTensor from the autograd graph.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` detached from the computation graph.
"""
return JaggedTensor(impl=self._impl.detach())
[docs]
def double(self) -> "JaggedTensor":
"""
Convert elements to double (float64) dtype.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with double precision.
"""
return JaggedTensor(impl=self._impl.double())
[docs]
def float(self) -> "JaggedTensor":
"""
Convert elements to float (float32) dtype.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with float32 precision.
"""
return JaggedTensor(impl=self._impl.float())
[docs]
def floor(self) -> "JaggedTensor":
"""
Round elements down to the nearest integer.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with floor applied.
"""
return JaggedTensor(impl=self._impl.floor())
[docs]
def floor_(self) -> "JaggedTensor":
"""
Round elements down to the nearest integer in-place.
Returns:
jagged_tensor (JaggedTensor): The modified :class:`JaggedTensor` ``(self)``.
"""
return JaggedTensor(impl=self._impl.floor_())
[docs]
def jagged_like(self, data: torch.Tensor) -> "JaggedTensor":
"""
Create a new JaggedTensor with the same structure but different data.
The new JaggedTensor will have the same jagged structure ``(joffsets, jidx, etc.)``
as the current one, but with new ``jdata`` values.
Args:
data (torch.Tensor): New data tensor with compatible shape.
Must have the same leading dimension as self.jdata.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with the same structure but new data.
"""
return JaggedTensor(impl=self._impl.jagged_like(data))
[docs]
def jflatten(self, dim: int = 0) -> "JaggedTensor":
"""
Flatten the jagged dimensions starting from the specified dimension.
Example:
# Original jagged tensor with 2 jagged dimensions
# representing a tensor of shape [ [ t_00, t_01, ... ], [ t_b0, t_b1, ... ] ]
jt = JaggedTensor.from_list_of_lists_of_tensors(...)
# Flatten starting from dim=0
jt_flat = jt.jflatten(dim=0)
# jt_flat is now a jagged tensor with 1 jagged dimension and represents
# [ t_00, t_01, ..., t_b0, t_b1, ... ]
Args:
dim (int): The dimension from which to start flattening. Defaults to 0.
Returns:
jagged_tensor (JaggedTensor): A new :class:`JaggedTensor` with flattened jagged structure.
"""
return JaggedTensor(impl=self._impl.jflatten(dim))
[docs]
def jmax(self, dim: int = 0, keepdim: bool = False) -> list["JaggedTensor"]:
"""
Compute the maximum along a dimension of each tensor in the jagged structure.
Returns both the maximum values and the indices where they occur.
Example:
# Create a jagged tensor from a list of tensors of each of shape (L_i, D)
jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])
# Compute the maximum along the jagged dimension (dim=0)
values, indices = jt.jmax(dim=0)
# values is now a jagged tensor containing the maximum values from each tensor
# along dim=0
# this is equivalent to (but faster than):
# values = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).values for t in [t1, t2, t3]])
# indices = JaggedTensor.from_list_of_lists_of_tensors([torch.max(t, dim=0).indices for t in [t1, t2, t3]])
Args:
dim (int): The dimension along which to compute max for each tensor. Defaults to 0.
keepdim (bool): Whether to keep the reduced dimension. Defaults to False.
Returns:
values (JaggedTensor): A :class:`JaggedTensor` containing the maximum values.
indices (JaggedTensor): A :class:`JaggedTensor` containing the indices of the maximum values.
"""
return [JaggedTensor(impl=impl) for impl in self._impl.jmax(dim, keepdim)]
[docs]
def jmin(self, dim: int = 0, keepdim: bool = False) -> list["JaggedTensor"]:
"""
Compute the minimum along a dimension of each tensor in the jagged structure.
Returns both the minimum values and the indices where they occur.
Example:
.. code-block:: python
# Create a jagged tensor from a list of tensors of each of shape (L_i, D)
jt = JaggedTensor.from_list_of_lists_of_tensors([t1, t2, t3])
# Compute the minimum along the jagged dimension (dim=0)
values, indices = jt.jmin(dim=0)
# values is now a jagged tensor containing the minimum values from each tensor
# along dim=0
# this is equivalent to (but faster than):
# values = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).values for t in [t1, t2, t3]])
# indices = JaggedTensor.from_list_of_lists_of_tensors([torch.min(t, dim=0).indices for t in [t1, t2, t3]])
Args:
values (JaggedTensor): A :class:`JaggedTensor` containing the minimum values.
indices (JaggedTensor): A :class:`JaggedTensor` containing the indices of the minimum values.
Returns:
list[JaggedTensor]: A list containing [values, indices] as JaggedTensors.
"""
return [JaggedTensor(impl=impl) for impl in self._impl.jmin(dim, keepdim)]
[docs]
def jreshape(self, lshape: Sequence[int] | Sequence[Sequence[int]]) -> "JaggedTensor":
"""
Reshape the jagged dimensions to new sizes.
Args:
lshape (Sequence[int] | Sequence[Sequence[int]]): New shape(s) for jagged dimensions.
Can be a single sequence of sizes or nested sequences for multi-level structure.
Returns:
JaggedTensor: A new JaggedTensor with reshaped jagged structure.
"""
lshape_cpp = _convert_to_list(lshape)
return JaggedTensor(impl=self._impl.jreshape(lshape_cpp))
[docs]
def jreshape_as(self, other: "JaggedTensor | torch.Tensor") -> "JaggedTensor":
"""
Reshape the jagged structure to match another JaggedTensor or Tensor.
Args:
other (JaggedTensor | torch.Tensor): The target structure to match.
Returns:
JaggedTensor: A new JaggedTensor with structure matching other.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl.jreshape_as(other._impl))
else:
if not isinstance(other, torch.Tensor):
raise TypeError("other must be a JaggedTensor or a torch.Tensor")
return JaggedTensor(impl=self._impl.jreshape_as(other))
[docs]
def jsqueeze(self, dim: int | None = None) -> "JaggedTensor":
"""
Remove singleton dimensions from the jagged structure.
Args:
dim (int | None): Specific dimension to squeeze, or None to squeeze all
singleton dimensions. Defaults to None.
Returns:
JaggedTensor: A new JaggedTensor with singleton dimensions removed.
"""
return JaggedTensor(impl=self._impl.jsqueeze(dim))
[docs]
def jsum(self, dim: int = 0, keepdim: bool = False) -> "JaggedTensor":
"""
Sum along a jagged dimension.
Args:
dim (int): The jagged dimension along which to sum. Defaults to 0.
keepdim (bool): Whether to keep the reduced dimension. Defaults to False.
Returns:
JaggedTensor: A new JaggedTensor with values summed along the specified dimension.
"""
return JaggedTensor(impl=self._impl.jsum(dim, keepdim))
[docs]
def long(self) -> "JaggedTensor":
"""
Convert elements to long (int64) dtype.
Returns:
JaggedTensor: A new JaggedTensor with int64 dtype.
"""
return JaggedTensor(impl=self._impl.long())
# FIXME(@chorvath, @fwilliams) Why is this here?
[docs]
def requires_grad_(self, requires_grad: bool) -> "JaggedTensor":
"""
Set the requires_grad attribute in-place.
Args:
requires_grad (bool): Whether to track gradients for this tensor.
Returns:
JaggedTensor: The modified JaggedTensor (self).
"""
return JaggedTensor(impl=self._impl.requires_grad_(requires_grad))
[docs]
def rmask(self, mask: torch.Tensor) -> "JaggedTensor":
"""
Apply a mask to filter elements along the regular (non-jagged) dimension.
Args:
mask (torch.Tensor): Boolean mask tensor to apply.
Shape must be compatible with the regular dimensions.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with masked elements.
"""
return JaggedTensor(impl=self._impl.rmask(mask))
[docs]
def round(self, decimals: int = 0) -> "JaggedTensor":
"""
Round elements to the specified number of decimals.
Args:
decimals (int): Number of decimal places to round to. Defaults to 0.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with rounded values.
"""
return JaggedTensor(impl=self._impl.round(decimals))
[docs]
def round_(self, decimals: int = 0) -> "JaggedTensor":
"""
Round elements to the specified number of decimals in-place.
Args:
decimals (int): Number of decimal places to round to. Defaults to 0.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` ``(self)``.
"""
return JaggedTensor(impl=self._impl.round_(decimals))
[docs]
def sqrt(self) -> "JaggedTensor":
"""
Compute the square root element-wise.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with square root applied.
"""
return JaggedTensor(impl=self._impl.sqrt())
[docs]
def sqrt_(self) -> "JaggedTensor":
"""
Compute the square root element-wise in-place.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` ``(self)``.
"""
return JaggedTensor(impl=self._impl.sqrt_())
[docs]
def to(self, device_or_dtype: torch.device | str | torch.dtype) -> "JaggedTensor":
"""
Move the JaggedTensor to a device or convert to a dtype.
Args:
device_or_dtype (torch.device | str | torch.dtype): Target :class:`torch.device` or :class:`torch.dtype`.
Can be a device ("cpu", "cuda"), or a dtype (torch.float32, etc.).
Returns:
JaggedTensor: A new :class:`JaggedTensor` on the specified device or with specified dtype.
"""
return JaggedTensor(impl=self._impl.to(device_or_dtype))
[docs]
def type(self, dtype: torch.dtype) -> "JaggedTensor":
"""
Convert the :class:`JaggedTensor` to a specific dtype.
Args:
dtype (torch.dtype): Target data type (*e.g.* ``torch.float32``, ``torch.int64``).
Returns:
JaggedTensor: A new :class:`JaggedTensor` with the specified dtype.
"""
return JaggedTensor(impl=self._impl.type(dtype))
[docs]
def type_as(self, other: "JaggedTensor | torch.Tensor") -> "JaggedTensor":
"""
Convert the :class:`JaggedTensor` to match the dtype of another tensor.
Args:
other (JaggedTensor | torch.Tensor): Reference :class:`torch.Tensor` or :class:`JaggedTensor` whose dtype to match.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with dtype matching other.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl.type_as(other._impl))
else:
if not isinstance(other, torch.Tensor):
raise TypeError("other must be a JaggedTensor or a torch.Tensor")
return JaggedTensor(impl=self._impl.type_as(other))
[docs]
def unbind(self) -> list[torch.Tensor] | list[list[torch.Tensor]]:
"""
Unbind the :class:`JaggedTensor` into its constituent tensors.
Returns:
list[torch.Tensor] | list[list[torch.Tensor]]: A list of :class:`torch.Tensor` (for simple
jagged structure) or a list of lists of :class:`torch.Tensor` (for nested structure).
"""
return self._impl.unbind()
def __add__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Add another tensor or scalar element-wise to this :class:`JaggedTensor`.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to add.
Returns:
JaggedTensor: Result of element-wise addition between ``self`` and ``other``.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl + other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl + other)
def __eq__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise equality comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with element-wise comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl == other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl == other)
def __floordiv__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Floor division element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor.
Returns:
JaggedTensor: Result of floor division.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl // other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl // other)
def __ge__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise greater-than-or-equal comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl >= other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl >= other)
def __getitem__(self, index: Any) -> "JaggedTensor":
"""
Index or slice the JaggedTensor. This slices along the outer jagged dimension.
Example:
.. code-block:: python
jt = JaggedTensor.from_list_of_tensors([torch.randn(3, 4), torch.randn(2, 4), torch.randn(5, 4)])
jt0 = jt[0] # First tensor (shape: (3, 4))
jt1_2 = jt[1:3] # Second and third tensors (shape: (2, 4) and (5, 4))
# Equivalent to JaggedTensor([jt[i].jdata[jt[i].jdata > 0]])
jt_masked = jt[jt > 0] # Masked selection
jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
jt_ll0 = jt_ll[0] # First list of tensors [torch.randn(2, 3), torch.randn(1, 3)]
jt_ll1 = jt_ll[1] # Second list of tensors [torch.randn(4, 3)]
jt_ll0_0 = jt_ll0[0] # First tensor in the first list (shape: (2, 3))
jt_ll0_1 = jt_ll0[1] # Second tensor in the first list (shape: (1, 3))
Args:
index (Any): Index, slice, or mask to apply. Can be a JaggedTensor for jagged indexing.
Returns:
JaggedTensor: The indexed/sliced :class:`JaggedTensor`.
"""
if isinstance(index, JaggedTensor):
return JaggedTensor(impl=self._impl[index._impl])
else:
return JaggedTensor(impl=self._impl[index])
def __gt__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise greater-than comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl > other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl > other)
def __iadd__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place addition element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to add.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl += other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl += other
return self
def __ifloordiv__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place floor division element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl //= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl //= other
return self
def __imod__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place modulo operation element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor for modulo.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl %= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl %= other
return self
def __imul__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place multiplication element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to multiply.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl *= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl *= other
return self
def __ipow__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place exponentiation element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Exponent.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl **= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl **= other
return self
def __isub__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place subtraction element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to subtract.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl -= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl -= other
return self
def __itruediv__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
In-place true division element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor.
Returns:
JaggedTensor: The modified :class:`JaggedTensor` (self).
"""
if isinstance(other, JaggedTensor):
self._impl /= other._impl
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
self._impl /= other
return self
def __le__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise less-than-or-equal comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl <= other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl <= other)
def __len__(self) -> int:
"""
Return the number of tensors in the jagged sequence.
Returns:
int: Number of tensors in the :class:`JaggedTensor`.
"""
return len(self._impl)
def __lt__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise less-than comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl < other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl < other)
def __mod__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Modulo operation element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor for modulo.
Returns:
JaggedTensor: Result of modulo operation.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl % other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl % other)
def __mul__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Multiply element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to multiply.
Returns:
JaggedTensor: Result of element-wise multiplication.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl * other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl * other)
def __ne__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Element-wise inequality comparison.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to compare.
Returns:
JaggedTensor: Boolean :class:`JaggedTensor` with comparison results.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl != other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl != other)
def __neg__(self) -> "JaggedTensor":
"""
Negate all elements.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with all elements negated.
"""
return JaggedTensor(impl=-self._impl)
def __pow__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Raise elements to a power element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Exponent.
Returns:
JaggedTensor: Result of exponentiation.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl**other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl**other)
def __sub__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
Subtract element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Value to subtract.
Returns:
JaggedTensor: Result of element-wise subtraction.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl - other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl - other)
def __truediv__(self, other: "torch.Tensor | JaggedTensor | int | float") -> "JaggedTensor":
"""
True division element-wise.
Args:
other (torch.Tensor | JaggedTensor | int | float): Divisor.
Returns:
JaggedTensor: Result of element-wise division.
"""
if isinstance(other, JaggedTensor):
return JaggedTensor(impl=self._impl / other._impl)
else:
if not isinstance(other, (torch.Tensor, int, float)):
raise TypeError("other must be a torch.Tensor, int, or float")
return JaggedTensor(impl=self._impl / other)
def __iter__(self) -> typing.Iterator["JaggedTensor"]:
"""
Iterate over the JaggedTensor, yielding each tensor in the sequence.
.. note::
This iterates over the outer jagged dimension, yielding individual
:class:`JaggedTensor` elements. If this :class:`JaggedTensor` represents a single list of
tensors, each yielded element will be a :class:`JaggedTensor` containing one tensor.
You can access the underlying tensor via the ``.jdata`` attribute of the yielded
:class:`JaggedTensor`.
Returns:
typing.Iterator[JaggedTensor]: Iterator yielding :class:`JaggedTensor` elements.
"""
for i in range(len(self)):
yield self[i]
# ============================================================
# PyTorch interop (__torch_function__)
# ============================================================
@classmethod
def __torch_function__(
cls,
func: Any,
types: tuple,
args: tuple = (),
kwargs: dict | None = None,
) -> Any:
"""
Intercept selected torch.<fn>(...) calls and forward them to the underlying
contiguous storage (`jdata`). The operation is allowed only if the result
preserves the JaggedTensor's primary (leading) dimension. The jagged
layout (offsets/indices) is *not* changed.
Examples:
torch.relu(jt) -> applies relu to jt.jdata (returns JaggedTensor)
torch.add(jt, 1.0) -> elementwise add on jt.jdata (returns JaggedTensor)
torch.sum(jt, dim=-1) -> reduces trailing dim(s) but preserves leading dim
torch.relu_(jt) -> in-place on jt.jdata, returns the mutated JaggedTensor
Unsupported:
- Any op that would change or reduce the leading dimension (e.g., torch.sum(jt) with dim=None)
- Shape-rearranging ops like reshape/permute/transpose/cat/stack, etc. (use the provided j* APIs)
"""
if kwargs is None:
kwargs = {}
# Only participate in dispatch when a JaggedTensor is present.
if not any(issubclass(t, JaggedTensor) for t in types):
return NotImplemented
name = getattr(func, "__name__", None)
if name is None or name not in _JT_TORCH_WHITELIST:
return NotImplemented
# Find a prototype JaggedTensor carrying the jagged structure.
def _find_proto(obj: Any) -> "JaggedTensor | None":
if isinstance(obj, JaggedTensor):
return obj
if isinstance(obj, (list, tuple)):
for x in obj:
jt = _find_proto(x)
if jt is not None:
return jt
return None
proto: "JaggedTensor | None" = None
for o in args:
proto = _find_proto(o)
if proto is not None:
break
if proto is None:
for o in kwargs.values():
proto = _find_proto(o)
if proto is not None:
break
if proto is None:
return NotImplemented
# Unwrap JaggedTensors -> their underlying torch.Tensor (jdata)
def _unwrap(obj: Any) -> Any:
if isinstance(obj, JaggedTensor):
return obj.jdata
if isinstance(obj, (list, tuple)):
typ = type(obj)
return typ(_unwrap(x) for x in obj)
return obj
conv_args = tuple(_unwrap(a) for a in args)
conv_kwargs = {k: _unwrap(v) for k, v in kwargs.items()}
# Handle out= if provided as a JaggedTensor
out_jt: "JaggedTensor | None" = None
if "out" in kwargs:
orig_out = kwargs["out"]
if isinstance(orig_out, JaggedTensor):
out_jt = orig_out
conv_kwargs["out"] = orig_out.jdata
elif isinstance(orig_out, (list, tuple)):
raise TypeError("JaggedTensor: tuple/list form of 'out=' is not supported.")
# Execute the torch operation on raw tensors.
result = func(*conv_args, **conv_kwargs)
N0 = int(proto.jdata.shape[0])
# Wrap torch.Tensor result(s) back into JaggedTensor, verifying the primary dim.
def _wrap(o: Any) -> Any:
if isinstance(o, torch.Tensor):
if o.ndim == 0 or int(o.shape[0]) != N0:
raise RuntimeError(
f"torch.{name} would change the primary jagged dimension "
f"(expected leading dim {N0}, got {tuple(o.shape)})."
)
return proto.jagged_like(o)
if isinstance(o, (list, tuple)):
items = [_wrap(x) for x in o]
if isinstance(o, tuple) and hasattr(o, "_fields"):
# namedtuple (e.g., values/indices from some reductions)
return type(o)(*items)
return type(o)(items)
return o
# In-place variant: mutate proto/out and return the mutated JaggedTensor.
if name.endswith("_"):
if out_jt is not None:
return out_jt
return proto
# If out= was a JaggedTensor, return it after the write.
if out_jt is not None:
return out_jt
return _wrap(result)
# ============================================================
# Properties
# ============================================================
@property
def jdata(self) -> torch.Tensor:
"""
Flattened data tensor containing all elements in this JaggedTensor.
For example, if this :class:`JaggedTensor` represents three tensors of shapes
``(2, 4)``, ``(3, 4)``, and ``(1, 4)``, then ``jdata`` will have shape ``(6, 4)``.
Returns:
torch.Tensor: The data tensor.
"""
return self._impl.jdata
@jdata.setter
def jdata(self, value: torch.Tensor) -> None:
"""
Set the flattened data tensor. The shape must be compatible with the jagged structure.
This operation does not modify the jagged layout (offsets/indices).
Example:
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4)])
print(jt.jdata.shape) # Output: torch.Size([5, 4])
# Update with data of the same shape
new_data = torch.randn(5, 4)
jt.jdata = new_data # Update the data tensor
print(jt.jdata) # Output: new_data tensor
# Update with the same outer shape but different inner shape is okay
new_data_2 = torch.randn(5, 2, 3)
jt.jdata = new_data_2 # Update the data tensor
print(jt.jdata) # Output: new_data_2 tensor
# Update with a completely different shape is not allowed
new_data_3 = torch.randn(4, 4)
jt.jdata = new_data_3 # This will raise an error
Args:
value (torch.Tensor): New ``jdata`` tensor to set.
"""
self._impl.jdata = value
@property
def requires_grad(self) -> bool:
"""
Whether this :class:`JaggedTensor` requires gradient computation.
Returns:
bool: ``True`` if gradients are tracked, False otherwise.
"""
return self._impl.requires_grad
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""
Set whether this :class:`JaggedTensor` requires gradient computation.
Args:
value (bool): ``True`` to require gradients, ``False`` otherwise.
"""
# self._impl.set_requires_grad(value)
self._impl.requires_grad = value
@property
def device(self) -> torch.device:
"""
Device where this :class:`JaggedTensor` is stored.
Returns:
torch.device: The device of this :class:`JaggedTensor`.
"""
return self._impl.device
@property
def dtype(self) -> torch.dtype:
"""
Data type of the elements in this :class:`JaggedTensor`.
Returns:
torch.dtype: The data type of this :class:`JaggedTensor`.
"""
return self._impl.dtype
@property
def edim(self) -> int:
"""
Dimensionality of the element (regular) structure.
For example, if each tensor in the jagged sequence has shape ``(?, 4, 5)``,
then ``edim`` will be ``2`` since there are two regular dimensions per element.
Returns:
int: The dimensionality of the element structure.
"""
return self._impl.edim
@property
def eshape(self) -> list[int]:
"""
Shape of the element dimensions.
For example, if each tensor in the jagged sequence has shape ``(?, 4, 5)``,
then ``eshape`` will be ``[4, 5]``.
Returns:
list[int]: The shape of the element dimensions.
"""
return self._impl.eshape
@property
def is_cpu(self) -> bool:
"""
Whether this :class:`JaggedTensor` is stored on the CPU.
Returns:
bool: ``True`` if on CPU, ``False`` otherwise.
"""
return self._impl.is_cpu
@property
def is_cuda(self) -> bool:
"""
Whether this :class:`JaggedTensor` is stored on a CUDA device.
Returns:
bool: ``True`` if on CUDA, ``False`` otherwise.
"""
return self._impl.is_cuda
@property
def jidx(self) -> torch.Tensor:
"""
Indices for each element in the jagged structure. This maps each element
in the ``jdata`` tensor to its corresponding position in the jagged layout.
Example:
.. code-block:: python
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``jidx`` tensor would be: ``tensor([0, 1, 0, 1, 2, 0])``.
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.jidx) # Output: tensor([0, 1, 0, 1, 2, 0])
Returns:
torch.Tensor: The jagged indices tensor.
"""
return self._impl.jidx
@property
def jlidx(self) -> torch.Tensor:
"""
List indices for nested jagged structures. This is a :class:`torch.Tensor` that
maps each element in the ``jdata`` tensor to its corresponding list in the
jagged layout.
Example:
.. code-block:: python
# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the jlidx tensor would be: tensor([0, 0], [0, 1], [1, 0]).
jt = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt.jlidx) # Output: tensor([[0, 0], [0, 1], [1, 0]])
Returns:
torch.Tensor: The jagged list indices tensor.
"""
return self._impl.jlidx
@property
def joffsets(self) -> torch.Tensor:
"""
Offsets marking boundaries between tensors.
Example:
.. code-block:: python
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``joffsets`` tensor would be: ``tensor([0, 2, 5, 6])``.
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.joffsets) # Output: tensor([0, 2, 5, 6])
# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the joffsets tensor would be: tensor([0, 2, 3, 7]).
jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt_ll.joffsets) # Output: tensor([0, 2, 3, 7])
Returns:
torch.Tensor: The jagged offsets tensor.
"""
return self._impl.joffsets
@property
def ldim(self) -> int:
"""
Dimensionality of the jagged (leading) structure. *i.e.* the number of jagged levels.
If the :class:`JaggedTensor` represents a simple jagged structure (a single list of tensors),
then ``ldim`` will be ``1``. For nested jagged structures (lists of lists of tensors), ``ldim``
will be greater than ``1``.
Returns:
int: The dimensionality of the jagged structure.
"""
return self._impl.ldim
@property
def lshape(self) -> list[int] | list[list[int]]:
"""
List structure shape(s) of the jagged dimensions.
Example:
.. code-block:: python
# For a JaggedTensor representing three tensors of shapes (2, 4), (3, 4), and (1, 4),
# the ``lshape`` will be: ``[2, 3, 4]`` (three tensors in the jagged structure).
jt = JaggedTensor.from_list_of_tensors([torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)])
print(jt.lshape) # Output: [2, 3, 1]
# For a JaggedTensor representing two lists of tensors:
# List 0: tensors of shapes (2, 3) and (1, 3)
# List 1: tensor of shape (4, 3)
# the ``lshape`` will be: ``[[2, 1], [4]]``.
jt_ll = JaggedTensor.from_list_of_lists_of_tensors([[torch.randn(2, 3), torch.randn(1, 3)], [torch.randn(4, 3)]])
print(jt_ll.lshape) # Output: [[2, 1], [4]]
Returns:
list[int] | list[list[int]]: The jagged structure shapes.
"""
return self._impl.lshape
@property
def num_tensors(self) -> int:
"""
Return the total number of tensors in the jagged sequence.
Returns:
int: Number of tensors in this :class:`JaggedTensor`.
"""
return self._impl.num_tensors
@property
def rshape(self) -> tuple[int, ...]:
"""
Return the shape of the ``jdata`` tensor.
.. note::
``rshape`` stands for "raw shape" and represents the full shape of the
underlying data tensor, including both jagged and regular dimensions.
Returns:
tuple[int, ...]: Shape of the underlying data tensor.
"""
return self._impl.rshape
# Weirdly, unless we put this last, it messes up static type checking.
[docs]
def int(self) -> "JaggedTensor":
"""
Convert elements to int (int32) dtype.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with int32 dtype.
"""
return JaggedTensor(impl=self._impl.int())
@overload
def _convert_to_list(seq: Sequence[int]) -> list[int]: ...
@overload
def _convert_to_list(seq: Sequence[Sequence[int]]) -> list[list[int]]: ...
def _convert_to_list(seq: Sequence[int] | Sequence[Sequence[int]]) -> list[int] | list[list[int]]:
"""Helper to convert Sequence types to list types for C++ binding compatibility."""
if isinstance(seq, (list, tuple)):
if seq and isinstance(seq[0], (list, tuple)):
# Nested sequence - convert inner sequences to lists
converted: list[list[int]] = [
list(inner) if isinstance(inner, tuple) else cast(list[int], inner) for inner in seq
]
return list(converted) if isinstance(seq, tuple) else converted
else:
# Simple sequence of ints
return list(seq) if isinstance(seq, tuple) else cast(list[int], seq) # type: ignore
else:
return cast(list[int], seq)
def jempty(
lsizes: Sequence[int] | Sequence[Sequence[int]],
rsizes: Sequence[int] | None = None,
*,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> JaggedTensor:
"""
Create a :class:`JaggedTensor` with uninitialized data.
Similar to :func:`torch.empty()`, creates a :class:`JaggedTensor` with allocated but uninitialized
memory, which is faster than initializing values when they will be immediately
overwritten.
Example:
... code-block:: python
jt = jempty([2, 3, 4], rsizes=[5])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] with uninitialized values.
jt = jempty([[2, 3], [4]], rsizes=[5, 6])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with uninitialized values.
Args:
lsizes (Sequence[int] | Sequence[Sequence[int]]): Sizes for the jagged dimensions.
Can be a sequence of integers for simple jagged structure, or nested sequences
for multi-level jagged structure.
rsizes (Sequence[int] | None): Sizes for the regular (trailing) dimensions.
Defaults to ``None`` *i.e.* scalar elements.
device (torch.device | str | None): Device to create the tensor on.
Defaults to ``None`` *i.e.* ``"cpu"``.
dtype (torch.dtype | None): Data type for the tensor elements.
Defaults to ``None`` *i.e.* ``torch.float32``.
requires_grad (bool): Whether to track gradients. Defaults to ``False``.
pin_memory (bool): Whether to use pinned memory. Defaults to ``False``.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with uninitialized data.
"""
lsizes_cpp: list[int] | list[list[int]] = _convert_to_list(lsizes)
rsizes_cpp: list[int] | None = _convert_to_list(rsizes) if rsizes is not None else None
return JaggedTensor(impl=jempty_cpp(lsizes_cpp, rsizes_cpp, dtype, device, requires_grad, pin_memory))
def jrand(
lsizes: Sequence[int] | Sequence[Sequence[int]],
rsizes: Sequence[int] | None = None,
*,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> JaggedTensor:
"""
Create a :class:`JaggedTensor` with random values from uniform distribution [0, 1).
Similar to :func:`torch.rand()`, creates a :class:`JaggedTensor` filled with random values sampled
from a uniform distribution on the interval [0, 1).
Example:
... code-block:: python
jt = jrand([2, 3, 4], rsizes=[5])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] with random values.
jt = jrand([[2, 3], [4]], rsizes=[5, 6])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with random values.
Args:
lsizes (Sequence[int] | Sequence[Sequence[int]]): Sizes for the jagged dimensions.
Can be a sequence of integers for simple jagged structure, or nested sequences
for multi-level jagged structure.
rsizes (Sequence[int] | None): Sizes for the regular (trailing) dimensions.
Defaults to ``None`` *i.e.* (scalar elements).
device (torch.device | str | None): Device to create the tensor on.
Defaults to ``None`` *i.e.* ``"cpu"``.
dtype (torch.dtype | None): Data type for the tensor elements.
Defaults to ``None`` *i.e.* ``torch.float32``.
requires_grad (bool): Whether to track gradients. Defaults to ``False``.
pin_memory (bool): Whether to use pinned memory. Defaults to ``False``.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with random values in [0, 1).
"""
lsizes_cpp: list[int] | list[list[int]] = _convert_to_list(lsizes)
rsizes_cpp: list[int] | None = _convert_to_list(rsizes) if rsizes is not None else None
return JaggedTensor(impl=jrand_cpp(lsizes_cpp, rsizes_cpp, dtype, device, requires_grad, pin_memory))
def jrandn(
lsizes: Sequence[int] | Sequence[Sequence[int]],
rsizes: Sequence[int] | None = None,
*,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> JaggedTensor:
"""
Create a :class:`JaggedTensor` with random values from standard normal distribution.
Similar to :func:`torch.randn()`, creates a :class:`JaggedTensor` filled with random values sampled
from a standard normal distribution (mean=0, std=1).
Example:
... code-block:: python
jt = jrandn([2, 3, 4], rsizes=[5])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] with normal random values.
jt = jrandn([[2, 3], [4]], rsizes=[5, 6])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] with normal random values.
Args:
lsizes (Sequence[int] | Sequence[Sequence[int]]): Sizes for the jagged dimensions.
Can be a sequence of integers for simple jagged structure, or nested sequences
for multi-level jagged structure.
rsizes (Sequence[int] | None): Sizes for the regular (trailing) dimensions.
Defaults to ``None`` *i.e.* (scalar elements).
device (torch.device | str | None): Device to create the tensor on.
Defaults to ``None`` *i.e.* ``"cpu"``.
dtype (torch.dtype | None): Data type for the tensor elements.
Defaults to ``None`` *i.e.* ``torch.float32``.
requires_grad (bool): Whether to track gradients. Defaults to ``False``.
pin_memory (bool): Whether to use pinned memory. Defaults to ``False``.
Returns:
JaggedTensor: A new :class:`JaggedTensor` with normal random values.
"""
lsizes_cpp: list[int] | list[list[int]] = _convert_to_list(lsizes)
rsizes_cpp: list[int] | None = _convert_to_list(rsizes) if rsizes is not None else None
return JaggedTensor(impl=jrandn_cpp(lsizes_cpp, rsizes_cpp, dtype, device, requires_grad, pin_memory))
def jones(
lsizes: Sequence[int] | Sequence[Sequence[int]],
rsizes: Sequence[int] | None = None,
*,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> JaggedTensor:
"""
Create a :class:`JaggedTensor` filled with ones.
Similar to :func:`torch.ones()`, creates a :class:`JaggedTensor` where all elements are initialized
to the value 1.
Example:
... code-block:: python
jt = jones([2, 3, 4], rsizes=[5])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] filled with ones.
jt = jones([[2, 3], [4]], rsizes=[5, 6])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] filled with ones.
Args:
lsizes (Sequence[int] | Sequence[Sequence[int]]): Sizes for the jagged dimensions.
Can be a sequence of integers for simple jagged structure, or nested sequences
for multi-level jagged structure.
rsizes (Sequence[int] | None): Sizes for the regular (trailing) dimensions.
Defaults to ``None`` *i.e.* (scalar elements).
device (torch.device | str | None): Device to create the tensor on.
Defaults to ``None`` *i.e.* (CPU).
dtype (torch.dtype | None): Data type for the tensor elements.
Defaults to ``None`` *i.e.* (torch.float32).
requires_grad (bool): Whether to track gradients. Defaults to ``False``.
pin_memory (bool): Whether to use pinned memory. Defaults to ``False``.
Returns:
JaggedTensor: A new :class:`JaggedTensor` filled with ones.
"""
lsizes_cpp: list[int] | list[list[int]] = _convert_to_list(lsizes)
rsizes_cpp: list[int] | None = _convert_to_list(rsizes) if rsizes is not None else None
return JaggedTensor(impl=jones_cpp(lsizes_cpp, rsizes_cpp, dtype, device, requires_grad, pin_memory))
def jzeros(
lsizes: Sequence[int] | Sequence[Sequence[int]],
rsizes: Sequence[int] | None = None,
*,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> JaggedTensor:
"""
Create a :class:`JaggedTensor` filled with zeros.
Similar to :func:`torch.zeros()`, creates a :class:`JaggedTensor` where all elements are initialized
to the value 0.
Example:
jt = jzeros([2, 3, 4], rsizes=[5])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5), (3, 5), (4, 5)] filled with zeros
jt = jzeros([[2, 3], [4]], rsizes=[5, 6])
print(jt) # Output: A JaggedTensor containing tensors [of shapes (2, 5, 6), (3, 5, 6), (4, 5, 6)] filled with zeros
Args:
lsizes (Sequence[int] | Sequence[Sequence[int]]): Sizes for the jagged dimensions.
Can be a sequence of integers for simple jagged structure, or nested sequences
for multi-level jagged structure.
rsizes (Sequence[int] | None): Sizes for the regular (trailing) dimensions.
Defaults to ``None`` *i.e.* scalar elements.
device (torch.device | str | None): Device to create the tensor on.
Defaults to ``None`` *i.e.* ``"cpu"``.
dtype (torch.dtype | None): Data type for the tensor elements.
Defaults to ``None`` *i.e.* ``torch.float32``.
requires_grad (bool): Whether to track gradients. Defaults to ``False``.
pin_memory (bool): Whether to use pinned memory. Defaults to ``False``.
Returns:
JaggedTensor: A new :class:`JaggedTensor` filled with zeros.
"""
lsizes_cpp: list[int] | list[list[int]] = _convert_to_list(lsizes)
rsizes_cpp: list[int] | None = _convert_to_list(rsizes) if rsizes is not None else None
return JaggedTensor(impl=jzeros_cpp(lsizes_cpp, rsizes_cpp, dtype, device, requires_grad, pin_memory))