torch.einsum

官方参考链接:https://pytorch.org/docs/stable/generated/torch.einsum.html?highlight=einsum#torch.einsum

参考博客:https://blog.csdn.net/nihate/article/details/90480459

复杂代码样例:

import torch
import torch.nn as nn
import numpy as np

if __name__ == '__main__':

    print('learn nn.Bilinear')
    m = nn.Bilinear(20, 30, 40)
    input1 = torch.randn(128, 20)
    input2 = torch.randn(128, 30)
    output = m(input1, input2)
    print(output.size())
    arr_output = output.data.cpu().numpy()

    weight = m.weight.data.cpu().numpy()
    bias = m.bias.data.cpu().numpy()
    x1 = input1.data.cpu().numpy()
    x2 = input2.data.cpu().numpy()
    print(x1.shape, weight.shape, x2.shape, bias.shape)
    y = np.zeros((x1.shape[0], weight.shape[0]))
    for k in range(weight.shape[0]):
        buff = np.dot(x1, weight[k])
        buff = buff * x2
        buff = np.sum(buff, axis=1)
        y[:, k] = buff
    y += bias
    dif = y - arr_output
    print(np.max(np.abs(dif.flatten())))

    x1 = input1
    x2 = input2
    weight = m.weight
    bias = m.bias

    einsum_output = torch.einsum('ab,kbc,ac->ak', x1, weight, x2)
    einsum_output += bias.data

    print(torch.max(torch.abs(einsum_output-output)))
文章目录