rl_tuils.td

Temporal Differnece

dist_nstep_td_error

Overview:

Multistep (1 step or n step) td_error for distributed q-learning based algorithm

Arguments:
  • data (dist_nstep_td_data): the input data, dist_nstep_td_data to calculate loss

  • gamma (float): discount factor

  • nstep (int): nstep num, default set to 1

Returns:
  • loss (torch.Tensor): nstep td error, 0-dim tensor

Shapes:
  • data (dist_nstep_td_data): the dist_nstep_td_data containing

    [‘dist’, ‘next_n_dist’, ‘act’, ‘reward’, ‘done’, ‘weight’]

  • dist (torch.FloatTensor): \((B, N, n_atom)\) i.e. [batch_size, action_dim, n_atom]

  • next_n_dist (torch.FloatTensor): \((B, N, n_atom)\)

  • act (torch.LongTensor): \((B, )\)

  • next_n_act (torch.LongTensor): \((B, )\)

  • reward (torch.FloatTensor): \((T, B)\), where T is timestep(nstep)

  • done (torch.BoolTensor) \((B, )\), whether done in last timestep

q_nstep_td_error

Overview:

Multistep (1 step or n step) td_error for q-learning based algorithm

Arguments:
  • data (q_nstep_td_data): the input data, q_nstep_td_data to calculate loss

  • gamma (float): discount factor

  • cum_reward (bool): whether to use cumulative nstep reward, which is figured out when collecting data

  • value_gamma (torch.Tensor): gamma discount value for target q_value

  • criterion (torch.nn.modules): loss function criterion

  • nstep (int): nstep num, default set to 1

Returns:
  • loss (torch.Tensor): nstep td error, 0-dim tensor

  • td_error_per_sample (torch.Tensor): nstep td error, 1-dim tensor

Shapes:
  • data (q_nstep_td_data): the q_nstep_td_data containing [‘q’, ‘next_n_q’, ‘action’, ‘reward’, ‘done’]

  • q (torch.FloatTensor): \((B, N)\) i.e. [batch_size, action_dim]

  • next_n_q (torch.FloatTensor): \((B, N)\)

  • action (torch.LongTensor): \((B, )\)

  • next_n_action (torch.LongTensor): \((B, )\)

  • reward (torch.FloatTensor): \((T, B)\), where T is timestep(nstep)

  • done (torch.BoolTensor) \((B, )\), whether done in last timestep

  • td_error_per_sample (torch.FloatTensor): \((B, )\)

q_nstep_td_error_with_rescale

Overview:

Multistep (1 step or n step) td_error with value rescaling

Arguments:
  • data (q_nstep_td_data): the input data, q_nstep_td_data to calculate loss

  • gamma (float): discount factor

  • nstep (int): nstep num, default set to 1

  • criterion (torch.nn.modules): loss function criterion

  • trans_fn (Callable): value transfrom function, default to value_transform (refer to rl_utils/value_rescale.py)

  • inv_trans_fn (Callable): value inverse transfrom function, default to value_inv_transform (refer to rl_utils/value_rescale.py)

Returns:
  • loss (torch.Tensor): nstep td error, 0-dim tensor

Shapes:
  • data (q_nstep_td_data): the q_nstep_td_data containing [‘q’, ‘next_n_q’, ‘action’, ‘reward’, ‘done’]

  • q (torch.FloatTensor): \((B, N)\) i.e. [batch_size, action_dim]

  • next_n_q (torch.FloatTensor): \((B, N)\)

  • action (torch.LongTensor): \((B, )\)

  • next_n_action (torch.LongTensor): \((B, )\)

  • reward (torch.FloatTensor): \((T, B)\), where T is timestep(nstep)

  • done (torch.BoolTensor) \((B, )\), whether done in last timestep

td_lambda_error

Overview:

Computing TD(lambda) loss given constant gamma and lambda. There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (including the terminal state, values[terminal] should also be 0)

Arguments:
  • data (namedtuple): td_lambda input data with fields [‘value’, ‘reward’, ‘weight’]

  • gamma (float): constant discount factor gamma, should be in [0, 1], defaults to 0.9

  • lambda (float): constant lambda, should be in [0, 1], defaults to 0.8

Returns:
  • loss (torch.Tensor): Computed MSE loss, averaged over the batch

Shapes:
  • value (torch.FloatTensor): \((T+1, B)\), where T is trajectory length and B is batch, which is the estimation of the state value at step 0 to T

  • reward (torch.FloatTensor): \((T, B)\), the returns from time step 0 to T-1

  • weight (torch.FloatTensor or None): \((B, )\), the training sample weight

  • loss (torch.FloatTensor): \(()\), 0-dim tensor

generalized_lambda_returns

Overview:

Functional equivalent to trfl.value_ops.generalized_lambda_returns https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 Passing in a number instead of tensor to make the value constant for all samples in batch

Arguments:
  • bootstrap_values (torch.Tensor or float): estimation of the value at step 0 to T, of size [T_traj+1, batchsize]

  • rewards (torch.Tensor): the returns from 0 to T-1, of size [T_traj, batchsize]

  • gammas (torch.Tensor or float): discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]

  • lambda (torch.Tensor or float): determining the mix of bootstrapping vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize]

Returns:
  • return (torch.Tensor): Computed lambda return value for each state from 0 to T-1, of size [T_traj, batchsize]

multistep_forward_view

Overview:

Same as trfl.sequence_ops.multistep_forward_view Implementing (12.18) in Sutton & Barto

` result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] for t in 0...T-2 : result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1]) `

Assuming the first dim of input tensors correspond to the index in batch There is no special handling for terminal state value, if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal (including the terminal state, which is, bootstrap_values[terminal] should also be 0)

Arguments:
  • bootstrap_values (torch.Tensor): estimation of the value at step 1 to T, of size [T_traj, batchsize]

  • rewards (torch.Tensor): the returns from 0 to T-1, of size [T_traj, batchsize]

  • gammas (torch.Tensor): discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]

  • lambda (torch.Tensor): determining the mix of bootstrapping vs further accumulation of

    multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored and effectively set to 0, as there is no information about future rewards.

Returns:
  • ret (torch.Tensor): Computed lambda return value

    for each state from 0 to T-1, of size [T_traj, batchsize]