Skip to content

unienv_data.batches.framestack_batch

FrameStackedBatch

FrameStackedBatch(batch: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType], prefetch_horizon: int = 0, postfetch_horizon: int = 0, get_valid_mask_function: Optional[Callable[[SliceStackedBatch, BArrayType, BatchT, Dict[str, Any]], BArrayType]] = None, fill_invalid_data: bool = True, stack_metadata: bool = False)

Bases: SliceStackedBatch[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]

A batch that stacks frames in a sliding window manner. This batch allows for prefetching and postfetching of frames, which can be useful for training models that require temporal context (e.g. Diffusion Policy, ACT, etc.) This is a read-only batch, since it is a view of the original batch. If you want to change the data, you should mutate the containing batch instead.

is_mutable class-attribute instance-attribute

is_mutable = False

single_space instance-attribute

single_space = single_space

single_metadata_space instance-attribute

single_metadata_space = single_metadata_space

backend property

backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]

device property

device: Optional[BDeviceType]

batch instance-attribute

batch = batch

fixed_offset instance-attribute

fixed_offset = fixed_offset

fill_invalid_data instance-attribute

fill_invalid_data = fill_invalid_data

get_valid_mask_function instance-attribute

get_valid_mask_function = get_valid_mask_function

stack_metadata instance-attribute

stack_metadata = stack_metadata

get_flattened_at

get_flattened_at(idx)

get_flattened_at_with_metadata

get_flattened_at_with_metadata(idx)

set_flattened_at

set_flattened_at(idx: Union[IndexableType, BArrayType], value: BArrayType) -> None

Overwrite existing samples using flattened data.

append_flattened

append_flattened(value: BArrayType) -> None

Append one flattened sample to the batch.

extend_flattened

extend_flattened(value: BArrayType) -> None

Append a batched block of flattened samples.

get_at

get_at(idx: Union[IndexableType, BArrayType]) -> BatchT

get_at_with_metadata

get_at_with_metadata(idx: Union[IndexableType, BArrayType]) -> Tuple[BatchT, Dict[str, Any]]

set_at

set_at(idx: Union[IndexableType, BArrayType], value: BatchT) -> None

Overwrite existing samples using structured data.

remove_at

remove_at(idx: Union[IndexableType, BArrayType]) -> None

Remove one or more samples from the batch.

append

append(value: BatchT) -> None

Append one structured sample to the batch.

extend

extend(value: BatchT) -> None

Append a batched block of structured samples.

extend_from

extend_from(other: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType], chunk_size: int = 8, tqdm: bool = False) -> None

Copy data from another batch in bounded-size chunks.

get_slice

get_slice(idx: Union[IndexableType, BArrayType]) -> BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]

Create a lazy view over a subset of indices.

get_column

get_column(nested_keys: Sequence[str]) -> BatchBase[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]

Create a lazy view over a nested field inside each sample.

close

close() -> None

expand_index

expand_index(index: BArrayType) -> BArrayType

Expand indexes to slice the data Args: index (BArrayType): A 1D tensor of indices to expand. Can be boolean or integer. Returns: BArrayType: A 2D tensor of indices, where each row corresponds to the expanded indices for each batch item.

get_valid_mask_flattened

get_valid_mask_flattened(expanded_idx: BArrayType, data: BArrayType, metadata: Dict[str, Any]) -> BArrayType

Get the valid mask for the flattened data. Args: expanded_idx (B, T): A 2D tensor of indices to slice the data. data (B, T, *D): The data to slice. Returns: Valid Mask (B, T)

get_valid_mask

get_valid_mask(expanded_idx: BArrayType, data: BatchT, metadata: Dict[str, Any]) -> BArrayType

Get the valid mask for the data. Args: expanded_idx (B, T): A 2D tensor of indices to slice the data. data (BatchT): The data to slice. Returns: Valid Mask (B, T)

fill_data_with_stack_mask

fill_data_with_stack_mask(space: Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]], data: Union[BArrayType, BatchT, Any], valid_mask: BArrayType) -> Union[BArrayType, BatchT, Any]

Fill the data with the mask as if the frames were frame-stacked. Args: space (Optional[Space]): The space to fill the data with. If None, the data is assumed to be a backend array. data (Union[BArrayType, BatchT, Any]): The data to fill sized (B, T, D) valid_mask (BArrayType): The mask to fill the data with, sized (B, T) Returns: Union[BArrayType, BatchT, Any]: The filled data, sized (B, T, D)

get_valid_mask_function_with_episodeid_key staticmethod

get_valid_mask_function_with_episodeid_key(episode_id_key: Union[str, int] = 'episode_id', is_in_metadata: bool = False) -> Callable[[SliceStackedBatch, BArrayType, BatchT], BArrayType]

get_valid_mask_function_with_episode_end_key staticmethod

get_valid_mask_function_with_episode_end_key(episode_end_key: Union[str, int] = 'episode_end', is_in_metadata: bool = False) -> Callable[[SliceStackedBatch, BArrayType, BatchT], BArrayType]