rl_utils.coma¶
coma¶
coma_error¶
- Overview:
Implementation of COMA
- Arguments:
data (
namedtuple): coma input data with fieids shown incoma_data
- Returns:
coma_loss (
namedtuple): the coma loss item, all of them are the differentiable 0-dim tensor
- Shapes:
logit (
torch.FloatTensor): \((T, B, A, N)\), where B is batch size A is the agent num, and N is action dimaction (
torch.LongTensor): \((T, B, A)\)q_value (
torch.FloatTensor): \((T, B, A, N)\)target_q_value (
torch.FloatTensor): \((T, B, A, N)\)reward (
torch.FloatTensor): \((T, B)\)weight (
torch.FloatTensororNone): \((T ,B, A)\)policy_loss (
torch.FloatTensor): \(()\), 0-dim tensorvalue_loss (
torch.FloatTensor): \(()\)entropy_loss (
torch.FloatTensor): \(()\)