neural network - autograd - Pytorch, ¿cuáles son los argumentos de gradiente?
autograd python (3)
Estoy leyendo la documentación de PyTorch y encontré un ejemplo donde escriben
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
donde x era una variable inicial, a partir de la cual se construyó y (un vector de 3). La pregunta es, ¿cuáles son los argumentos 0.1, 1.0 y 0.0001 del tensor de gradientes? La documentación no es muy clara al respecto.
Explicación
Para las redes neuronales, usualmente utilizamos la
loss
para evaluar qué tan bien la red ha aprendido a clasificar la imagen de entrada (u otras tareas).
El término de
loss
suele ser un valor escalar.
Para actualizar los parámetros de la red, necesitamos calcular el gradiente de
loss
wrt para los parámetros, que en realidad es un
leaf node
en el gráfico de cálculo (por cierto, estos parámetros son principalmente el peso y el sesgo de varias capas como la convolución , Lineal y así sucesivamente).
De acuerdo con la regla de la cadena, para calcular el gradiente de
loss
wrt a un nodo hoja, podemos calcular la derivada de
loss
wrt alguna variable intermedia, y el gradiente de la variable intermedia wrt a la variable hoja, hacer un producto de punto y sumar todo esto.
Los argumentos de
gradient
del método
backward()
de una
Variable
se usan para
calcular una suma ponderada de cada elemento de una Variable en la
Variable hoja
.
Este peso es solo el derivado de la
loss
final en cada elemento de la variable intermedia.
Un ejemplo concreto
Tomemos un ejemplo concreto y simple para entender esto.
from torch.autograd import Variable
import torch
x = Variable(torch.FloatTensor([[1, 2, 3, 4]]), requires_grad=True)
z = 2*x
loss = z.sum(dim=1)
# do backward for first element of z
z.backward(torch.FloatTensor([[1, 0, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_() #remove gradient in x.grad, or it will be accumulated
# do backward for second element of z
z.backward(torch.FloatTensor([[0, 1, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# do backward for all elements of z, with weight equal to the derivative of
# loss w.r.t z_1, z_2, z_3 and z_4
z.backward(torch.FloatTensor([[1, 1, 1, 1]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# or we can directly backprop using loss
loss.backward() # equivalent to loss.backward(torch.FloatTensor([1.0]))
print(x.grad.data)
En el ejemplo anterior, el resultado de la primera
print
es
2 0 0 0
[torch.FloatTensor de tamaño 1x4]
que es exactamente la derivada de z_1 wrt a x.
El resultado de la segunda
print
es:
0 2 0 0
[torch.FloatTensor de tamaño 1x4]
que es la derivada de z_2 wrt a x.
Ahora, si usa un peso de [1, 1, 1, 1] para calcular la derivada de z wrt a x, el resultado es
1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx
.
Así que no sorprende que la salida de la 3ra
print
sea:
2 2 2 2
[torch.FloatTensor de tamaño 1x4]
Cabe señalar que el vector de peso [1, 1, 1, 1] es exactamente derivado de la
loss
wrt a z_1, z_2, z_3 y z_4.
La derivada de la
loss
wrt a
x
se calcula como:
d(loss)/dx = d(loss)/dz_1 * dz_1/dx + d(loss)/dz_2 * dz_2/dx + d(loss)/dz_3 * dz_3/dx + d(loss)/dz_4 * dz_4/dx
Entonces, la salida de la cuarta
print
es la misma que la tercera
print
:
2 2 2 2
[torch.FloatTensor de tamaño 1x4]
Aquí, la salida de forward (), es decir, y es aa 3-vector.
Los tres valores son los gradientes en la salida de la red. Por lo general, se establecen en 1.0 si y es la salida final, pero también pueden tener otros valores, especialmente si y es parte de una red más grande.
Por ej. si x es la entrada, y = [y1, y2, y3] es una salida intermedia que se usa para calcular la salida final z,
Entonces,
dz/dx = dz/dy1 * dy1/dx + dz/dy2 * dy2/dx + dz/dy3 * dy3/dx
Así que aquí, los tres valores al revés son
[dz/dy1, dz/dy2, dz/dy3]
y luego hacia atrás () calcula dz / dx
Por lo general, su gráfico computacional tiene una salida escalar que dice
loss
.
Luego puede calcular el gradiente de
loss
wrt los pesos (
w
) por
loss.backward()
.
Donde el argumento predeterminado de
backward()
es
1.0
.
Si su salida tiene múltiples valores (por ejemplo,
loss=[loss1, loss2, loss3]
), puede calcular los gradientes de pérdida con los pesos por
loss.backward(torch.FloatTensor([1.0, 1.0, 1.0]))
.
Además, si desea agregar pesos o importancias a diferentes pérdidas, puede usar
loss.backward(torch.FloatTensor([-0.1, 1.0, 0.0001]))
.
Esto significa calcular
-0.1*d(loss1)/dw, d(loss2)/dw, 0.0001*d(loss3)/dw
simultáneamente.