Skip to content

utils.model_utils#

[view_source]

Functions used to initialize and manipulate pytorch models.

Flatten#

class Flatten(nn.Module)

[view_source]

Flatten input tensor so that it is of shape (FLATTENED_BATCH x -1).

Flatten.forward#

 | forward(x)

[view_source]

Flatten input tensor.

Parameters

  • x : Tensor of size (FLATTENED_BATCH x ...) to flatten to size (FLATTENED_BATCH x -1) Returns

Flattened tensor.

init_linear_layer#

init_linear_layer(module: nn.Linear, weight_init: Callable, bias_init: Callable, gain=1)

[view_source]

Initialize a torch.nn.Linear layer.

Parameters

  • module : A torch linear layer.
  • weight_init : Function used to initialize the weight parameters of the linear layer. Should take the weight data tensor and gain as input.
  • bias_init : Function used to initialize the bias parameters of the linear layer. Should take the bias data tensor and gain as input.
  • gain : The gain to apply.

Returns

The initialized linear layer.

compute_cnn_output#

compute_cnn_output(cnn: nn.Module, cnn_input: torch.Tensor, permute_order: Optional[Tuple[int, ...]] = (
        0,  # FLAT_BATCH (flattening steps, samplers and agents)
        3,  # CHANNEL
        1,  # ROW
        2,  # COL
    ))

[view_source]

Computes CNN outputs for given inputs.

Parameters

  • cnn : A torch CNN.
  • cnn_input: A torch Tensor with inputs.
  • permute_order: A permutation Tuple to provide PyTorch dimension order, default (0, 3, 1, 2), where 0 corresponds to the flattened batch dimensions (combining step, sampler and agent)

Returns

CNN output with dimensions [STEP, SAMPLER, AGENT, CHANNEL, (HEIGHT, WIDTH)].