python - tutorial - Cómo hacer producto de matrices en PyTorch.
seleccionar elementos de una matriz en python (3)
Estas buscando
torch.mm(a,b)
Tenga en cuenta que torch.dot()
comporta de manera diferente a np.dot()
. Se ha discutido sobre lo que sería deseable here . Específicamente, torch.dot()
trata a
y b
como vectores 1D (independientemente de su forma original) y calcula su producto interno. Se produce el error, porque este comportamiento hace que tu a
un vector de longitud 6 y tu b
a vector de longitud 2; por lo tanto, su producto interno no puede ser computado. Para la multiplicación de matrices en PyTorch, use torch.mm()
. El np.dot()
de np.dot()
en contraste, es más flexible; calcula el producto interno para matrices 1D y realiza la multiplicación de matrices para matrices 2D.
En numpy puedo hacer una simple multiplicación de matrices como esta:
a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))
Sin embargo, cuando estoy intentando esto con PyTorch Tensors, esto no funciona:
a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())
print(b)
print(b.size())
print(torch.dot(a, b))
Este código arroja el siguiente error:
RuntimeError: tamaño de tensor inconsistente en /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503
¿Alguna idea de cómo se puede realizar la multiplicación de matrices en PyTorch?
Sobre la base de la respuesta de mexmex, si desea realizar una multiplicación de matrices, puede hacerlo de tres maneras:
AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or even simpler
AB = A @ B # Python 3.5+
Para la multiplicación de elementos, simplemente puede hacer (si A y B tienen la misma forma)
A * B # element-wise matrix multiplication (Hadamard product)
Utilice torch.mm(a, b)
o torch.matmul(a, b)
Ambos son lo mismo.
>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>