Skip to content

utils.tensor_utils#

[view_source]

Functions used to manipulate pytorch tensors and numpy arrays.

to_device_recursively#

to_device_recursively(input: Any, device: Union[str, torch.device, int], inplace: bool = True)

[view_source]

Recursively places tensors on the appropriate device.

detach_recursively#

detach_recursively(input: Any, inplace=True)

[view_source]

Recursively detaches tensors in some data structure from their computation graph.

batch_observations#

batch_observations(observations: List[Dict], device: Optional[torch.device] = None) -> Dict[str, Union[Dict, torch.Tensor]]

[view_source]

Transpose a batch of observation dicts to a dict of batched observations.

Arguments

  • observations : List of dicts of observations.
  • device : The torch.device to put the resulting tensors on. Will not move the tensors if None.

Returns

Transposed dict of lists of observations.

to_tensor#

to_tensor(v) -> torch.Tensor

[view_source]

Return a torch.Tensor version of the input.

Parameters

  • v : Input values that can be coerced into being a tensor.

Returns

A tensor version of the input.

tile_images#

tile_images(images: List[np.ndarray]) -> np.ndarray

[view_source]

Tile multiple images into single image.

Parameters

  • images : list of images where each image has dimension (height x width x channels)

Returns

Tiled image (new_height x width x channels).

image#

image(tag, tensor, rescale=1, dataformats="CHW")

[view_source]

Outputs a Summary protocol buffer with images. The summary has up to max_images summary values containing images. The images are built from tensor which must be 3-D with shape [height, width, channels] and where channels can be:

  • 1: tensor is interpreted as Grayscale.
  • 3: tensor is interpreted as RGB.
  • 4: tensor is interpreted as RGBA.

Arguments:

  • tag - A name for the generated node. Will also serve as a series name in TensorBoard.
  • tensor - A 3-D uint8 or float32 Tensor of shape [height, width, channels] where channels is 1, 3, or 4. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). The image() function will scale the image values to [0, 255] by applying a scale factor of either 1 (uint8) or 255 (float32).

Returns:

A scalar Tensor of type string. The serialized Summary protocol buffer.

ScaleBothSides#

class ScaleBothSides(object)

[view_source]

Rescales the input PIL.Image to the given 'width' and height.

Attributes width: new width height: new height interpolation: Default: PIL.Image.BILINEAR