torch.einsum详解
参考链接:https://zhuanlan.zhihu.com/p/434232512
爱因斯坦求和约定:用于简洁的表示乘积、点积、转置等方法。
假设矩阵 $$A=R^{I * K}$$ 矩阵 $$B=R^{K * J}$$ ,两个矩阵的乘积 C 的维度可以表示为 $$R^{I * J}$$
用爱因斯坦求和约定可以如下表示:
$$ C = AB $$
在代码中上面的式子可以表示为字符串:
'ik,kj->ij'
使用torch.einsum实现上述功能:
>>> import torch
>>> A = torch.randn(3, 4)
>>> B = torch.randn(4, 5)
>>> C = torch.einsum('ik,kj->ij', A, B)
>>> C.shape
torch.Size([3, 5])
也可以这样
>>> A = torch.randn(3, 4)
>>> B = torch.randn(5, 4)
>>> C = torch.einsum('ik,jk->ij', A, B)
>>> C.shape
torch.Size([3, 5])
求和
>>> C = torch.einsum('ij->', A)
>>> C
tensor(1.5575)
列求和
>>> C = torch.einsum('ij->j', A)
>>> C
tensor([-0.4369, 1.1548, 0.8594, -0.0198])
>>> C.shape
torch.Size([4])
行求和同理
点积求和
>>> C = torch.einsum('ij,ij->', A, A)
>>> C
tensor(11.6726)
转置
>>> C = torch.einsum('ij->ji', A)
>>> C.shape
torch.Size([4, 3])
多维矩阵乘法
>>> A = torch.randn(3, 4)
>>> B = torch.randn(3, 4, 5)
>>> C = torch.randn(4, 5)
>>> D = torch.einsum('ij,ijk,jk->ik', A, B, C)
>>> D.shape
torch.Size([3, 5])