Skip to content

projects.objectnav_baselines.models.object_nav_models#

[view_source]

Baseline models for use in the object navigation task.

Object navigation is currently available as a Task in AI2-THOR and Facebook's Habitat.

ObjectNavActorCritic#

class ObjectNavActorCritic(VisualNavActorCritic)

[view_source]

Baseline recurrent actor critic model for object-navigation.

Attributes

  • action_space: The space of actions available to the agent. Currently only discrete actions are allowed (so this space will always be of type gym.spaces.Discrete).
  • observation_space: The observation space expected by the agent. This observation space should include (optionally) 'rgb' images and 'depth' images and is required to have a component corresponding to the goal goal_sensor_uuid.
  • goal_sensor_uuid: The uuid of the sensor of the goal object. See GoalObjectTypeThorSensor as an example of such a sensor.
  • hidden_size: The hidden size of the GRU RNN.
  • object_type_embedding_dim: The dimensionality of the embedding corresponding to the goal object type.

ObjectNavActorCritic.__init__#

 | __init__(action_space: gym.spaces.Discrete, observation_space: SpaceDict, goal_sensor_uuid: str, hidden_size=512, num_rnn_layers=1, rnn_type="GRU", add_prev_actions=False, action_embed_size=6, multiple_beliefs=False, beliefs_fusion: Optional[FusionType] = None, auxiliary_uuids: Optional[List[str]] = None, rgb_uuid: Optional[str] = None, depth_uuid: Optional[str] = None, object_type_embedding_dim=8, trainable_masked_hidden_state: bool = False, backbone="gnresnet18", resnet_baseplanes=32)

[view_source]

Initializer.

See class documentation for parameter definitions.

ObjectNavActorCritic.is_blind#

 | @property
 | is_blind() -> bool

[view_source]

True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).

ObjectNavActorCritic.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.

ResnetTensorObjectNavActorCritic#

class ResnetTensorObjectNavActorCritic(VisualNavActorCritic)

[view_source]

ResnetTensorObjectNavActorCritic.is_blind#

 | @property
 | is_blind() -> bool

[view_source]

True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).

ResnetTensorGoalEncoder#

class ResnetTensorGoalEncoder(nn.Module)

[view_source]

ResnetTensorGoalEncoder.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.

ResnetDualTensorGoalEncoder#

class ResnetDualTensorGoalEncoder(nn.Module)

[view_source]

ResnetDualTensorGoalEncoder.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.