Skip to content

projects.objectnav_baselines.models.object_nav_models#

[view_source]

Baseline models for use in the object navigation task.

Object navigation is currently available as a Task in AI2-THOR and Facebook's Habitat.

ObjectNavBaselineActorCritic#

class ObjectNavBaselineActorCritic(ActorCriticModel[CategoricalDistr])

[view_source]

Baseline recurrent actor critic model for object-navigation.

Attributes

  • action_space: The space of actions available to the agent. Currently only discrete actions are allowed (so this space will always be of type gym.spaces.Discrete).
  • observation_space: The observation space expected by the agent. This observation space should include (optionally) 'rgb' images and 'depth' images and is required to have a component corresponding to the goal goal_sensor_uuid.
  • goal_sensor_uuid: The uuid of the sensor of the goal object. See GoalObjectTypeThorSensor as an example of such a sensor.
  • hidden_size: The hidden size of the GRU RNN.
  • object_type_embedding_dim: The dimensionality of the embedding corresponding to the goal object type.

ObjectNavBaselineActorCritic.__init__#

 | __init__(action_space: gym.spaces.Discrete, observation_space: SpaceDict, goal_sensor_uuid: str, hidden_size=512, object_type_embedding_dim=8, trainable_masked_hidden_state: bool = False, num_rnn_layers=1, rnn_type="GRU")

[view_source]

Initializer.

See class documentation for parameter definitions.

ObjectNavBaselineActorCritic.recurrent_hidden_state_size#

 | @property
 | recurrent_hidden_state_size() -> int

[view_source]

The recurrent hidden state size of the model.

ObjectNavBaselineActorCritic.is_blind#

 | @property
 | is_blind() -> bool

[view_source]

True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).

ObjectNavBaselineActorCritic.num_recurrent_layers#

 | @property
 | num_recurrent_layers() -> int

[view_source]

Number of recurrent hidden layers.

ObjectNavBaselineActorCritic.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.

ObjectNavBaselineActorCritic.forward#

 | forward(observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]

[view_source]

Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an ActorCriticOutput object containing the model's policy (distribution over actions) and evaluation of the current state (value).

Parameters

  • observations : Batched input observations.
  • rnn_hidden_states : Hidden states from initial timepoints.
  • prev_actions : Tensor of previous actions taken.
  • masks : Masks applied to hidden states. See RNNStateEncoder. Returns

Tuple of the ActorCriticOutput and recurrent hidden state.

ResnetTensorObjectNavActorCritic#

class ResnetTensorObjectNavActorCritic(ActorCriticModel[CategoricalDistr])

[view_source]

ResnetTensorObjectNavActorCritic.recurrent_hidden_state_size#

 | @property
 | recurrent_hidden_state_size() -> int

[view_source]

The recurrent hidden state size of the model.

ResnetTensorObjectNavActorCritic.is_blind#

 | @property
 | is_blind() -> bool

[view_source]

True if the model is blind (e.g. neither 'depth' or 'rgb' is an input observation type).

ResnetTensorObjectNavActorCritic.num_recurrent_layers#

 | @property
 | num_recurrent_layers() -> int

[view_source]

Number of recurrent hidden layers.

ResnetTensorObjectNavActorCritic.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.

ResnetTensorGoalEncoder#

class ResnetTensorGoalEncoder(nn.Module)

[view_source]

ResnetTensorGoalEncoder.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.

ResnetDualTensorGoalEncoder#

class ResnetDualTensorGoalEncoder(nn.Module)

[view_source]

ResnetDualTensorGoalEncoder.get_object_type_encoding#

 | get_object_type_encoding(observations: Dict[str, torch.FloatTensor]) -> torch.FloatTensor

[view_source]

Get the object type encoding from input batched observations.