Skip to content



Functions used to manipulate pytorch tensors and numpy arrays.


to_device_recursively(input: Any, device: Union[str, torch.device, int], inplace: bool = True)


Recursively places tensors on the appropriate device.


detach_recursively(input: Any, inplace=True)


Recursively detaches tensors in some data structure from their computation graph.


batch_observations(observations: List[Dict], device: Optional[torch.device] = None) -> Dict[str, Union[Dict, torch.Tensor]]


Transpose a batch of observation dicts to a dict of batched observations.


  • observations : List of dicts of observations.
  • device : The torch.device to put the resulting tensors on. Will not move the tensors if None.


Transposed dict of lists of observations.


to_tensor(v) -> torch.Tensor


Return a torch.Tensor version of the input.


  • v : Input values that can be coerced into being a tensor.


A tensor version of the input.


tile_images(images: List[np.ndarray]) -> np.ndarray


Tile multiple images into single image.


  • images : list of images where each image has dimension (height x width x channels)


Tiled image (new_height x width x channels).


image(tag, tensor, rescale=1, dataformats="CHW")


Outputs a Summary protocol buffer with images. The summary has up to max_images summary values containing images. The images are built from tensor which must be 3-D with shape [height, width, channels] and where channels can be:

  • 1: tensor is interpreted as Grayscale.
  • 3: tensor is interpreted as RGB.
  • 4: tensor is interpreted as RGBA.


  • tag: A name for the generated node. Will also serve as a series name in TensorBoard.
  • tensor: A 3-D uint8 or float32 Tensor of shape [height, width, channels] where channels is 1, 3, or 4. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). The image() function will scale the image values to [0, 255] by applying a scale factor of either 1 (uint8) or 255 (float32).
  • rescale: The scale.
  • dataformats: Input image shape format.


A scalar Tensor of type string. The serialized Summary protocol buffer.


class ScaleBothSides(object)


Rescales the input PIL.Image to the given 'width' and height.

Attributes width: new width height: new height interpolation: Default: PIL.Image.BILINEAR