projects.objectnav_baselines.models.object_nav_models
#
Baseline models for use in the object navigation task.
Object navigation is currently available as a Task in AI2-THOR and Facebook's Habitat.
ObjectNavActorCritic
#
class ObjectNavActorCritic(VisualNavActorCritic)
Baseline recurrent actor critic model for object-navigation.
Attributes
action_space
: The space of actions available to the agent. Currently only discrete actions are allowed (so this space will always be of typegym.spaces.Discrete
).observation_space
: The observation space expected by the agent. This observation space should include (optionally) 'rgb' images and 'depth' images and is required to have a component corresponding to the goalgoal_sensor_uuid
.goal_sensor_uuid
: The uuid of the sensor of the goal object. SeeGoalObjectTypeThorSensor
as an example of such a sensor.hidden_size
: The hidden size of the GRU RNN.object_type_embedding_dim
: The dimensionality of the embedding corresponding to the goal object type.
ObjectNavActorCritic.__init__
#
| __init__(action_space: gym.spaces.Discrete, observation_space: SpaceDict, goal_sensor_uuid: str, hidden_size=512, num_rnn_layers=1, rnn_type="GRU", add_prev_actions=False, action_embed_size=6, multiple_beliefs=False, beliefs_fusion: Optional[FusionType] = None, auxiliary_uuids: Optional[List[str]] = None, rgb_uuid: Optional[str] = None, depth_uuid: Optional[str] = None, object_type_embedding_dim=8, trainable_masked_hidden_state: bool = False, backbone="gnresnet18", resnet_baseplanes=32)
Initializer.
See class documentation for parameter definitions.
ObjectNavActorCritic.is_blind
#
| @property
| is_blind() -> bool
True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).
ObjectNavActorCritic.get_object_type_encoding
#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.
ResnetTensorObjectNavActorCritic
#
class ResnetTensorObjectNavActorCritic(VisualNavActorCritic)
ResnetTensorObjectNavActorCritic.is_blind
#
| @property
| is_blind() -> bool
True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).
ResnetTensorGoalEncoder
#
class ResnetTensorGoalEncoder(nn.Module)
ResnetTensorGoalEncoder.get_object_type_encoding
#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.
ResnetDualTensorGoalEncoder
#
class ResnetDualTensorGoalEncoder(nn.Module)
ResnetDualTensorGoalEncoder.get_object_type_encoding
#
| get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor
Get the object type encoding from input batched observations.