allenact.utils.model_utils
#
Functions used to initialize and manipulate pytorch models.
Flatten
#
class Flatten(nn.Module)
Flatten input tensor so that it is of shape (FLATTENED_BATCH x -1).
Flatten.forward
#
| forward(x)
Flatten input tensor.
Parameters
- x : Tensor of size (FLATTENED_BATCH x ...) to flatten to size (FLATTENED_BATCH x -1) Returns
Flattened tensor.
init_linear_layer
#
init_linear_layer(module: nn.Linear, weight_init: Callable, bias_init: Callable, gain=1)
Initialize a torch.nn.Linear layer.
Parameters
- module : A torch linear layer.
- weight_init : Function used to initialize the weight parameters of the linear layer. Should take the weight data tensor and gain as input.
- bias_init : Function used to initialize the bias parameters of the linear layer. Should take the bias data tensor and gain as input.
- gain : The gain to apply.
Returns
The initialized linear layer.
compute_cnn_output
#
compute_cnn_output(cnn: nn.Module, cnn_input: torch.Tensor, permute_order: Optional[Tuple[int, ...]] = (
0, # FLAT_BATCH (flattening steps, samplers and agents)
3, # CHANNEL
1, # ROW
2, # COL
))
Computes CNN outputs for given inputs.
Parameters
- cnn : A torch CNN.
- cnn_input: A torch Tensor with inputs.
- permute_order: A permutation Tuple to provide PyTorch dimension order, default (0, 3, 1, 2), where 0 corresponds to the flattened batch dimensions (combining step, sampler and agent)
Returns
CNN output with dimensions [STEP, SAMPLER, AGENT, CHANNEL, (HEIGHT, WIDTH)].
FeatureEmbedding
#
class FeatureEmbedding(nn.Module)
A wrapper of nn.Embedding but support zero output Used for extracting features for actions/rewards