rl_tuils.td¶
Temporal Differnece¶
dist_nstep_td_error¶
- Overview:
Multistep (1 step or n step) td_error for distributed q-learning based algorithm
- Arguments:
data (
dist_nstep_td_data): the input data, dist_nstep_td_data to calculate lossgamma (
float): discount factornstep (
int): nstep num, default set to 1
- Returns:
loss (
torch.Tensor): nstep td error, 0-dim tensor
- Shapes:
- data (
dist_nstep_td_data): the dist_nstep_td_data containing [‘dist’, ‘next_n_dist’, ‘act’, ‘reward’, ‘done’, ‘weight’]
- data (
dist (
torch.FloatTensor): \((B, N, n_atom)\) i.e. [batch_size, action_dim, n_atom]next_n_dist (
torch.FloatTensor): \((B, N, n_atom)\)act (
torch.LongTensor): \((B, )\)next_n_act (
torch.LongTensor): \((B, )\)reward (
torch.FloatTensor): \((T, B)\), where T is timestep(nstep)done (
torch.BoolTensor) \((B, )\), whether done in last timestep
q_nstep_td_error¶
- Overview:
Multistep (1 step or n step) td_error for q-learning based algorithm
- Arguments:
data (
q_nstep_td_data): the input data, q_nstep_td_data to calculate lossgamma (
float): discount factorcum_reward (
bool): whether to use cumulative nstep reward, which is figured out when collecting datavalue_gamma (
torch.Tensor): gamma discount value for target q_valuecriterion (
torch.nn.modules): loss function criterionnstep (
int): nstep num, default set to 1
- Returns:
loss (
torch.Tensor): nstep td error, 0-dim tensortd_error_per_sample (
torch.Tensor): nstep td error, 1-dim tensor
- Shapes:
data (
q_nstep_td_data): the q_nstep_td_data containing [‘q’, ‘next_n_q’, ‘action’, ‘reward’, ‘done’]q (
torch.FloatTensor): \((B, N)\) i.e. [batch_size, action_dim]next_n_q (
torch.FloatTensor): \((B, N)\)action (
torch.LongTensor): \((B, )\)next_n_action (
torch.LongTensor): \((B, )\)reward (
torch.FloatTensor): \((T, B)\), where T is timestep(nstep)done (
torch.BoolTensor) \((B, )\), whether done in last timesteptd_error_per_sample (
torch.FloatTensor): \((B, )\)
q_nstep_td_error_with_rescale¶
- Overview:
Multistep (1 step or n step) td_error with value rescaling
- Arguments:
data (
q_nstep_td_data): the input data, q_nstep_td_data to calculate lossgamma (
float): discount factornstep (
int): nstep num, default set to 1criterion (
torch.nn.modules): loss function criteriontrans_fn (
Callable): value transfrom function, default to value_transform (refer to rl_utils/value_rescale.py)inv_trans_fn (
Callable): value inverse transfrom function, default to value_inv_transform (refer to rl_utils/value_rescale.py)
- Returns:
loss (
torch.Tensor): nstep td error, 0-dim tensor
- Shapes:
data (
q_nstep_td_data): the q_nstep_td_data containing [‘q’, ‘next_n_q’, ‘action’, ‘reward’, ‘done’]q (
torch.FloatTensor): \((B, N)\) i.e. [batch_size, action_dim]next_n_q (
torch.FloatTensor): \((B, N)\)action (
torch.LongTensor): \((B, )\)next_n_action (
torch.LongTensor): \((B, )\)reward (
torch.FloatTensor): \((T, B)\), where T is timestep(nstep)done (
torch.BoolTensor) \((B, )\), whether done in last timestep
td_lambda_error¶
- Overview:
Computing TD(lambda) loss given constant gamma and lambda. There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (including the terminal state, values[terminal] should also be 0)
- Arguments:
data (
namedtuple): td_lambda input data with fields [‘value’, ‘reward’, ‘weight’]gamma (
float): constant discount factor gamma, should be in [0, 1], defaults to 0.9lambda (
float): constant lambda, should be in [0, 1], defaults to 0.8
- Returns:
loss (
torch.Tensor): Computed MSE loss, averaged over the batch
- Shapes:
value (
torch.FloatTensor): \((T+1, B)\), where T is trajectory length and B is batch, which is the estimation of the state value at step 0 to Treward (
torch.FloatTensor): \((T, B)\), the returns from time step 0 to T-1weight (
torch.FloatTensoror None): \((B, )\), the training sample weightloss (
torch.FloatTensor): \(()\), 0-dim tensor
generalized_lambda_returns¶
- Overview:
Functional equivalent to trfl.value_ops.generalized_lambda_returns https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 Passing in a number instead of tensor to make the value constant for all samples in batch
- Arguments:
bootstrap_values (
torch.Tensororfloat): estimation of the value at step 0 to T, of size [T_traj+1, batchsize]rewards (
torch.Tensor): the returns from 0 to T-1, of size [T_traj, batchsize]gammas (
torch.Tensororfloat): discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]lambda (
torch.Tensororfloat): determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize]
- Returns:
return (
torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]
multistep_forward_view¶
- Overview:
Same as trfl.sequence_ops.multistep_forward_view Implementing (12.18) in Sutton & Barto
` result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] for t in 0...T-2 : result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1]) `Assuming the first dim of input tensors correspond to the index in batch There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (including the terminal state, which is, bootstrap_values[terminal] should also be 0)
- Arguments:
bootstrap_values (
torch.Tensor): estimation of the value at step 1 to T, of size [T_traj, batchsize]rewards (
torch.Tensor): the returns from 0 to T-1, of size [T_traj, batchsize]gammas (
torch.Tensor): discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]- lambda (
torch.Tensor): determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored and effectively set to 0, as there is no information about future rewards.
- lambda (
- Returns:
- ret (
torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]
- ret (