python - raspberry - ¿Cómo puedo implementar un RNN personalizado(específicamente un ESN) en Tensorflow?
tensorflow raspberry pi (1)
Estoy tratando de definir mi propia RNNCell (Echo State Network) en Tensorflow, de acuerdo con la siguiente definición.
x (t + 1) = tanh (Win * u (t) + W * x (t) + Wfb * y (t))
y (t) = Wout * z (t)
z (t) = [x (t), u (t)]
x es estado, u es entrada, y es salida. Win, W y Wfb no se pueden entrenar. Todos los pesos se inicializan aleatoriamente, pero W se modifica así: "Establece un cierto porcentaje de elementos de W en 0, escala W para mantener su radio espectral por debajo de 1.0
Tengo este código para generar la ecuación.
x = tf.Variable(tf.reshape(tf.zeros([N]), [-1, N]), trainable=False, name="state_vector")
W = tf.Variable(tf.random_normal([N, N], 0.0, 0.05), trainable=False)
# TODO: setup W according to the ESN paper
W_x = tf.matmul(x, W)
u = tf.placeholder("float", [None, K], name="input_vector")
W_in = tf.Variable(tf.random_normal([K, N], 0.0, 0.05), trainable=False)
W_in_u = tf.matmul(u, W_in)
z = tf.concat(1, [x, u])
W_out = tf.Variable(tf.random_normal([K + N, L], 0.0, 0.05))
y = tf.matmul(z, W_out)
W_fb = tf.Variable(tf.random_normal([L, N], 0.0, 0.05), trainable=False)
W_fb_y = tf.matmul(y, W_fb)
x_next = tf.tanh(W_in_u + W_x + W_fb_y)
y_ = tf.placeholder("float", [None, L], name="train_output")
Mi problema es doble Primero, no sé cómo implementar esto como una superclase de RNNCell. En segundo lugar, no sé cómo generar un tensor W de acuerdo con la especificación anterior.
Cualquier ayuda sobre cualquiera de estas preguntas es muy apreciada. Tal vez pueda encontrar una manera de preparar W, pero estoy seguro de que no entiendo cómo implementar mi propio RNN como una superclase de RNNCell.
Para dar un resumen rápido:
Busque en el código fuente de TensorFlow bajo python/ops/rnn_cell.py
también vea cómo subclase RNNCell. Por lo general, es así:
class MyRNNCell(RNNCell):
def __init__(...):
@property
def output_size(self):
...
@property
def state_size(self):
...
def __call__(self, input_, state, name=None):
... your per-step iteration here ...