what autograd neural-network gradient pytorch torch gradient-descent

neural network - autograd - Pytorch, ¿cuáles son los argumentos de gradiente?



autograd python (3)

Estoy leyendo la documentación de PyTorch y encontré un ejemplo donde escriben

gradients = torch.FloatTensor([0.1, 1.0, 0.0001]) y.backward(gradients) print(x.grad)

donde x era una variable inicial, a partir de la cual se construyó y (un vector de 3). La pregunta es, ¿cuáles son los argumentos 0.1, 1.0 y 0.0001 del tensor de gradientes? La documentación no es muy clara al respecto.


Explicación

Para las redes neuronales, usualmente utilizamos la loss para evaluar qué tan bien la red ha aprendido a clasificar la imagen de entrada (u otras tareas). El término de loss suele ser un valor escalar. Para actualizar los parámetros de la red, necesitamos calcular el gradiente de loss wrt para los parámetros, que en realidad es un leaf node en el gráfico de cálculo (por cierto, estos parámetros son principalmente el peso y el sesgo de varias capas como la convolución , Lineal y así sucesivamente).

De acuerdo con la regla de la cadena, para calcular el gradiente de loss wrt a un nodo hoja, podemos calcular la derivada de loss wrt alguna variable intermedia, y el gradiente de la variable intermedia wrt a la variable hoja, hacer un producto de punto y sumar todo esto.

Los argumentos de gradient del método backward() de una Variable se usan para calcular una suma ponderada de cada elemento de una Variable en la Variable hoja . Este peso es solo el derivado de la loss final en cada elemento de la variable intermedia.

Un ejemplo concreto

Tomemos un ejemplo concreto y simple para entender esto.

from torch.autograd import Variable import torch x = Variable(torch.FloatTensor([[1, 2, 3, 4]]), requires_grad=True) z = 2*x loss = z.sum(dim=1) # do backward for first element of z z.backward(torch.FloatTensor([[1, 0, 0, 0]]), retain_graph=True) print(x.grad.data) x.grad.data.zero_() #remove gradient in x.grad, or it will be accumulated # do backward for second element of z z.backward(torch.FloatTensor([[0, 1, 0, 0]]), retain_graph=True) print(x.grad.data) x.grad.data.zero_() # do backward for all elements of z, with weight equal to the derivative of # loss w.r.t z_1, z_2, z_3 and z_4 z.backward(torch.FloatTensor([[1, 1, 1, 1]]), retain_graph=True) print(x.grad.data) x.grad.data.zero_() # or we can directly backprop using loss loss.backward() # equivalent to loss.backward(torch.FloatTensor([1.0])) print(x.grad.data)

En el ejemplo anterior, el resultado de la primera print es

2 0 0 0
[torch.FloatTensor de tamaño 1x4]

que es exactamente la derivada de z_1 wrt a x.

El resultado de la segunda print es:

0 2 0 0
[torch.FloatTensor de tamaño 1x4]

que es la derivada de z_2 wrt a x.

Ahora, si usa un peso de [1, 1, 1, 1] para calcular la derivada de z wrt a x, el resultado es 1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx . Así que no sorprende que la salida de la 3ra print sea:

2 2 2 2
[torch.FloatTensor de tamaño 1x4]

Cabe señalar que el vector de peso [1, 1, 1, 1] es exactamente derivado de la loss wrt a z_1, z_2, z_3 y z_4. La derivada de la loss wrt a x se calcula como:

d(loss)/dx = d(loss)/dz_1 * dz_1/dx + d(loss)/dz_2 * dz_2/dx + d(loss)/dz_3 * dz_3/dx + d(loss)/dz_4 * dz_4/dx

Entonces, la salida de la cuarta print es la misma que la tercera print :

2 2 2 2
[torch.FloatTensor de tamaño 1x4]


Aquí, la salida de forward (), es decir, y es aa 3-vector.

Los tres valores son los gradientes en la salida de la red. Por lo general, se establecen en 1.0 si y es la salida final, pero también pueden tener otros valores, especialmente si y es parte de una red más grande.

Por ej. si x es la entrada, y = [y1, y2, y3] es una salida intermedia que se usa para calcular la salida final z,

Entonces,

dz/dx = dz/dy1 * dy1/dx + dz/dy2 * dy2/dx + dz/dy3 * dy3/dx

Así que aquí, los tres valores al revés son

[dz/dy1, dz/dy2, dz/dy3]

y luego hacia atrás () calcula dz / dx


Por lo general, su gráfico computacional tiene una salida escalar que dice loss . Luego puede calcular el gradiente de loss wrt los pesos ( w ) por loss.backward() . Donde el argumento predeterminado de backward() es 1.0 .

Si su salida tiene múltiples valores (por ejemplo, loss=[loss1, loss2, loss3] ), puede calcular los gradientes de pérdida con los pesos por loss.backward(torch.FloatTensor([1.0, 1.0, 1.0])) .

Además, si desea agregar pesos o importancias a diferentes pérdidas, puede usar loss.backward(torch.FloatTensor([-0.1, 1.0, 0.0001])) .

Esto significa calcular -0.1*d(loss1)/dw, d(loss2)/dw, 0.0001*d(loss3)/dw simultáneamente.