pytorch中的torch.nn.Unfold和torch.nn.Fold

参考链接:https://blog.csdn.net/a3630623/article/details/120639367

参考链接2:https://blog.csdn.net/weixin_44076434/article/details/106545037

官方参考链接:https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html?highlight=nn%20unfold#torch.nn.Unfold

官方说说明代码(Conv = Unfold + Matrix Multiplication + Fold ):

unfold = nn.Unfold(kernel_size=(2, 3))
input = torch.randn(2, 5, 3, 4)
output = unfold(input)
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
output.size()

# Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = torch.nn.functional.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
# out = out_unf.view(1, 2, 7, 8)
(torch.nn.functional.conv2d(inp, w) - out).abs().max()

代码:

import torch
import torch.nn as nn

if __name__ == '__main__':
    unfold = nn.Unfold(kernel_size=(2, 2))
    input = torch.range(1, 16).view([1, 1, 4, 4])
    input_unfold = unfold(input)
    # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
    # 4 blocks (2x3 kernels) in total in the 3x4 input
    print("input")
    print(input)
    print("input_unfold")
    print(input_unfold)

    # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
    input = torch.range(1, 16).view([1, 1, 4, 4])
    w = torch.range(1, 4).view([1, 1, 2, 2])
    print("input")
    print(input)
    print("w")
    print(w)
    input_unfold = torch.nn.functional.unfold(input, (2, 2))
    print("input_unfold")
    print(input_unfold)
    out_unfold = input_unfold.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
    print("out_unfold")
    print(out_unfold)
    out = torch.nn.functional.fold(out_unfold, (3, 3), (1, 1))
    # or equivalently (and avoiding a copy),
    # out = out_unf.view(1, 2, 7, 8)
    conv_out = torch.nn.functional.conv2d(input, w)
    print("out")
    print(out)
    print("conv_out")
    print(conv_out)
    diff = (conv_out-out).abs().max()
    print(diff)

结果

input
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
input_unfold
tensor([[[ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
         [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
         [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
         [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.]]])
input
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
w
tensor([[[[1., 2.],
          [3., 4.]]]])
input_unfold
tensor([[[ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
         [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
         [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
         [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.]]])
out_unfold
tensor([[[ 44.,  54.,  64.,  84.,  94., 104., 124., 134., 144.]]])
out
tensor([[[[ 44.,  54.,  64.],
          [ 84.,  94., 104.],
          [124., 134., 144.]]]])
conv_out
tensor([[[[ 44.,  54.,  64.],
          [ 84.,  94., 104.],
          [124., 134., 144.]]]])
tensor(0.)

文章目录