torch.einsum
官方参考链接:https://pytorch.org/docs/stable/generated/torch.einsum.html?highlight=einsum#torch.einsum
参考博客:https://blog.csdn.net/nihate/article/details/90480459
复杂代码样例:
import torch
import torch.nn as nn
import numpy as np
if __name__ == '__main__':
print('learn nn.Bilinear')
m = nn.Bilinear(20, 30, 40)
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 30)
output = m(input1, input2)
print(output.size())
arr_output = output.data.cpu().numpy()
weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = input1.data.cpu().numpy()
x2 = input2.data.cpu().numpy()
print(x1.shape, weight.shape, x2.shape, bias.shape)
y = np.zeros((x1.shape[0], weight.shape[0]))
for k in range(weight.shape[0]):
buff = np.dot(x1, weight[k])
buff = buff * x2
buff = np.sum(buff, axis=1)
y[:, k] = buff
y += bias
dif = y - arr_output
print(np.max(np.abs(dif.flatten())))
x1 = input1
x2 = input2
weight = m.weight
bias = m.bias
einsum_output = torch.einsum('ab,kbc,ac->ak', x1, weight, x2)
einsum_output += bias.data
print(torch.max(torch.abs(einsum_output-output)))