python - xlabel - Generando nuevas imágenes con PyTorch.
subplot title python (2)
El código de su ejemplo ( https://github.com/davidsonmizael/gan ) me hizo el mismo ruido que muestra. La pérdida del generador disminuyó demasiado rápido.
Hubo algunas cosas con errores, ni siquiera estoy seguro de qué, pero supongo que es fácil descubrir las diferencias por tu cuenta. Para una comparación, también vea este tutorial: GAN en 50 líneas de PyTorch
.... same as your code
print("# Starting generator and descriminator...")
netG = G()
netG.apply(weights_init)
netD = D()
netD.apply(weights_init)
if torch.cuda.is_available():
netG.cuda()
netD.cuda()
#training the DCGANs
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))
epochs = 25
timeElapsed = []
for epoch in range(epochs):
print("# Starting epoch [%d/%d]..." % (epoch, epochs))
for i, data in enumerate(dataloader, 0):
start = time.time()
time.clock()
#updates the weights of the discriminator nn
netD.zero_grad()
#trains the discriminator with a real image
real, _ = data
if torch.cuda.is_available():
inputs = Variable(real.cuda()).cuda()
target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
else:
inputs = Variable(real)
target = Variable(torch.ones(inputs.size()[0]))
output = netD(inputs)
errD_real = criterion(output, target)
errD_real.backward() #retain_graph=True
#trains the discriminator with a fake image
if torch.cuda.is_available():
D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
target = Variable(torch.zeros(inputs.size()[0]).cuda()).cuda()
else:
D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
target = Variable(torch.zeros(inputs.size()[0]))
D_fake = netG(D_noise).detach()
D_fake_ouput = netD(D_fake)
errD_fake = criterion(D_fake_ouput, target)
errD_fake.backward()
# NOT:backpropagating the total error
# errD = errD_real + errD_fake
optimizerD.step()
#for i, data in enumerate(dataloader, 0):
#updates the weights of the generator nn
netG.zero_grad()
if torch.cuda.is_available():
G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
else:
G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
target = Variable(torch.ones(inputs.size()[0]))
fake = netG(G_noise)
G_output = netD(fake)
errG = criterion(G_output, target)
#backpropagating the error
errG.backward()
optimizerG.step()
if i % 50 == 0:
#prints the losses and save the real images and the generated images
print("# Progress: ")
print("[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f" % (epoch, epochs, i, len(dataloader), errD_real.data[0], errG.data[0]))
#calculates the remaining time by taking the avg seconds that every loop
#and multiplying by the loops that still need to run
timeElapsed.append(time.time() - start)
avg_time = (sum(timeElapsed) / float(len(timeElapsed)))
all_dtl = (epoch * len(dataloader)) + i
rem_dtl = (len(dataloader) - i) + ((epochs - epoch) * len(dataloader))
remaining = (all_dtl - rem_dtl) * avg_time
print("# Estimated remaining time: %s" % (time.strftime("%H:%M:%S", time.gmtime(remaining))))
if i % 100 == 0:
vutils.save_image(real, "%s/real_samples.png" % "./results", normalize = True)
vutils.save_image(fake.data, "%s/fake_samples_epoch_%03d.png" % ("./results", epoch), normalize = True)
print ("# Finished.")
Resultado después de 25 épocas (tamaño de lote 256) en CIFAR-10:
Estoy estudiando GAN. He completado el curso que me dio un ejemplo de un programa que genera imágenes basadas en ejemplos ingresados.
El ejemplo se puede encontrar aquí:
https://github.com/davidsonmizael/gan
Así que decidí usar eso para generar nuevas imágenes basadas en un conjunto de datos de fotos frontales de caras, pero no estoy teniendo ningún éxito. A diferencia del ejemplo anterior, el código solo genera ruido, mientras que la entrada tiene imágenes reales.
En realidad no tengo idea de qué debo cambiar para que el código apunte a la dirección correcta y aprenda de las imágenes. No he cambiado un solo valor en el código proporcionado en el ejemplo, pero no funciona.
Si alguien me puede ayudar a entender esto y señalarme la dirección correcta sería muy útil. Gracias por adelantado.
Mi discriminador:
class D(nn.Module):
def __init__(self):
super(D, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias = False),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(256, 512, 4, 2, 1, bias = False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(512, 1, 4, 1, 0, bias = False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1)
Mi generador
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
Mi función para iniciar los pesos:
def weights_init(m):
classname = m.__class__.__name__
if classname.find(''Conv'') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find(''BatchNorm'') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
El código completo se puede ver aquí:
https://github.com/davidsonmizael/criminal-gan
El entrenamiento de GAN no es muy rápido. Supongo que no está utilizando un modelo pre-entrenado, sino que está aprendiendo desde cero. En la época 25 es bastante normal no ver ningún patrón significativo en las muestras. Me doy cuenta de que el proyecto github te muestra algo genial después de 25 épocas, pero eso también depende del tamaño del conjunto de datos. CIFAR-10 (el que se usó en la página de github) tiene 60000 imágenes. 25 épocas significa que la red las ha visto todas 25 veces.
No sé qué conjunto de datos está utilizando, pero si es más pequeño, podría tomar más épocas hasta que vea los resultados, ya que la red puede ver menos imágenes en total. Si las imágenes en su conjunto de datos tienen una resolución más alta, también puede tomar más tiempo.
Debes verificar nuevamente después de al menos unos pocos cientos, si no unos pocos miles de épocas.
Por ejemplo, en la foto frontal de la cara de datos después de 25 épocas: