目录

Pytorch中矩阵乘法使用及案例

Pytorch中矩阵乘法使用及案例

六种矩阵乘法

torch 中包含许多矩阵乘法,大致可以分为以下几种:

  • * :即 a * b 按位相乘,要求 ab 的形状必须一致,支持广播操作
  • torch.matmul() :最广泛的矩阵乘法
  • @ :与 torch.matmul() 效果一样(等价),即 torch.matmul(a, b) == a @ b
  • torch.dot() :两个一维向量乘法,不支持广播
  • torch.mm() :两个二维矩阵的乘法,不支持广播
  • torch.bmm() :两个三维矩阵乘法(批次 batch 粒度),且两个矩阵必须是三维的,不支持广播操作

其中, torch.matmul() 中包含 torch.dot()torch.mm()torch.bmm()

代码验证

torch.dot()

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])

## 下面四个函数的结果是一样的  结果都是7
a.dot(b)
torch.dot(a, b)
a @ b
torch.matmul(a, b)

输出结果:

https://i-blog.csdnimg.cn/direct/8112e1ba543a4fb7a3be65a5f014d00f.png

torch.matmul()torch.dot() 的主要区别就是,当两个向量(矩阵)的维度不一致时, torch.matmul() 会进行 广播 ,而 torch.dot() 会报错

*

对向量 ab 进行 按位相乘

a = torch.tensor([2, 3])
b = torch.tensor([2, 1])

a * b  # [4, 3]

torch.mm()

用于二维矩阵的相乘——第一个向量的 和第二个向量的 必须 相等

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)

## 下面三个输出结果是一样的
torch.mm(mat1, mat2)
mat1.matmul(mat2)
mat1 @ mat2

输出结果:

https://i-blog.csdnimg.cn/direct/969893d91a8c4851bbba7cae60156dc7.png

torch.matmul()torch.mm() 的主要区别就是,当两个矩阵的维度不一致时, torch.matmul() 会进行 广播 ,而 torch.mm() 会报错

torch.bmm()

应用于三维矩阵,要求:

  • 两个矩阵的第一个维度的大小 必须相同
  • 必须满足第一个矩阵: (b × n × m) ,第二个矩阵: (b × m × p) ,即第一个矩阵的第三个维度必须和第二个矩阵的第二个维度相同
  • 输出大小: (b × n × p)

该函数相当于 分别对每个 batch 进行二维矩阵相乘

bmat1 = torch.randn(2, 1, 4)
bmat2 = torch.randn(2, 4, 2)

## 下面三个输出是一样的
torch.bmm(bmat1, bmat2)
bmat1.matmul(bmat2)
bmat1 @ bmat2

输出结果:

https://i-blog.csdnimg.cn/direct/65f1106768c749bf858adf82035c4c49.png

换一种角度想, torch.bmm() 就是相当于按照批次 batch 进行索引,然后将每个批次内的二维矩阵进行相乘

for i in range(bmat1.shape[0]):  # 索引出来批次bmat1.shape[0]
    temp =torch.mm(bmat1[i, :, :], bmat2[i, :, :])
    print(temp)

https://i-blog.csdnimg.cn/direct/21c2c55021b140189b72b8002468a717.png