Skip to content

unienv_data.integrations.pytorch

UniEnvAsPyTorchDataset

UniEnvAsPyTorchDataset(batch: BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType], include_metadata: bool = False)

Bases: Dataset

A PyTorch Dataset wrapper for UniEnvPy batches. Note that UniEnv's BatchBase will automatically collate data when indexed with batches, and therefore in the dataloader you can set collate_fn=None.

Parameters:

Name Type Description Default
batch BatchBase

The UniEnvPy batch to wrap.

required

batch instance-attribute

batch = batch

include_metadata instance-attribute

include_metadata = include_metadata

PyTorchAsUniEnvDataset

PyTorchAsUniEnvDataset(dataset: Dataset)

Bases: BatchBase[BatchT, PyTorchArrayType, PyTorchDeviceType, PyTorchDtypeType, PyTorchRNGType]

A UniEnvPy BatchBase wrapper for PyTorch Datasets.

Parameters:

Name Type Description Default
dataset Dataset

The PyTorch Dataset to wrap.

required

is_mutable class-attribute instance-attribute

is_mutable = False

dataset instance-attribute

dataset = dataset

single_space instance-attribute

single_space = single_space

single_metadata_space instance-attribute

single_metadata_space = single_metadata_space

backend property

backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]

device property

device: Optional[BDeviceType]

get_at_with_metadata

get_at_with_metadata(idx)

get_at

get_at(idx)

get_flattened_at

get_flattened_at(idx: Union[IndexableType, BArrayType]) -> BArrayType

Fetch samples as flattened backend arrays.

get_flattened_at_with_metadata

get_flattened_at_with_metadata(idx: Union[IndexableType, BArrayType]) -> Tuple[BArrayType, Optional[Dict[str, Any]]]

Fetch flattened samples together with optional per-sample metadata.

set_flattened_at

set_flattened_at(idx: Union[IndexableType, BArrayType], value: BArrayType) -> None

Overwrite existing samples using flattened data.

append_flattened

append_flattened(value: BArrayType) -> None

Append one flattened sample to the batch.

extend_flattened

extend_flattened(value: BArrayType) -> None

Append a batched block of flattened samples.

set_at

set_at(idx: Union[IndexableType, BArrayType], value: BatchT) -> None

Overwrite existing samples using structured data.

remove_at

remove_at(idx: Union[IndexableType, BArrayType]) -> None

Remove one or more samples from the batch.

append

append(value: BatchT) -> None

Append one structured sample to the batch.

extend

extend(value: BatchT) -> None

Append a batched block of structured samples.

extend_from

extend_from(other: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType], chunk_size: int = 8, tqdm: bool = False) -> None

Copy data from another batch in bounded-size chunks.

get_slice

get_slice(idx: Union[IndexableType, BArrayType]) -> BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]

Create a lazy view over a subset of indices.

get_column

get_column(nested_keys: Sequence[str]) -> BatchBase[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]

Create a lazy view over a nested field inside each sample.

close

close() -> None