nn_test_helper

is_differentiable

Overview:

Judge whether the model/models are differentiable. First check whether module’s grad is None, then do loss’s back propagation, finally check whether module’s grad are torch.Tensor.

Arguments:
  • loss (torch.Tensor): loss tensor of the model

  • model (Union[torch.nn.Module, List[torch.nn.Module]]): model or models to be checked

  • print_instead (bool): Whether to print module’s final grad result, instead of asserting. Default set to False.