Browse Source

Add model option and 12 layers net

nagadomi 9 years ago
parent
commit
7af5c9443d
5 changed files with 101 additions and 43 deletions
  1. 3 16
      lib/reconstruct.lua
  2. 1 0
      lib/settings.lua
  3. 91 22
      lib/srcnn.lua
  4. 4 3
      tools/rebuild_model.lua
  5. 2 2
      train.lua

+ 3 - 16
lib/reconstruct.lua

@@ -1,5 +1,6 @@
 require 'image'
 local iproc = require 'iproc'
+local srcnn = require 'srcnn'
 
 local function reconstruct_y(model, x, offset, block_size)
    if x:dim() == 2 then
@@ -50,7 +51,7 @@ local function reconstruct_rgb(model, x, offset, block_size)
 end
 local reconstruct = {}
 function reconstruct.is_rgb(model)
-   if model:get(model:size() - 1).weight:size(1) == 3 then
+   if srcnn.channels(model) == 3 then
       -- 3ch RGB
       return true
    else
@@ -59,21 +60,7 @@ function reconstruct.is_rgb(model)
    end
 end
 function reconstruct.offset_size(model)
-   local conv = model:findModules("nn.SpatialConvolutionMM")
-   if #conv > 0 then
-      local offset = 0
-      for i = 1, #conv do
-	 offset = offset + (conv[i].kW - 1) / 2
-      end
-      return math.floor(offset)
-   else
-      conv = model:findModules("cudnn.SpatialConvolution")
-      local offset = 0
-      for i = 1, #conv do
-	 offset = offset + (conv[i].kW - 1) / 2
-      end
-      return math.floor(offset)
-   end
+   return srcnn.offset_size(model)
 end
 function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128

+ 1 - 0
lib/settings.lua

@@ -24,6 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", 'method to training (noise|scale)')
+cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
 cmd:option("-noise_level", 1, '(1|2|3)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')

+ 91 - 22
lib/srcnn.lua

@@ -30,63 +30,132 @@ end
 function srcnn.channels(model)
    return model:get(model:size() - 1).weight:size(1)
 end
-function srcnn.waifu2x_cunn(ch)
+function srcnn.backend(model)
+   local conv = model:findModules("cudnn.SpatialConvolution")
+   if #conv > 0 then
+      return "cudnn"
+   else
+      return "cunn"
+   end
+end
+function srcnn.color(model)
+   local ch = srcnn.channels(model)
+   if ch == 3 then
+      return "rgb"
+   else
+      return "y"
+   end
+end
+function srcnn.name(model)
+   local backend_cudnn = false
+   local conv = model:findModules("nn.SpatialConvolutionMM")
+   if #conv == 0 then
+      backend_cudnn = true
+      conv = model:findModules("cudnn.SpatialConvolution")
+   end
+   if #conv == 7 then
+      return "vgg_7"
+   elseif #conv == 12 then
+      return "vgg_12"
+   else
+      return nil
+   end
+end
+function srcnn.offset_size(model)
+   local conv = model:findModules("nn.SpatialConvolutionMM")
+   if #conv == 0 then
+      conv = model:findModules("cudnn.SpatialConvolution")
+   end
+   local offset = 0
+   for i = 1, #conv do
+      offset = offset + (conv[i].kW - 1) / 2
+   end
+   return math.floor(offset)
+end
+
+local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   if backend == "cunn" then
+      return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   elseif backend == "cudnn" then
+      return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+
+-- VGG style net(7 layers)
+function srcnn.vgg_7(backend, ch)
    local model = nn.Sequential()
-   model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
    return model
 end
-function srcnn.waifu2x_cudnn(ch)
+-- VGG style net(12 layers)
+function srcnn.vgg_12(backend, ch)
    local model = nn.Sequential()
-   model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(w2nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
    return model
 end
+
 function srcnn.create(model_name, backend, color)
+   model_name = model_name or "vgg_7"
+   backend = backend or "cunn"
+   color = color or "rgb"
    local ch = 3
    if color == "rgb" then
       ch = 3
    elseif color == "y" then
       ch = 1
    else
-      error("unsupported color: " + color)
+      error("unsupported color: " .. color)
    end
-   if backend == "cunn" then
-      return srcnn.waifu2x_cunn(ch)
-   elseif backend == "cudnn" then
-      return srcnn.waifu2x_cudnn(ch)
+   if model_name == "vgg_7" then
+      return srcnn.vgg_7(backend, ch)
+   elseif model_name == "vgg_12" then
+      return srcnn.vgg_12(backend, ch)
    else
-      error("unsupported backend: " +  backend)
+      error("unsupported model_name: " .. model_name)
    end
 end
 return srcnn

+ 4 - 3
tools/rebuild_model.lua

@@ -5,8 +5,8 @@ require 'os'
 require 'w2nn'
 local srcnn = require 'srcnn'
 
-local function rebuild(old_model)
-   local new_model = srcnn.waifu2x_cunn(srcnn.channels(old_model))
+local function rebuild(old_model, model)
+   local new_model = srcnn.create(model, srcnn.backend(old_model), srcnn.color(old_model))
    local weight_from = old_model:findModules("nn.SpatialConvolutionMM")
    local weight_to = new_model:findModules("nn.SpatialConvolutionMM")
 
@@ -30,6 +30,7 @@ cmd:text("waifu2x rebuild cunn model")
 cmd:text("Options:")
 cmd:option("-i", "", 'Specify the input model')
 cmd:option("-o", "", 'Specify the output model')
+cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12)')
 cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
 cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
 
@@ -39,5 +40,5 @@ if not path.isfile(opt.i) then
    os.exit(-1)
 end
 local old_model = torch.load(opt.i, opt.iformat)
-local new_model = rebuild(old_model)
+local new_model = rebuild(old_model, opt.model)
 torch.save(opt.o, new_model, opt.oformat)

+ 2 - 2
train.lua

@@ -192,7 +192,7 @@ local function train()
    local hist_train = {}
    local hist_valid = {}
    local LR_MIN = 1.0e-5
-   local model = srcnn.create(settings.method, settings.backend, settings.color)
+   local model = srcnn.create(settings.model, settings.backend, settings.color)
    local offset = reconstruct.offset_size(model)
    local pairwise_func = function(x, is_validation, n)
       return transformer(x, is_validation, n, offset)
@@ -200,7 +200,7 @@ local function train()
    local criterion = create_criterion(model)
    local eval_metric = nn.MSECriterion():cuda()
    local x = torch.load(settings.images)
-   local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
+   local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local adam_config = {
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,