Skip to content



Defining imitation losses for actor critic type models.


class Imitation(AbstractActorCriticLoss)


Expert imitation loss.


 | loss(step_count: int, batch: ObservationType, actor_critic_output: ActorCriticOutput[CategoricalDistr], *args, **kwargs)


Computes the imitation loss.


  • batch : A batch of data corresponding to the information collected when rolling out (possibly many) agents over a fixed number of steps. In particular this batch should have the same format as that returned by RolloutStorage.recurrent_generator. Here batch["observations"] must contain "expert_action" observations or "expert_policy" observations. See ExpertActionSensor (or ExpertPolicySensor) for an example of a sensor producing such observations.
  • actor_critic_output : The output of calling an ActorCriticModel on the observations in batch.
  • args : Extra args. Ignored.
  • kwargs : Extra kwargs. Ignored.


A (0-dimensional) torch.FloatTensor corresponding to the computed loss. .backward() will be called on this tensor in order to compute a gradient update to the ActorCriticModel's parameters.