rl_utils.coma

coma

coma_error

Overview:

Implementation of COMA

Arguments:
  • data (namedtuple): coma input data with fieids shown in coma_data

Returns:
  • coma_loss (namedtuple): the coma loss item, all of them are the differentiable 0-dim tensor

Shapes:
  • logit (torch.FloatTensor): \((T, B, A, N)\), where B is batch size A is the agent num, and N is action dim

  • action (torch.LongTensor): \((T, B, A)\)

  • q_value (torch.FloatTensor): \((T, B, A, N)\)

  • target_q_value (torch.FloatTensor): \((T, B, A, N)\)

  • reward (torch.FloatTensor): \((T, B)\)

  • weight (torch.FloatTensor or None): \((T ,B, A)\)

  • policy_loss (torch.FloatTensor): \(()\), 0-dim tensor

  • value_loss (torch.FloatTensor): \(()\)

  • entropy_loss (torch.FloatTensor): \(()\)