admin 管理员组文章数量: 1086865
一文整理5个Pytorch张量乘法函数
~~欢迎关注#公众号:AI算法小喵,会有更多不错的文章分享~~
本文首发于:一文整理5个Pytorch张量乘法函数
最近整理了Pytorch中5个常用的张量乘法函数和用法,建议收藏学习。
1. 张量的维度
在开始今天的学习之前,我们需要先学习一个知识点,即张量的维度。
张量的维度包括两方面内容,其一是维度个数,其二是维度大小。维度个数可以通过张量.ndim
属性查看,维度大小可以通过.shape
或.size()
查看。
>>> a=torch.arange(6).reshape(2,3)
>>> a
tensor([[0, 1, 2],[3, 4, 5]])
>>> a.ndim
2
>>> a.shape
torch.Size([2, 3])
>>> a.size()
torch.Size([2, 3])
比如上面的张量a:维度个数为2,代表a是一个二维张量;维度大小为[2,3],代表第0维的维度大小为2,第1维为3。
2. torch.matmul
我们先学习最复杂也最灵活的torch.matmul函数[1]。
2.1 概览
功能:matmul函数实现的是矩阵乘法,更确切地说,是“混合”矩阵乘法。
参数:
-
input
(张量):第一个张量。 -
other
(张量):第二个张量。 -
out
(张量):结果张量,等同于torch.matmul函数的返回值。
返回值:张量。
2.2 示例代码
matmul函数的行为根据输入张量的不同大体可以分为5种情形(case),所以这里我们也通过5个case下的示例代码来学习这个函数。
(1) case1
若两个张量均为一维张量,则执行向量点积操作,等价于调用torch.dot函数。
比如,下面我们创建了两个一维张量a、b,维度大小均为2。
>>> a=torch.randn(1)
>>> b=torch.randn(1)
>>> a.ndim
1
>>> b.ndim
1
>>> a.size()
torch.Size([2])
>>> b.size()
torch.Size([2])
>>> a
tensor([0.8411])
>>> b
tensor([-1.1787])
然后分别对他们进行matmul操作,和dot操作。从结果比对来看,两个操作是等价的,最终生成的都是scalar标量。
>>> c1=torch.matmul(a,b)
>>> c2=torch.dot(a,b)
>>> c1.equal(c2)
True
>>> c1
tensor(-0.9914)
>>> c1.ndim
0
>>> c1.size()
torch.Size([])
(2) case2
若两个张量均为二维张量,则执行矩阵乘法,等价于调用torch.mm函数。
比如下面的例子,a、b均为2维张量,维度大小分别为[2,2]、[2,3]。即a.size()[1]=b.size()[0]
满足矩阵乘法约束,通过matmul函数或mm函数,我们将获得2维张量,维度大小为[2,3]。
>>> a=torch.randn(2,2)
>>> b=torch.randn(2,3)
>>> a.ndim
2
>>> b.ndim
2
>>> a.size()
torch.Size([2, 2])
>>> b.size()
torch.Size([2, 3])
>>> c1=torch.matmul(a,b)
>>> c2=torch.mm(a,b)
>>> c1.equal(c2)
True
>>> c1.size()
torch.Size([2, 3])
>>> c1.ndim
2
(3) case3
若第一个张量为一维张量,假设维度为[k],第二个张量为二维张量,假设维度为[k,p]。第一个张量会在左边进行维度扩展,维度变为[1,k],然后再进行矩阵乘法,获得维度为[1,p]的张量,然后再去掉扩展的维度,最后结果张量维度为[p]。
比如,a是维度大小为[3]的一维矩阵,b是维度大小为[3,4]的二维矩阵,结果张量c1是一维张量,维度大小为[4]。
>>> a=torch.arange(1,4)
>>> b=torch.arange(2,14).reshape((3,4))
>>> a.ndim
1
>>> a.size()
torch.Size([3])
>>> b.ndim
2
>>> b.size()
torch.Size([3, 4])
>>>
>>> c1=torch.matmul(a,b)
>>> c1.ndim
1
>>> c1.size()
torch.Size([4])
>>>
>>> a
tensor([1, 2, 3])
>>> b
tensor([[ 2, 3, 4, 5],[ 6, 7, 8, 9],[10, 11, 12, 13]])
>>> c1
tensor([44, 50, 56, 62])
更简单地记法,可以视为线性代数中的行向量乘矩阵,结果为第二个张量矩阵的行向量的线性组合,组合系数为第一个张量中相应的值。
>>> c2=1*b[0]+2*b[1]+3*b[2]
>>> c1.equal(c2)
True
>>> c2
tensor([44, 50, 56, 62])
还有一点需要注意,虽然matmul在进行维度扩展后,执行的是矩阵乘法,但在这种情形下,它与mm是不等价的,mm函数要求输入均为二维张量。
>>> torch.mm(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 1D, 2D tensors at ../aten/src/TH/generic/THTensorMath.cpp:131
(4) case4
若第一个张量为二维张量,假设维度为[k,n],第二个张量为一维张量,假设维度为[n]。第二个张量会在右边进行维度扩展,维度变为[n,1],然后再执行矩阵乘法,获得维度为[k,1]的张量,最后再去掉扩展的维度,获得维度为[k]的结果张量。
比如,a是维度大小为[3,2]的二维矩阵,b是维度大小为[2]的一维矩阵,结果张量c1是一维张量,维度大小为[4]。
>>> a=torch.arange(1,7).reshape(3,2)
>>> b=torch.arange(1,3)
>>> a.ndim
2
>>> >>> a.size()
torch.Size([3, 2])
>>> b.ndim
1
>>> b.size()
torch.Size([2])
>>> c1=torch.matmul(a,b)
>>> c1.ndim
1
>>> c1.size()
torch.Size([3])
>>>
>>> a
tensor([[1, 2],[3, 4],[5, 6]])
>>> b
tensor([1, 2])
>>> c1
tensor([ 5, 11, 17])
更简单地,可以视为线性代数中的矩阵乘列向量操作,结果为第一个张量矩阵的列向量的线性组合,组合系数为第二个张量中相应的值。
>>> c2=1*a[:,0]+2*a[:,-1]
>>> c1.equal(c2)
True
>>> c2
tensor([ 5, 11, 17])
当然同case3,这种情况torch.mm也是无法执行的。
>>>
>>> torch.mm(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
(5)case5
如果两个张量的维度均至少为1,且其中至少一个张量维度大于2,那么matmul将执行批矩阵乘法操作:默认使用两个张量的后两维度执行矩阵乘法,其他维度作为batch维。
-
等价于torch.bmm情形
若两个张量均为3维张量,矩阵个数相等(第0维大小相等)且后两维满足矩阵乘法约束,那么调用matmul等价于调用torch.bmm函数。
>>> a=torch.arange(12).reshape(2,2,3)
>>> a
tensor([[[ 0, 1, 2],[ 3, 4, 5]],[[ 6, 7, 8],[ 9, 10, 11]]])
>>> a.size()
torch.Size([2, 2, 3])
>>> a.ndim
3
>>> b=torch.arange(1,7).reshape(2,3,1)
>>> b
tensor([[[1],[2],[3]],[[4],[5],[6]]])
>>> b.size()
torch.Size([2, 3, 1])
>>> b.ndim
3
上面的例子中,a、b均为3维张量,维度大小分别为[2,2,3]、[2,3,1]。第0维大小相等为2,后两维满足矩阵乘法约束。这种情况下,两个函数等价,获得的结果为3维张量,维度大小为[2,2,1]。
>>> c1=torch.matmul(a,b)
>>> c2=torch.bmm(a,b)
>>> c1.equal(c2)
True
>>> c1
tensor([[[ 8],[ 26]],[[107],[152]]])
>>> c1.size()
torch.Size([2, 2, 1])
>>> c1.ndim
3
-
其他情形
其他情形,torch.bmm无法执行,但torch.matmul仍可执行:两个张量的后两维需满足矩阵乘法约束,不满足的情形会进行维度扩展(参考case3,case4),其他维则会通过广播操作对齐。
来看下面这个例子:张量a为1维张量,维度大小为[2];张量b为3维张量,维度大小为[3,2,1]。
>>> a=torch.arange(2)
>>> b=torch.arange(6).reshape(3,2,1)
>>> a.ndim
1
>>> a.size()
torch.Size([2])
>>> b.ndim
3
>>> b.size()
torch.Size([3, 2, 1])
>>> a
tensor([0, 1])
>>> b
tensor([[[0],[1]],[[2],[3]],[[4],[5]]])
为了进行批矩阵乘法,a经过变换(a的其他维经过广播操作、参与矩阵计算的维则进行类似case3的维度扩展)成为维度为[3,1,2]的张量,再与b(维度大小为[3,2,1])进行批矩阵乘法获得维度为[3,1,1]的张量。最终,去掉扩展维度后,结果的维度为[3,1]。
>>> c=torch.matmul(a,b)
>>> c.size()
torch.Size([3, 1])
>>> c.ndim
2
>>> c
tensor([[1],[3],[5]])
>>>
下面这个例子同理。b:[2]->[2,1] (维度扩展) ->[2,1,2,1] (广播操作) ,再与a进行批矩阵乘法,结果:[2,1,3,1]->[2,1,3] (去扩展维度)。
>>> a=torch.arange(12).reshape(2,1,3,2)
>>> b=torch.arange(2)
>>> c=torch.matmul(a,b)
>>> a.ndim,b.ndim,c.ndim
(4, 1, 3)
>>> a.shape
torch.Size([2, 1, 3, 2])
>>> b.shape
torch.Size([2])
>>> c.shape
torch.Size([2, 1, 3])
>>> a
tensor([[[[ 0, 1],[ 2, 3],[ 4, 5]]],[[[ 6, 7],[ 8, 9],[10, 11]]]])
>>> b
tensor([0, 1])
>>> c
tensor([[[ 1, 3, 5]],[[ 7, 9, 11]]])
3. torch.dot
前面在介绍torch.matmul的case1时,已经知道了torch.dot执行的是向量点积计算,这节我们来更细节地学习torch.dot函数[2]。
3.1 概览
功能:向量点积。
参数:
-
input
(张量):第一个张量。 -
other
(张量):第二个张量。 -
out
(张量):结果张量,等同于dot函数的返回值。
返回值:张量(标量)。
重点:只支持具有相同元素个数的两个一维张量做点积操作。
3.2 示例代码
(1)当a、b均为一维张量且维度大小相同时
>>> import torch
>>> a=torch.tensor([1,2])
>>> b=torch.tensor([3,4])
>>> a.ndim==b.ndim==1
True
>>> a.size()==b.size()==torch.Size([2])
True
>>>
>>> c=torch.dot(a,b)
>>> c.ndim
0
>>>
>>> a
tensor([1, 2])
>>> b
tensor([3, 4])
>>> c
tensor(11)
(2)当a、b均为一维张量,但维度大小不一致时
>>> a=torch.tensor([1,2])
>>> b=torch.tensor([3,4,5])
>>> a.ndim==b.ndim==1
True
>>> a.shape==b.shape
False
>>> c=torch.dot(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: inconsistent tensor size, expected tensor [2] and src [3] to have the same number of elements, but got 2 and 3 elements respectively
(3)当a或b不是一维张量时
>>> a=torch.arange(4).reshape(2,2)
>>> a
tensor([[0, 1],[2, 3]])
>>> b=torch.tensor([3,4,5,6]).reshape(2,2)
>>> b
tensor([[3, 4],[5, 6]])
>>>
>>> a.size()
torch.Size([2, 2])
>>> a.ndim
2
>>> torch.dot(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: 1D tensors expected, got 2D, 2D tensors at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:733
总结:从以上几个示例代码的学习,我们可以明确dot函数限制/约束输入的两个张量必须均为一维张量,且元素个数相同。即要求输入的两个张量a、b维度数满足a.ndim==b.ndim=1
,维度大小满足a.shape==b.shape
。
4. torch.mm
和case2下的torch.matmul相同,torch.mm执行的是矩阵乘法运算。这节,我们来一起学习torch.mm[3]函数。
4.1 概览
功能:矩阵乘法。
参数:
-
input
(张量):第一个矩阵,即2维张量。 -
mat2
(张量):第二个矩阵,即2维张量。 -
out
(张量):结果张量。
重点:torch.mm不会进行广播操作,它严格要求两个张量满足维度约束。即,假设两个张量分别为a、b,要求a.size()[1]=b.size()[0]
。
4.2 示例代码
(1)当a、b均为2维张量,且满足维度约束条件
下面的例子中,我们创建了维度为[1,3]和[3,2]的二维张量a、b。然后,通过torch.mm(a,b),获得了维度为[1,2]的二维张量c。
>>> import torch
>>> a=torch.arange(1,4).unsqueeze(0)
>>> b=torch.arange(1,7).reshape(3,2)
>>> c=torch.mm(a,b)
>>>
>>> a.size()
torch.Size([1, 3])
>>> b.size()
torch.Size([3, 2])
>>> c.size()
torch.Size([1, 2])
>>> a.ndim==b.ndim==c.ndim==2
True
>>>
>>> a
tensor([[1, 2, 3]])
>>> b
tensor([[1, 2],[3, 4],[5, 6]])
>>> c
tensor([[22, 28]])
(2)当某一个张量为非二维张量时
下面这个例子中,a是维度大小为[3,2]的二维张量,b是维度大小为[2]的一维张量。torch.mm不会进行广播操作(这里主要是指维度扩展),所以不会像case3中的torch.matmul可以成功执行。
>>> a = torch.arange(6).reshape(3,2)
>>> a
tensor([[0, 1],[2, 3],[4, 5]])
>>> a.size()
torch.Size([3, 2])
>>> a.ndim
2
>>> b = torch.arange(1,3)
>>> b
tensor([1, 2])
>>> b.size()
torch.Size([2])
>>> b.ndim
1
>>>
>>> torch.mm(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
>>>
>>>
>>> c=torch.matmul(a,b)
>>> c
tensor([ 2, 8, 14])
>>> c.size()
torch.Size([3])
>>> c.ndim
1
总结:torch.mm 函数限制/约束输入的两个张量必须均为二维张量,且维度满足矩阵乘法约束。即,要求输入的两个张量a、b维度数满足a.ndim==b.ndim=2
,维度大小则满足a.size()[1]==b.size()[0]
。
5. bmm
前面,通过对torch.matmul函数在case5下行为的学习,我们了解到torch.bmm实现的是批量矩阵乘法计算[4]。本节,我们来具体学习这个函数。
5.1 概览
功能:批量矩阵乘法。
参数:
-
input
(张量):第一批矩阵,即3维张量,第0维表示批大小。 -
mat2
(张量):第二批矩阵,即3维张量,第0维表示批大小。 -
out
(张量):结果张量,等同于torch.bmm函数返回值。也是3维张量,第0维表示批大小。
返回值:三维张量,第0维表示批大小。
重点:bmm不会进行广播操作,它严格要求两个张量均为三维张量,且第0维大小相等(表示有多少个矩阵),其他两维满足矩阵乘法约束。 即,假设两个张量分别为a、b,要求a.size()[0]=b.size()[0]
且a.size()[-1]==b.size()[1]
。
5.2 示例代码
(1) 当a、b均为3维张量,且严格满足约束条件。
比如,这里a、b均为3维张量,维度大小分别为[2,3,2]、[2,2,5]。a.size()[-1]==b.size()[1]==2
满足维度约束,a.size()[0]==b.size()[0]==2
说明批大小相同,具有相同个数的矩阵。
这种情形下,所以torch.bmm与torch.matmul完全等价。
>>> a=torch.arange(1,13).reshape(2,3,2)
>>> b=torch.arange(20).reshape(2,2,5)
>>> a.ndim==b.ndim==3
True
>>> a.size()[-1]==b.size()[1]
>>> c1=torch.bmm(a,b)
>>> c1
tensor([[[ 10, 13, 16, 19, 22],[ 20, 27, 34, 41, 48],[ 30, 41, 52, 63, 74]],[[190, 205, 220, 235, 250],[240, 259, 278, 297, 316],[290, 313, 336, 359, 382]]])
>>> c1.size()
torch.Size([2, 3, 5])
>>> c2=torch.matmul(a,b)
>>> c2
tensor([[[ 10, 13, 16, 19, 22],[ 20, 27, 34, 41, 48],[ 30, 41, 52, 63, 74]],[[190, 205, 220, 235, 250],[240, 259, 278, 297, 316],[290, 313, 336, 359, 382]]])
>>> c2.size()
torch.Size([2, 3, 5])
>>> c2.equal(c1)
True
>>> c1.ndim
3
>>> torch.mm(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 3D, 3D tensors at ../aten/src/TH/generic/THTensorMath.cpp:131
(2) 批大小不同时
a与b均为3维张量,且满足矩阵乘法约束,但是批大小不同。torch.bmm不会执行广播操作,所以这种情形下它无法成功执行。但支持广播操作的torch.matmul可以成功执行。
>>> a=torch.arange(1,13).reshape(2,3,2)
>>> b=torch.arange(10).reshape(2,5).unsqueeze(0)
>>> a.size()
torch.Size([2, 3, 2])
>>> b.size()
torch.Size([1, 2, 5])
>>> a.ndim==b.ndim==3
True
>>> torch.bmm(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor to have size 2 at dimension 0, but got size 1 for argument #2 'batch2' (while checking arguments for bmm)
>>> c=torch.matmul(a,b)
>>> c
tensor([[[ 10, 13, 16, 19, 22],[ 20, 27, 34, 41, 48],[ 30, 41, 52, 63, 74]],[[ 40, 55, 70, 85, 100],[ 50, 69, 88, 107, 126],[ 60, 83, 106, 129, 152]]])
>>> c.size()
torch.Size([2, 3, 5])
>>> c.ndim
3
总结:torch.bmm 函数限制/约束输入的两个张量必须均为三维张量,其中第0维大小相同,其他维满足矩阵乘法约束。
6. mul与*
本节我们来学习最后一个常用的张量乘法函数 torch.mul[5] ,它与 * 等价,实现的是逐元素(即element-wise)相乘。
6.1 概览
功能:逐元素相乘。
参数:
-
input
(张量):第一个张量。 -
other
(张量):第二个张量。 -
out(张量):结果张量,等同于mul函数的返回值。
返回值:张量。
重点:要求两个张量维度相同,即a.size()==b.size()
;若不同,则通过广播操作将相乘的两个张量的维度变得相同。同时,它的广播操作还会将两个张量类型统一。
6.2 示例代码
(1) 当a、b维度相同且类型相同时
下例中,我们先创建了两个类型为torch.LongTensor的张量a、b,他们的维度均为[2,3],然后执行了两个等价的计算操作:a*b
与torch.mul(a,b)
。
>>> a=torch.LongTensor(2,3)
>>> b=torch.LongTensor(2,3)
>>> c1=torch.mul(a,b)
>>> c2=a*b
>>> c1.equal(c2)
True
>>>
>>> a.size()==b.size()==c1.size()==torch.Size([2, 3])
True
>>> a.type()==b.type()==c1.type()=="torch.LongTensor"
True
(2) 当a、b向量维度相同,类型不同时
今天我们所学的5个张量乘法函数中,只有torch.matmul和torch.mul支持广播操作。torch.matmul的广播操作仅针对张量的维度,而torch.mul还支持张量的类型变换。
下面的例子中,我们首先创建了二维张量a与一维张量b。a的维度为[2,3],类型为torch.LongTensor;b的维度为[3],类型为torch.FloatTensor。
>>> a=torch.arange(2,8).reshape(2,3)
>>> a
tensor([[2, 3, 4],[5, 6, 7]])
>>> a.size()
torch.Size([2, 3])
>>> a.ndim
2
>>> a.type()
'torch.LongTensor'
>>>
>>> b=torch.randn(3)
>>> b
tensor([ 1.1250, 0.8435, -0.5835])
>>> b.size()
torch.Size([3])
>>> b.ndim
1
>>> b.type()
'torch.FloatTensor'
>>>
然后,我们尝试torch.matmul操作,执行失败。错误信息提示我们a与b类型不一致。也就是说,虽然torch.matmul支持广播操作,但是仅是针对张量的维度,而不包括张量类型。所以,即使两个张量满足torch.matmul在维度上的要求,但类型不一致,也是无法正确让torch.matmul函数执行的。
>>> torch.matmul(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'vec' in call to _th_mv
相反,torch.mul函数则是可以正确执行的。
>>> c=torch.mul(a,b)
tensor([[ 2.2500, 2.5306, -2.3339],[ 5.6249, 5.0612, -4.0844]])
>>> c.type()
'torch.FloatTensor'
如果想让torch.matmul函数正确执行,我们可以手动调整张量b的类型。
>>> b=b.type(torch.long)
>>> b
tensor([1, 0, 0])
>>> torch.matmul(a,b)
tensor([2, 5])
(3) 当a、b向量类型相同,维度不同时
-
示例
下例中,a、b均是类型为torch.LongTensor的张量。
>>> a=torch.arange(1,3).unsqueeze(1)
>>> b=torch.arange(1,4)
>>> a.type()==b.type()=='torch.LongTensor'
二维张量a的维度大小为[2,1],一维张量b的维度大小为[3]。
>>> a.size()
torch.Size([2, 1])
>>> a.ndim
2
>>> b.size()
torch.Size([3])
>>> b.ndim
1
torch.mul 通过广播操作将a、b拉伸至具有同样的shape,然后再执行逐元素乘法,最后获得二维张量C1,C1的维度大小为[2,3]。
>>> c1=torch.mul(a,b)
>>> c1.size()
torch.Size([2, 3])
>>> c1.ndim
2
>>> c1.type()
'torch.LongTensor'
-
维度上的广播操作
根据我们在第一节对维度的说明,我们知道广播操作包括两个层面:
1. 首先,若维度数不同,维度较少的张量需要在最左边进行维度扩展,使维度数相同。
2. 然后,若各维度的维度大小不同,维度大小为1的张量需要在该维上复制元素,扩展拉伸至维度大小和另一个张量在该维上的大小相同。
-
模拟广播操作
为了更好更直观地理解上述对广播操作的描述,我们接下来尝试手动复现广播操作。
首先,我们先看a、b、c1各自的值:
>>> a
tensor([[1],[2]])
>>> b
tensor([1, 2, 3])
>>> c1
tensor([[1, 2, 3],[2, 4, 6]])
然后,我们来分析维度数。a与b维度数不同,维度为1的b比维度为2的a少一个维度,所以b需要在最左边扩展一个维度。扩展后,a与b维度数相同,均为2。
>>> b=b.unsqueeze(0)
>>> b.ndim
2
>>> b.size()
torch.Size([1, 3])
>>> b
tensor([[1, 2, 3]])
接着我们来分析维度大小。第0维:a维度大小为2,b维度大小为1;第1维:a维度大小为1,b的维度大小为3。也就是说,a、b在两个维度上大小都不同。所以,a需要在第1维复制元素至维度大小为3,b则需要在第0维复制元素至维度大小为2。
>>> a=a.repeat_interleave(3,dim=-1)
>>> a
tensor([[1, 1, 1],[2, 2, 2]])
>>> a.size()
torch.Size([2, 3])
>>> b=b.repeat_interleave(2,dim=0)
>>> b.size()
torch.Size([2, 3])
>>> b
tensor([[1, 2, 3],[1, 2, 3]])
最后,我们验证下结果是否和之前的一致:
>>> a1*b1
tensor([[1, 2, 3],[2, 4, 6]])
还需要注意:当两个张量可以通过扩展维度使维度数相同时,若两个张量在相应的维度大小上相等,或者大小不同但其中较小的大小为1时,才可以执行计算。
比如,我们将b保持不变,将a换为维度大小为[2,2]的二维张量,torch.mul就无法正常执行了。
>>> a=torch.arange(4).reshape(2,2)
>>> a
tensor([[0, 1],[2, 3]])
>>> b
tensor([1, 2, 3])
>>> a.size()
torch.Size([2, 2])
>>> b.size()
torch.Size([3])
>>> torch.mul(a,b)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
>>> a*b
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
总结:不同于前面的那4个函数,torch.mul实现的是逐元素相乘。它可以通过广播操作将输入的两个张量扩展成具有相同的维度大小和维度数,还可以将两个张量变为相同类型。
7. 5个张量乘法函数
最后再总结下Pytorch中常用的5个张量乘法函数:
# 向量点积运算。要求输入为一维张量且类型相同、元素个数相同,输出为scaler标量。
torch.dot
# 矩阵乘法运算,不支持广播操作。要求输入为二维张量且类型相同,维度大小满足矩阵乘法约束。
torch.mm
# 批矩阵乘法运算,不支持广播操作。要求输入为三维张量且类型相同,第0维大小相等,后两维大小满足矩阵乘法约束。
torch.bmm
# 混合矩阵乘法运算,包括向量点积、矩阵乘法、批矩阵乘法,且支持广播操作(仅针对维度)。要求输入张量类型相同,具体行为根据维度可以分五种情况。
torch.matmul
# 逐元素乘法,等价于*,支持广播操作(包括维度及类型)。无特殊要求或约束。
torch.mul
参考资料
[1] torch.matmul: .matmul.html?highlight=torch%20matmul#torch.matmul
[2] torch.dot: .dot.html?highlight=torch%20dot#torch.dot
[3] torch.mm: .mm.html?highlight=torch%20mm#torch.mm
[4] torch.bmm: .bmm.html?highlight=torch%20bmm#torch.bmm
[5] torch.mul: .mul.html?highlight=torch%20mul#torch.mul
欢迎大家关注公众号#AI算法小喵,上面会不定期分享一些关于深度学习、机器学习、NLP的干货知识和实践笔记。
本文标签: 一文整理5个Pytorch张量乘法函数
版权声明:本文标题:一文整理5个Pytorch张量乘法函数 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/p/1697150691a262287.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论