einsum()函数是基于爱因斯坦求和约定实现的多维数组运算工具,在NumPy、PyTorch等科学计算库中都有实现。它通过简洁的索引符号表达复杂的张量运算,是处理多维数组的强大工具。
爱因斯坦求和约定基础
爱因斯坦求和约定的核心原则是:当一个索引在表达式中出现两次,则默认对该索引进行求和。例如矩阵乘法可以表示为:
C = np.einsum("ij,jk->ik", A, B)
这等效于np.dot(A, B)
或A @ B
,但einsum()的真正价值在于它能推广到更复杂的多维张量运算。
einsum()的基本语法
einsum()的基本语法包含三个部分:
- 输入操作数:要参与运算的输入张量
- 下标字符串:描述操作的字符串
- 输出形状:箭头右侧指定输出张量的形状
例如,矩阵乘法的下标字符串"ij,jk->ik"表示:
- 第一个矩阵的第i行和第二个矩阵的第j列相乘
- 结果为输出矩阵的第k行
einsum()的常见应用场景
- 矩阵乘法:
einsum('ij,jk->ik', A, B)
- 向量内积:
einsum('i,i->', a, b)
- 矩阵转置:
einsum('ij->ji', A)
- 逐元素乘法:
einsum('ij,ij->ij', A, B)
- 迹运算:
einsum('ii->', A)
einsum()的高级特性
- 广播机制:einsum()支持NumPy的广播规则,可以对不同形状的数组进行操作
- 隐式输出:可以省略输出形状,此时函数会按照重复索引进行求和
- 多操作数:支持多个输入张量的复杂运算
- 性能优化:对于某些操作,einsum()比传统方法更高效
PyTorch中的einsum()
PyTorch也提供了torch.einsum()
函数,用法与NumPy类似:
import torch
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ij,jk->ik', A, B) # 结果是一个2x4的矩阵:ml-citation{ref="4,8" data="citationList"}