template.qmix

Please Reference ding/model/template/qmix.py for usage

Mixer

class ding.model.template.qmix.Mixer(agent_num, state_dim, mixing_embed_dim, hypernet_embed=64)[source]
Overview:

mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value

Interface:

__init__, forward

__init__(agent_num, state_dim, mixing_embed_dim, hypernet_embed=64)[source]
Overview:

initialize pymarl mixer network

Arguments:
  • agent_num (int): the number of agent

  • state_dim(int): the dimension of global observation state

  • mixing_embed_dim (int): the dimension of mixing state emdedding

  • hypernet_embed (int): the dimension of hypernet emdedding, default to 64

forward(agent_qs, states)[source]
Overview:

forward computation graph of pymarl mixer network

Arguments:
  • agent_qs (torch.FloatTensor): the independent q_value of each agent

  • states (torch.FloatTensor): the emdedding vector of global state

Returns:
  • q_tot (torch.FloatTensor): the total mixed q_value

Shapes:
  • agent_qs (torch.FloatTensor): \((B, N)\), where B is batch size and N is agent_num

  • states (torch.FloatTensor): \((B, M)\), where M is embedding_size

  • q_tot (torch.FloatTensor): \((B, )\)

QMix

class ding.model.template.qmix.QMix(agent_num: int, obs_shape: int, global_obs_shape: int, action_shape: int, hidden_size_list: list, mixer: bool = True, lstm_type: str = 'gru', dueling: bool = False)[source]
Overview:

QMIX network

Interface:

__init__, forward, _setup_global_encoder

__init__(agent_num: int, obs_shape: int, global_obs_shape: int, action_shape: int, hidden_size_list: list, mixer: bool = True, lstm_type: str = 'gru', dueling: bool = False)None[source]
Overview:

initialize Qmix network

Arguments:
  • agent_num (int): the number of agent

  • obs_shape (int): the dimension of each agent’s observation state

  • global_obs_shape (int): the dimension of global observation state

  • action_shape (int): the dimension of action shape

  • hidden_size_list (list): the list of hidden size

  • mixer (bool): use mixer net or not, default to True

  • use_gru (bool): use lstm type or not, default to False

  • use_pmixer (bool): use pymarl mixer net or not, default to False. When mixer is False, we can’t use pymarl mixer net or normal mixer net

_setup_global_encoder(global_obs_shape: int, embedding_size: int)torch.nn.modules.module.Module[source]
Overview:

Used to encoder global observation.

Arguments:
  • global_obs_shape (int): the dimension of global observation state

  • embedding_size (int): the dimension of state emdedding

Return:
  • outputs (torch.nn.Module): Global observation encoding network

forward(data: dict, single_step: bool = True)dict[source]
Overview:

forward computation graph of qmix network

Arguments:
  • data (dict): input data dict with keys [‘obs’, ‘prev_state’, ‘action’]
    • agent_state (torch.Tensor): each agent local state(obs)

    • global_state (torch.Tensor): global state(obs)

    • prev_state (list): previous rnn state

    • action (torch.Tensor or None): if action is None, use argmax q_value index as action to calculate agent_q_act

  • single_step (bool): whether single_step forward, if so, add timestep dim before forward and remove it after forward

Returns:
  • ret (dict): output data dict with keys [total_q, logit, next_state]

  • total_q (torch.Tensor): total q_value, which is the result of mixer network

  • agent_q (torch.Tensor): each agent q_value

  • next_state (list): next rnn state

Shapes:
  • agent_state (torch.Tensor): \((T, B, A, N)\), where T is timestep, B is batch_size A is agent_num, N is obs_shape

  • global_state (torch.Tensor): \((T, B, M)\), where M is global_obs_shape

  • prev_state (list): math:(B, A), a list of length B, and each element is a list of length A

  • action (torch.Tensor): \((T, B, A)\)

  • total_q (torch.Tensor): \((T, B)\)

  • agent_q (torch.Tensor): \((T, B, A, P)\), where P is action_shape

  • next_state (list): math:(B, A), a list of length B, and each element is a list of length A

CollaQMultiHeadAttention

class ding.model.template.qmix.CollaQMultiHeadAttention(n_head: int, d_model_q: int, d_model_v: int, d_k: int, d_v: int, d_out: int, dropout: float = 0.0)[source]
Overview:

The head of collaq attention module.

Interface:

__init__, forward

__init__(n_head: int, d_model_q: int, d_model_v: int, d_k: int, d_v: int, d_out: int, dropout: float = 0.0)[source]
Overview:

initialize the head of collaq attention module

Arguments:
  • n_head (int): the num of head

  • d_model_q (int): the size of input q

  • d_model_v (int): the size of input v

  • d_k (int): the size of k, used by Scaled Dot Product Attention

  • d_v (int): the size of v, used by Scaled Dot Product Attention

  • d_out (int): the size of output q

forward(q, k, v, mask=None)[source]
Overview:

forward computation graph of collaQ multi head attention net.

Arguments:
  • q (torch.nn.Sequential): the transformer information q

  • k (torch.nn.Sequential): the transformer information k

  • v (torch.nn.Sequential): the transformer information v

