pytorch中的torch.nn.Unfold和torch.nn.Fold
参考链接:https://blog.csdn.net/a3630623/article/details/120639367
参考链接2:https://blog.csdn.net/weixin_44076434/article/details/106545037
官方参考链接:https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html?highlight=nn%20unfold#torch.nn.Unfold
官方说说明代码(Conv = Unfold + Matrix Multiplication + Fold ):
unfold = nn.Unfold(kernel_size=(2, 3))
input = torch.randn(2, 5, 3, 4)
output = unfold(input)
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
output.size()
# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
(torch.nn.functional.conv2d(inp, w) - out).abs().max()
代码:
import torch
import torch.nn as nn
if __name__ == '__main__':
unfold = nn.Unfold(kernel_size=(2, 2))
input = torch.range(1, 16).view([1, 1, 4, 4])
input_unfold = unfold(input)
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
print("input")
print(input)
print("input_unfold")
print(input_unfold)
# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
input = torch.range(1, 16).view([1, 1, 4, 4])
w = torch.range(1, 4).view([1, 1, 2, 2])
print("input")
print(input)
print("w")
print(w)
input_unfold = torch.nn.functional.unfold(input, (2, 2))
print("input_unfold")
print(input_unfold)
out_unfold = input_unfold.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
print("out_unfold")
print(out_unfold)
out = torch.nn.functional.fold(out_unfold, (3, 3), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
conv_out = torch.nn.functional.conv2d(input, w)
print("out")
print(out)
print("conv_out")
print(conv_out)
diff = (conv_out-out).abs().max()
print(diff)
结果
input
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]])
input_unfold
tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.]]])
input
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]]]])
w
tensor([[[[1., 2.],
[3., 4.]]]])
input_unfold
tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.]]])
out_unfold
tensor([[[ 44., 54., 64., 84., 94., 104., 124., 134., 144.]]])
out
tensor([[[[ 44., 54., 64.],
[ 84., 94., 104.],
[124., 134., 144.]]]])
conv_out
tensor([[[[ 44., 54., 64.],
[ 84., 94., 104.],
[124., 134., 144.]]]])
tensor(0.)