allenact.algorithms.offpolicy_sync.losses.abstract_offpolicy_loss
#
Defining abstract loss classes for actor critic models.
AbstractOffPolicyLoss
#
class AbstractOffPolicyLoss(Generic[ModelType], Loss)
Abstract class representing an off-policy loss function used to train a model.
AbstractOffPolicyLoss.loss
#
| @abc.abstractmethod
| loss(step_count: int, model: ModelType, batch: ObservationType, memory: Memory, *args, **kwargs, *, ,) -> Tuple[torch.FloatTensor, Dict[str, float], Memory, int]
Computes the loss.
Loss after processing a batch of data with (part of) a model (possibly with memory).
Parameters
- model: model to run on data batch (both assumed to be on the same device)
- batch: data to use as input for model (already on the same device as model)
- memory: model memory before processing current data batch
Returns
A tuple with
:
current_loss
: total loss
current_info
: additional information about the current loss
memory
: model memory after processing current data batch
bsize
: batch size