• Notifications

/dcgan.torch

Permalink
Switch branches/tags
Nothing to show
Find file
caf5a9b on 3 Jul 2016
@soumith @szagoruyko @gwern
247 lines (215 sloc) 8.49 KB
require 'torch'
require 'nn'
require 'optim'
opt = {
dataset = 'lsun', -- imagenet / lsun / folder
batchSize = 64,
loadSize = 96,
fineSize = 64,
nz = 100, -- # of dim for Z
ngf = 64, -- # of gen filters in first conv layer
ndf = 64, -- # of discrim filters in first conv layer
nThreads = 4, -- # of data loading threads to use
niter = 25, -- # of iter at starting learning rate
lr = 0.0002, -- initial learning rate for adam
beta1 = 0.5, -- momentum term of adam
ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset
display = 1, -- display samples while training. 0 = false
display_id = 10, -- display window id.
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
name = 'experiment1',
noise = 'normal', -- uniform / normal
}
-- one-line argument parser. parses enviroment variables to override the defaults
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
if opt.display == 0 then opt.display = false end
opt.manualSeed = torch.random(1, 10000) -- fix seed
print("Random Seed: " .. opt.manualSeed)
torch.manualSeed(opt.manualSeed)
torch.setnumthreads(1)
torch.setdefaulttensortype('torch.FloatTensor')
-- create data loader
local DataLoader = paths.dofile('data/data.lua')
local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
print("Dataset: " .. opt.dataset, " Size: ", data:size())
----------------------------------------------------------------------------
local function weights_init(m)
local name = torch.type(m)
if name:find('Convolution') then
m.weight:normal(0.0, 0.02)
m:noBias()
elseif name:find('BatchNormalization') then
if m.weight then m.weight:normal(1.0, 0.02) end
if m.bias then m.bias:fill(0) end
end
end
local nc = 3
local nz = opt.nz
local ndf = opt.ndf
local ngf = opt.ngf
local real_label = 1
local fake_label = 0
local SpatialBatchNormalization = nn.SpatialBatchNormalization
local SpatialConvolution = nn.SpatialConvolution
local SpatialFullConvolution = nn.SpatialFullConvolution
local netG = nn.Sequential()
-- input is Z, going into a convolution
netG:add(SpatialFullConvolution(nz, ngf * 8, 4, 4))
netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true))
-- state size: (ngf*8) x 4 x 4
netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true))
-- state size: (ngf*4) x 8 x 8
netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true))
-- state size: (ngf*2) x 16 x 16
netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true))
-- state size: (ngf) x 32 x 32
netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1))
netG:add(nn.Tanh())
-- state size: (nc) x 64 x 64
netG:apply(weights_init)
local netD = nn.Sequential()
-- input is (nc) x 64 x 64
netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1))
netD:add(nn.LeakyReLU(0.2, true))
-- state size: (ndf) x 32 x 32
netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 16 x 16
netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*4) x 8 x 8
netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 4 x 4
netD:add(SpatialConvolution(ndf * 8, 1, 4, 4))
netD:add(nn.Sigmoid())
-- state size: 1 x 1 x 1
netD:add(nn.View(1):setNumInputDims(3))
-- state size: 1
netD:apply(weights_init)
local criterion = nn.BCECriterion()
---------------------------------------------------------------------------
optimStateG = {
learningRate = opt.lr,
beta1 = opt.beta1,
}
optimStateD = {
learningRate = opt.lr,
beta1 = opt.beta1,
}
----------------------------------------------------------------------------
local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize)
local noise = torch.Tensor(opt.batchSize, nz, 1, 1)
local label = torch.Tensor(opt.batchSize)
local errD, errG
local epoch_tm = torch.Timer()
local tm = torch.Timer()
local data_tm = torch.Timer()
----------------------------------------------------------------------------
if opt.gpu > 0 then
require 'cunn'
cutorch.setDevice(opt.gpu)
input = input:cuda(); noise = noise:cuda(); label = label:cuda()
if pcall(require, 'cudnn') then
require 'cudnn'
cudnn.benchmark = true
cudnn.convert(netG, cudnn)
cudnn.convert(netD, cudnn)
end
netD:cuda(); netG:cuda(); criterion:cuda()
end
local parametersD, gradParametersD = netD:getParameters()
local parametersG, gradParametersG = netG:getParameters()
if opt.display then disp = require 'display' end
noise_vis = noise:clone()
if opt.noise == 'uniform' then
noise_vis:uniform(-1, 1)
elseif opt.noise == 'normal' then
noise_vis:normal(0, 1)
end
-- create closure to evaluate f(X) and df/dX of discriminator
local fDx = function(x)
gradParametersD:zero()
-- train with real
data_tm:reset(); data_tm:resume()
local real = data:getBatch()
data_tm:stop()
input:copy(real)
label:fill(real_label)
local output = netD:forward(input)
local errD_real = criterion:forward(output, label)
local df_do = criterion:backward(output, label)
netD:backward(input, df_do)
-- train with fake
if opt.noise == 'uniform' then -- regenerate random noise
noise:uniform(-1, 1)
elseif opt.noise == 'normal' then
noise:normal(0, 1)
end
local fake = netG:forward(noise)
input:copy(fake)
label:fill(fake_label)
local output = netD:forward(input)
local errD_fake = criterion:forward(output, label)
local df_do = criterion:backward(output, label)
netD:backward(input, df_do)
errD = errD_real + errD_fake
return errD, gradParametersD
end
-- create closure to evaluate f(X) and df/dX of generator
local fGx = function(x)
gradParametersG:zero()
--[[ the three lines below were already executed in fDx, so save computation
noise:uniform(-1, 1) -- regenerate random noise
local fake = netG:forward(noise)
input:copy(fake) ]]--
label:fill(real_label) -- fake labels are real for generator cost
local output = netD.output -- netD:forward(input) was already executed in fDx, so save computation
errG = criterion:forward(output, label)
local df_do = criterion:backward(output, label)
local df_dg = netD:updateGradInput(input, df_do)
netG:backward(noise, df_dg)
return errG, gradParametersG
end
-- train
for epoch = 1, opt.niter do
epoch_tm:reset()
local counter = 0
for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do
tm:reset()
-- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
optim.adam(fDx, parametersD, optimStateD)
-- (2) Update G network: maximize log(D(G(z)))
optim.adam(fGx, parametersG, optimStateG)
-- display
counter = counter + 1
if counter % 10 == 0 and opt.display then
local fake = netG:forward(noise_vis)
local real = data:getBatch()
disp.image(fake, {win=opt.display_id, title=opt.name})
disp.image(real, {win=opt.display_id * 3, title=opt.name})
end
-- logging
if ((i-1) / opt.batchSize) % 1 == 0 then
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
.. ' Err_G: %.4f Err_D: %.4f'):format(
epoch, ((i-1) / opt.batchSize),
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
tm:time().real, data_tm:time().real,
errG and errG or -1, errD and errD or -1))
end
end
paths.mkdir('checkpoints')
parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory
parametersG, gradParametersG = nil, nil
torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG:clearState())
torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD:clearState())
parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them
parametersG, gradParametersG = netG:getParameters()
print(('End of epoch %d / %d \t Time Taken: %.3f'):format(
epoch, opt.niter, epoch_tm:time().real))
end