unienv_data.batches.slicestack_batch¶
SliceStackedBatch
¶
SliceStackedBatch(batch: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType], fixed_offset: BArrayType, get_valid_mask_function: Optional[Callable[[SliceStackedBatch, BArrayType, BatchT, Dict[str, Any]], BArrayType]] = None, fill_invalid_data: bool = True, stack_metadata: bool = False)
Bases: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
A batch that stacks frames with given fixed offsets. This is a read-only batch, since it is a view of the original batch. If you want to change the data, you should mutate the containing batch instead.
expand_index
¶
expand_index(index: BArrayType) -> BArrayType
Expand indexes to slice the data Args: index (BArrayType): A 1D tensor of indices to expand. Can be boolean or integer. Returns: BArrayType: A 2D tensor of indices, where each row corresponds to the expanded indices for each batch item.
get_valid_mask_flattened
¶
get_valid_mask_flattened(expanded_idx: BArrayType, data: BArrayType, metadata: Dict[str, Any]) -> BArrayType
Get the valid mask for the flattened data. Args: expanded_idx (B, T): A 2D tensor of indices to slice the data. data (B, T, *D): The data to slice. Returns: Valid Mask (B, T)
get_valid_mask
¶
get_valid_mask(expanded_idx: BArrayType, data: BatchT, metadata: Dict[str, Any]) -> BArrayType
Get the valid mask for the data. Args: expanded_idx (B, T): A 2D tensor of indices to slice the data. data (BatchT): The data to slice. Returns: Valid Mask (B, T)
fill_data_with_stack_mask
¶
fill_data_with_stack_mask(space: Optional[Space[Any, BDeviceType, BDtypeType, BRNGType]], data: Union[BArrayType, BatchT, Any], valid_mask: BArrayType) -> Union[BArrayType, BatchT, Any]
Fill the data with the mask as if the frames were frame-stacked. Args: space (Optional[Space]): The space to fill the data with. If None, the data is assumed to be a backend array. data (Union[BArrayType, BatchT, Any]): The data to fill sized (B, T, D) valid_mask (BArrayType): The mask to fill the data with, sized (B, T) Returns: Union[BArrayType, BatchT, Any]: The filled data, sized (B, T, D)
get_at_with_metadata
¶
get_at_with_metadata(idx: Union[IndexableType, BArrayType]) -> Tuple[BatchT, Dict[str, Any]]
get_valid_mask_function_with_episodeid_key
staticmethod
¶
get_valid_mask_function_with_episodeid_key(episode_id_key: Union[str, int] = 'episode_id', is_in_metadata: bool = False) -> Callable[[SliceStackedBatch, BArrayType, BatchT], BArrayType]
get_valid_mask_function_with_episode_end_key
staticmethod
¶
get_valid_mask_function_with_episode_end_key(episode_end_key: Union[str, int] = 'episode_end', is_in_metadata: bool = False) -> Callable[[SliceStackedBatch, BArrayType, BatchT], BArrayType]
set_flattened_at
¶
set_flattened_at(idx: Union[IndexableType, BArrayType], value: BArrayType) -> None
Overwrite existing samples using flattened data.
append_flattened
¶
append_flattened(value: BArrayType) -> None
Append one flattened sample to the batch.
extend_flattened
¶
extend_flattened(value: BArrayType) -> None
Append a batched block of flattened samples.
set_at
¶
set_at(idx: Union[IndexableType, BArrayType], value: BatchT) -> None
Overwrite existing samples using structured data.
remove_at
¶
remove_at(idx: Union[IndexableType, BArrayType]) -> None
Remove one or more samples from the batch.
extend_from
¶
extend_from(other: BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType], chunk_size: int = 8, tqdm: bool = False) -> None
Copy data from another batch in bounded-size chunks.
get_slice
¶
get_slice(idx: Union[IndexableType, BArrayType]) -> BatchBase[BatchT, BArrayType, BDeviceType, BDtypeType, BRNGType]
Create a lazy view over a subset of indices.
get_column
¶
get_column(nested_keys: Sequence[str]) -> BatchBase[Any, BArrayType, BDeviceType, BDtypeType, BRNGType]
Create a lazy view over a nested field inside each sample.