nn_test_helper¶
is_differentiable¶
- Overview:
Judge whether the model/models are differentiable. First check whether module’s grad is None, then do loss’s back propagation, finally check whether module’s grad are torch.Tensor.
- Arguments:
loss (
torch.Tensor): loss tensor of the modelmodel (
Union[torch.nn.Module, List[torch.nn.Module]]): model or models to be checkedprint_instead (
bool): Whether to print module’s final grad result, instead of asserting. Default set toFalse.