Output:
  • q (torch.nn.Sequential): the transformer output q

CollaQSMACAttentionModule

class ding.model.template.qmix.CollaQSMACAttentionModule(q_dim: int, v_dim: int, self_feature_range: List[int], ally_feature_range: List[int], attention_size: int)[source]
Overview:

Collaq attention module. Used to get agent’s attention observation. It includes agent’s observation and agent’s part of the observation information of the agent’s concerned allies

Interface:

__init__, _cut_obs, forward

__init__(q_dim: int, v_dim: int, self_feature_range: List[int], ally_feature_range: List[int], attention_size: int)[source]
Overview:

initialize collaq attention module

Arguments:
  • q_dim (int): the dimension of transformer output q

  • v_dim (int): the dimension of transformer output v

  • self_features (torch.Tensor): output self agent’s attention observation

  • ally_features (torch.Tensor): output ally agent’s attention observation

  • attention_size (int): the size of attention net layer

_cut_obs(obs: torch.Tensor)[source]
Overview:

cut the observed information into self’s observation and allay’s observation

Arguments:
  • obs (torch.Tensor): input each agent’s observation

Return:
  • self_features (torch.Tensor): output self agent’s attention observation

  • ally_features (torch.Tensor): output ally agent’s attention observation

forward(inputs: torch.Tensor)[source]
Overview:

forward computation to get agent’s attention observation information

Arguments:
  • obs (torch.Tensor): input each agent’s observation

Return:
  • obs (torch.Tensor): output agent’s attention observation

CollaQ

class ding.model.template.qmix.CollaQ(agent_num: int, obs_shape: int, alone_obs_shape: int, global_obs_shape: int, action_shape: int, hidden_size_list: list, attention: bool = False, self_feature_range: Optional[List[int]] = None, ally_feature_range: Optional[List[int]] = None, attention_size: int = 32, mixer: bool = True, lstm_type: str = 'gru', dueling: bool = False, use_pmixer: bool = False)[source]
Overview:

CollaQ network

Interface:

__init__, forward, _setup_global_encoder

__init__(agent_num: int, obs_shape: int, alone_obs_shape: int, global_obs_shape: int, action_shape: int, hidden_size_list: list, attention: bool = False, self_feature_range: Optional[List[int]] = None, ally_feature_range: Optional[List[int]] = None, attention_size: int = 32, mixer: bool = True, lstm_type: str = 'gru', dueling: bool = False, use_pmixer: bool = False)None[source]
Overview:

initialize Collaq network

Arguments:
  • agent_num (int): the number of agent

  • obs_shape (int): the dimension of each agent’s observation state

  • alone_obs_shape (int): the dimension of each agent’s observation state without other agents

  • global_obs_shape (int): the dimension of global observation state

  • action_shape (int): the dimension of action shape

  • hidden_size_list (list): the list of hidden size

  • attention (bool): use attention module or not, default to False

  • self_feature_range (Union[List[int], None]): the agent’s feature range

  • ally_feature_range (Union[List[int], None]): the agent ally’s feature range

  • attention_size (int): the size of attention net layer

  • mixer (bool): use mixer net or not, default to True

_setup_global_encoder(global_obs_shape: int, embedding_size: int)torch.nn.modules.module.Module[source]
Overview:

Used to encoder global observation.

Arguments:
  • global_obs_shape (int): the dimension of global observation state

  • embedding_size (int): the dimension of state emdedding

Return:
  • outputs (torch.nn.Module): Global observation encoding network

forward(data: dict, single_step: bool = True)dict[source]
Overview:

forward computation graph of collaQ network

Arguments:
  • data (dict): input data dict with keys [‘obs’, ‘prev_state’, ‘action’]
    • agent_state (torch.Tensor): each agent local state(obs)

    • agent_alone_state (torch.Tensor): each agent’s local state alone, in smac setting is without ally feature(obs_along)

    • global_state (torch.Tensor): global state(obs)

    • prev_state (list): previous rnn state, should include 3 parts: one hidden state of q_network, and two hidden state if q_alone_network for obs and obs_alone inputs

    • action (torch.Tensor or None): if action is None, use argmax q_value index as action to calculate agent_q_act

  • single_step (bool): whether single_step forward, if so, add timestep dim before forward and remove it after forward

Return:
  • ret (dict): output data dict with keys [‘total_q’, ‘logit’, ‘next_state’]
    • total_q (torch.Tensor): total q_value, which is the result of mixer network

    • agent_q (torch.Tensor): each agent q_value

    • next_state (list): next rnn state

Shapes:
  • agent_state (torch.Tensor): \((T, B, A, N)\), where T is timestep, B is batch_size A is agent_num, N is obs_shape

  • global_state (torch.Tensor): \((T, B, M)\), where M is global_obs_shape

  • prev_state (list): math:(B, A), a list of length B, and each element is a list of length A

  • action (torch.Tensor): \((T, B, A)\)

  • total_q (torch.Tensor): \((T, B)\)

  • agent_q (torch.Tensor): \((T, B, A, P)\), where P is action_shape

  • next_state (list): math:(B, A), a list of length B, and each element is a list of length A