torch.einsum详解

参考链接:https://zhuanlan.zhihu.com/p/434232512

爱因斯坦求和约定:用于简洁的表示乘积、点积、转置等方法。

假设矩阵 $$A=R^{I * K}$$ 矩阵 $$B=R^{K * J}$$ ,两个矩阵的乘积 C 的维度可以表示为 $$R^{I * J}$$

用爱因斯坦求和约定可以如下表示:

$$ C = AB $$

在代码中上面的式子可以表示为字符串:

'ik,kj->ij'

使用torch.einsum实现上述功能:

>>> import torch
>>> A = torch.randn(3, 4)
>>> B = torch.randn(4, 5)
>>> C = torch.einsum('ik,kj->ij', A, B)
>>> C.shape
torch.Size([3, 5])

也可以这样

>>> A = torch.randn(3, 4)
>>> B = torch.randn(5, 4)
>>> C = torch.einsum('ik,jk->ij', A, B)
>>> C.shape
torch.Size([3, 5])

求和

>>> C = torch.einsum('ij->', A)
>>> C
tensor(1.5575)

列求和

>>> C = torch.einsum('ij->j', A)
>>> C
tensor([-0.4369,  1.1548,  0.8594, -0.0198])
>>> C.shape
torch.Size([4])

行求和同理

点积求和

>>> C = torch.einsum('ij,ij->', A, A)
>>> C
tensor(11.6726)

转置

>>> C = torch.einsum('ij->ji', A)
>>> C.shape
torch.Size([4, 3])

多维矩阵乘法

>>> A = torch.randn(3, 4)
>>> B = torch.randn(3, 4, 5)
>>> C = torch.randn(4, 5)
>>> D = torch.einsum('ij,ijk,jk->ik', A, B, C)
>>> D.shape
torch.Size([3, 5])
文章目录