COMA¶
COMAPolicy¶
- class ding.policy.coma.COMAPolicy(cfg: dict, model: Optional[Union[type, torch.nn.modules.module.Module]] = None, enable_field: Optional[List[str]] = None)[source]¶
- Overview:
Policy class of COMA algorithm. COMA is a multi model reinforcement learning algorithm
- Interface:
- _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn
_init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval_reset_eval, _get_train_sample, default_model, _monitor_vars_learn
- Config:
ID
Symbol
Type
Default Value
Description
Other(Shape)
1
typestr
coma
RL policy register name, refer toregistryPOLICY_REGISTRYthis arg is optional,a placeholder2
cudabool
False
Whether to use cuda for networkthis arg can be diff-erent from modes3
on_policybool
True
Whether the RL algorithm is on-policyor off-policyprioritybool
False
Whether use priority(PER)priority sample,update priority5
priority_IS_weightbool
False
Whether use Importance SamplingWeight to correct biased update.IS weight6
learn.update_per_collectint
1
How many updates(iterations) to trainafter collector’s one collection. Onlyvalid in serial trainingthis args can be varyfrom envs. Bigger valmeans more off-policy7
learn.target_update_thetafloat
0.001
Target network update momentumparameter.between[0,1]8
learn.discount_factorfloat
0.99
Reward’s future discount factor, aka.gammamay be 1 when sparsereward env9
learn.td_lambdafloat
0.8
The trade-off factor of td-lambda,which balances 1step td and mc10
learn.value_weightfloat
1.0
The loss weight of value networkpolicy network weightis set to 111
learn.entropy_weightfloat
0.01
The loss weight of entropyregularizationpolicy network weightis set to 1
- _data_preprocess_learn(data: List[Any]) → dict[source]¶
- Overview:
Preprocess the data to fit the required data format for learning
- Arguments:
- data (
List[Dict[str, Any]]): the data collected from collect function, the Dict in data should contain keys including at least [‘obs’, ‘action’, ‘reward’]
- data (
- Returns:
- data (
Dict[str, Any]): the processed data, including at least [‘obs’, ‘action’, ‘reward’, ‘done’, ‘weight’]
- data (
- _forward_collect(data: dict, eps: float) → dict[source]¶
- Overview:
Collect output according to eps_greedy plugin
- Arguments:
data (
dict): Dict type data, including at least [‘obs’].
- Returns:
data (
dict): The collected data
- _forward_eval(data: dict) → dict[source]¶
- Overview:
Forward function of collect mode, similar to
self._forward_collect.- Arguments:
data (
dict): Dict type data, including at least [‘obs’].
- Returns:
output (
dict): Dict type data, including at least inferred action according to input obs.
- _forward_learn(data: dict) → Dict[str, Any][source]¶
- Overview:
Forward and backward function of learn mode, acquire the data and calculate the loss andoptimize learner model
- Arguments:
- data (
Dict[str, Any]): Dict type data, a batch of data for training, values are torch.Tensor or np.ndarray or dict/list combinations.
- data (
- Returns:
- info_dict (
Dict[str, Any]): Dict type data, a info dict indicated training result, which will be recorded in text log and tensorboard, values are python scalar or a list of scalars.
- info_dict (
- ArgumentsKeys:
necessary:
obs,action,reward,done,weight
- ReturnsKeys:
- necessary:
cur_lr,total_loss,policy_loss,value_loss,entropy_loss cur_lr (
float): Current learning ratetotal_loss (
float): The calculated losspolicy_loss (
float): The policy(actor) loss of comavalue_loss (
float): The value(critic) loss of comaentropy_loss (
float): The entropy loss
- necessary:
- _get_train_sample(data: collections.deque) → Union[None, List[Any]][source]¶
- Overview:
Get the train sample from trajectory
- Arguments:
data (
deque): The trajectory’s cache
- Returns:
samples (
dict): The training samples generated
- _init_collect() → None[source]¶
- Overview:
Collect mode init moethod. Called by
self.__init__. Init traj and unroll length, collect model. Model has eps_greedy_sample wrapper and hidden state wrapper
- _init_eval() → None[source]¶
- Overview:
Evaluate mode init method. Called by
self.__init__. Init eval model with argmax strategy and hidden_state plugin.
- _init_learn() → None[source]¶
- Overview:
Init the learner model of COMAPolicy
- Arguments:
Note
The _init_learn method takes the argument from the self._cfg.learn in the config file
learning_rate (
float): The learning rate fo the optimizergamma (
float): The discount factorlambda (
float): The lambda factor, determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep,value_wight(
float): The weight of value loss in total lossentropy_weight(
float): The weight of entropy loss in total lossagent_num (
int): Since this is a multi-agent algorithm, we need to input the agent num.batch_size (
int): Need batch size info to init hidden_state plugins
- _monitor_vars_learn() → List[str][source]¶
- Overview:
Return variables’ name if variables are to used in monitor.
- Returns:
vars (
List[str]): Variables’ name list.
- _process_transition(obs: Any, model_output: dict, timestep: collections.namedtuple) → dict[source]¶
- Overview:
Generate dict type transition data from inputs.
- Arguments:
obs (
Any): Env observationmodel_output (
dict): Output of collect model, including at least [‘action’, ‘prev_state’]- timestep (
namedtuple): Output after env step, including at least [‘obs’, ‘reward’, ‘done’] (here ‘obs’ indicates obs after env step).
- timestep (
- Returns:
transition (
dict): Dict type transition data.
- default_model() → Tuple[str, List[str]][source]¶
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
model_info (
Tuple[str, List[str]]): model name and mode import_names
Note
The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For coma,
ding.model.coma.coma