Skip to content

core.algorithms.onpolicy_sync.losses.imitation#

[view_source]

Defining imitation losses for actor critic type models.

Imitation#

class Imitation(AbstractActorCriticLoss)

[view_source]

Expert imitation loss.

Imitation.loss#

 | loss(step_count: int, batch: ObservationType, actor_critic_output: ActorCriticOutput[CategoricalDistr], *args, **kwargs)

[view_source]

Computes the imitation loss.

Parameters

  • batch : A batch of data corresponding to the information collected when rolling out (possibly many) agents over a fixed number of steps. In particular this batch should have the same format as that returned by RolloutStorage.recurrent_generator. Here batch["observations"] must contain "expert_action" observations or "expert_policy" observations. See ExpertActionSensor (or ExpertPolicySensor) for an example of a sensor producing such observations.
  • actor_critic_output : The output of calling an ActorCriticModel on the observations in batch.
  • args : Extra args. Ignored.
  • kwargs : Extra kwargs. Ignored.

Returns

A (0-dimensional) torch.FloatTensor corresponding to the computed loss. .backward() will be called on this tensor in order to compute a gradient update to the ActorCriticModel's parameters.