checkpoint_helper

build_checkpoint_helper

Overview:

Use config to build checkpoint helper.

Arguments:
  • cfg (dict): ckpt_helper config

Returns:
  • (CheckpointHelper): checkpoint_helper created by this function

CheckpointHelper

class ding.torch_utils.checkpoint_helper.CheckpointHelper[source]
Overview:

Help to save or load checkpoint by give args.

Interface:

save, load

load(load_path: str, model: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer = None, last_iter: CountVar = None, last_epoch: CountVar = None, last_frame: CountVar = None, lr_schduler: Scheduler = None, dataset: torch.utils.data.dataset.Dataset = None, collector_info: torch.nn.modules.module.Module = None, prefix_op: str = None, prefix: str = None, strict: bool = True, logger_prefix: str = '', state_dict_mask: list = [])[source]
Overview:

Load checkpoint by given path

Arguments:
  • load_path (str): checkpoint’s path

  • model (torch.nn.Module): model definition

  • optimizer (torch.optim.Optimizer): optimizer obj

  • last_iter (CountVar): iter num, default None

  • last_epoch (CountVar): epoch num, default None

  • last_frame (CountVar): frame num, default None

  • lr_schduler (Schduler): lr_schduler obj

  • dataset (torch.utils.data.Dataset): dataset, should be replaydataset

  • collector_info (torch.nn.Module): attr of checkpoint, save collector info

  • prefix_op (str): should be [‘remove’, ‘add’], process on state_dict

  • prefix (str): prefix to be processed on state_dict

  • strict (bool): args of model.load_state_dict

  • logger_prefix (str): prefix of logger

  • state_dict_mask (list): A list containing state_dict keys,

    which shouldn’t be loaded into model(after prefix op)

..note:

The checkpoint loaded from load_path is a dict, whose format is like ‘{‘state_dict’: OrderedDict(), …}’

save(path: str, model: torch.nn.modules.module.Module, optimizer: Optional[torch.optim.optimizer.Optimizer] = None, last_iter: Optional[ding.torch_utils.checkpoint_helper.CountVar] = None, last_epoch: Optional[ding.torch_utils.checkpoint_helper.CountVar] = None, last_frame: Optional[ding.torch_utils.checkpoint_helper.CountVar] = None, dataset: Optional[torch.utils.data.dataset.Dataset] = None, collector_info: Optional[torch.nn.modules.module.Module] = None, prefix_op: Optional[str] = None, prefix: Optional[str] = None)None[source]
Overview:

Save checkpoint by given args

Arguments:
  • path (str): the path of saving checkpoint

  • model (torch.nn.Module): model to be saved

  • optimizer (torch.optim.Optimizer): optimizer obj

  • last_iter (CountVar): iter num, default None

  • last_epoch (CountVar): epoch num, default None

  • last_frame (CountVar): frame num, default None

  • dataset (torch.utils.data.Dataset): dataset, should be replaydataset

  • collector_info (torch.nn.Module): attr of checkpoint, save collector info

  • prefix_op (str): should be [‘remove’, ‘add’], process on state_dict

  • prefix (str): prefix to be processed on state_dict

CountVar

class ding.torch_utils.checkpoint_helper.CountVar(init_val: int)[source]
Overview:

Number counter

Interface:

val, update, add

add(add_num: int)[source]
Overview:

Add the number to counter

Arguments:
  • add_num (int): the number added to the counter

update(val: int)None[source]
Overview:

Update the var counter

Arguments:
  • val (int): the update value of the counter

auto_checkpoint

Overview:

Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method whenever an exception happens.

Arguments:
  • func(Callable): the function to be wrapped

Returns:
  • wrapper (Callable): the wrapped function