你选择了一个相当不幸的例子。 torch.nn.functional.max_pool1d 不是一个实例 torch.autograd.Function 因为它是一个内置的PyTorch,用C ++代码定义并带有 自动生成 Python绑定。我不确定是否有可能获得 backward 财产通过其界面。
torch.nn.functional.max_pool1d
torch.autograd.Function
backward
首先,如果您没有注意到,您不需要为此公式的反向传播编写任何自定义代码,因为它们都是电源操作和 max_pool1d 已经定义了它,因此它们的组成也被autograd覆盖。假设你的目标是一个练习,我建议你更多地手动完成(不要回头 backward 的 max_pool1d )。一个例子如下
max_pool1d
import torch import torch.nn.functional as F import torch.autograd as tag class SquareAndMaxPool1d(tag.Function): @staticmethod def forward(ctx, input, kernel_size, **kwargs): # we're gonna need indices for backward. Currently SquareAnd... # never actually returns indices, I left it out for simplicity kwargs['return_indices'] = True input_sqr = input ** 2 output, indices = F.max_pool1d(input_sqr, kernel_size, **kwargs) ctx.save_for_backward(input, indices) return output @staticmethod def backward(ctx, grad_output): input, indices = ctx.saved_tensors # first we need to reconstruct the gradient of `max_pool1d` # by putting all the output gradient elements (corresponding to # input elements which made it through the max_pool1d) in their # respective places, the rest has gradient of 0. We do it by # scattering it against a tensor of 0s grad_output_unpooled = torch.zeros_like(input) grad_output_unpooled.scatter_(2, indices, grad_output) # then incorporate the gradient of the "square" part of your # operator grad_input = 2. * input * grad_output_unpooled # the docs for backward # https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward # say that "it should return as many tensors, as there were inputs # to forward()". It fails to mention that if an argument was not a # tensor, it should return None (I remember reading this somewhere, # but can't find it anymore). Anyway, we need to # return a (grad_input, None) tuple to avoid a complaint that two # outputs were expected return grad_input, None
然后我们可以使用 数值梯度检查器 验证操作是否按预期工作。
f = SquareAndMaxPool1d.apply xT = torch.randn(1, 1, 6, requires_grad=True, dtype=torch.float64) tag.gradcheck(lambda t: f(t, 2), xT)
如果这不能解决你如何获得的问题,我很抱歉 backward 的 max_pool1d ,但希望你发现我的答案非常有用。