python - Ejemplo de regresión simple pyBrain
machine-learning neural-network (1)
Estoy tratando de hacer la regresión más simple en pyBrain, pero de alguna manera estoy fallando.
La red neuronal debe aprender la función Y = 3 * X
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.datasets import SupervisedDataSet
from pybrain.structure import FullConnection, FeedForwardNetwork, TanhLayer, LinearLayer, BiasUnit
import matplotlib.pyplot as plt
from numpy import *
n = FeedForwardNetwork()
n.addInputModule(LinearLayer(1, name = ''in''))
n.addInputModule(BiasUnit(name = ''bias''))
n.addModule(TanhLayer(1,name = ''tan''))
n.addOutputModule(LinearLayer(1, name = ''out''))
n.addConnection(FullConnection(n[''bias''], n[''tan'']))
n.addConnection(FullConnection(n[''in''], n[''tan'']))
n.addConnection(FullConnection(n[''tan''], n[''out'']))
n.sortModules()
# initialize the backprop trainer and train
t = BackpropTrainer(n, learningrate = 0.1, momentum = 0.0, verbose = True)
#DATASET
DS = SupervisedDataSet( 1, 1 )
X = random.rand(100,1)*100
Y = X*3+random.rand(100,1)*5
for r in xrange(X.shape[0]):
DS.appendLinked((X[r]),(Y[r]))
t.trainOnDataset(DS, 200)
plt.plot(X,Y,''.b'')
X=[[i] for i in arange(0,100,0.1)]
Y=map(n.activate,X)
plt.plot(X,Y,''-g'')
No aprende nada. Intenté eliminar la capa oculta (porque en este ejemplo ni siquiera necesitamos eso) y la red comenzó a predecir los NaN. ¿Que esta pasando?
EDITAR: Este es el código que resolvió mi problema:
#DATASET
DS = SupervisedDataSet( 1, 1 )
X = random.rand(100,1)*100
Y = X*3+random.rand(100,1)*5
maxy = float(max(Y))
maxx = 100.0
for r in xrange(X.shape[0]):
DS.appendLinked((X[r]/maxx),(Y[r]/maxy))
t.trainOnDataset(DS, 200)
plt.plot(X,Y,''.b'')
X=[[i] for i in arange(0,100,0.1)]
Y=map(lambda x: n.activate(array(x)/maxx)*maxy,X)
plt.plot(X,Y,''-g'')
Las neuronas básicas de los cerebros van a producir algo entre 0 y 1. Divida su Y por 300 (el valor máximo posible), y obtendrá mejores resultados.
De manera más general, encuentre la Y máxima para su conjunto de datos y escale todo con eso.