瀏覽代碼

add support for RGB color space reconstruction

- add new RGB model (models/anime_style_art_rgb).
- RGB model can reduce color noise.
- waifu2x uses this RGB model by default.

You can use Y model with:
$ th waifu2x.lua -model_dir models/anime_style_art -i input.png -o output.png
$ th train.lua -color y ...
nagadomi 10 年之前
父節點
當前提交
5b4d692f03

+ 1 - 1
cudnn2cunn.lua

@@ -5,7 +5,7 @@ require './lib/LeakyReLU'
 local srcnn = require 'lib/srcnn'
 
 local function cudnn2cunn(cudnn_model)
-   local cunn_model = srcnn.waifu2x()
+   local cunn_model = srcnn.waifu2x("y")
    local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
    local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
 

二進制
images/lena_waifu2x.png


二進制
images/miku_CC_BY-NC_noisy_waifu2x.png


二進制
images/miku_noisy_waifu2x.png


二進制
images/miku_small_noisy_waifu2x.png


二進制
images/miku_small_waifu2x.png


二進制
images/slide.odp


二進制
images/slide.png


二進制
images/slide_noise_reduction.png


二進制
images/slide_result.png


二進制
images/slide_upscaling.png


+ 27 - 15
lib/pairwise_transform.lua

@@ -52,7 +52,7 @@ local function flip_augment(x, y)
 end
 local INTERPOLATION_PADDING = 16
 function pairwise_transform.scale(src, scale, size, offset, options)
-   options = options or {color_augment = true, random_half = true}
+   options = options or {color_augment = true, random_half = true, rgb = true}
    if options.random_half then
       src = random_half(src)
    end
@@ -81,8 +81,12 @@ function pairwise_transform.scale(src, scale, size, offset, options)
    x = iproc.scale(x, y:size(3), y:size(2))
    y = y:float():div(255)
    x = x:float():div(255)
-   y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-   x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+
+   if options.rgb then
+   else
+      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
+      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   end
 
    y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset -	INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
    x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
@@ -90,7 +94,7 @@ function pairwise_transform.scale(src, scale, size, offset, options)
    return x, y
 end
 function pairwise_transform.jpeg_(src, quality, size, offset, options)
-   options = options or {color_augment = true, random_half = true}
+   options = options or {color_augment = true, random_half = true, rgb = true}
    if options.random_half then
       src = random_half(src)
    end
@@ -106,6 +110,7 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
+      x:samplingFactors({1.0, 1.0, 1.0})
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
@@ -117,9 +122,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
    x = x:float():div(255)
    x, y = flip_augment(x, y)
    
-   y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-   x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
-
+   if options.rgb then
+   else
+      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
+      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   end
+   
    return x, image.crop(y, offset, offset, size - offset, size - offset)
 end
 function pairwise_transform.jpeg(src, level, size, offset, options)
@@ -159,12 +167,12 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    local down_scale = 1.0 / scale
    local filters = {
       "Box",        -- 0.012756949974688
-      --"Blackman",   -- 0.013191924552285
+      "Blackman",   -- 0.013191924552285
       --"Cartom",     -- 0.013753536746706
       --"Hanning",    -- 0.013761314529647
       --"Hermite",    -- 0.013850225205266
-      --"SincFast",   -- 0.014095824314306
-      --"Jinc",       -- 0.014244299255442
+      "SincFast",   -- 0.014095824314306
+      "Jinc",       -- 0.014244299255442
    }
    local downscale_filter = filters[torch.random(1, #filters)]
    local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
@@ -180,6 +188,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
+      x:samplingFactors({1.0, 1.0, 1.0})
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
@@ -194,10 +203,13 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    x = x:float():div(255)
    y = y:float():div(255)
    x, y = flip_augment(x, y)
-   
-   y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-   x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
 
+   if options.rgb then
+   else
+      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
+      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   end
+   
    return x, image.crop(y, offset, offset, size - offset, size - offset)
 end
 function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
@@ -247,7 +259,7 @@ local function test_scale()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
    for i = 1, 9 do
-      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true})
+      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
       image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())
@@ -272,8 +284,8 @@ local function test_jpeg_scale()
       --print(x:mean(), y:mean())
    end
 end
---test_jpeg()
 --test_scale()
+--test_jpeg()
 --test_jpeg_scale()
 
 return pairwise_transform

+ 104 - 5
lib/reconstruct.lua

@@ -1,7 +1,7 @@
 require 'image'
 local iproc = require './iproc'
 
-local function reconstruct_layer(model, x, offset, block_size)
+local function reconstruct_y(model, x, offset, block_size)
    if x:dim() == 2 then
       x = x:reshape(1, x:size(1), x:size(2))
    end
@@ -26,8 +26,40 @@ local function reconstruct_layer(model, x, offset, block_size)
    end
    return new_x
 end
+local function reconstruct_rgb(model, x, offset, block_size)
+   local new_x = torch.Tensor():resizeAs(x):zero()
+   local output_size = block_size - offset * 2
+   local input = torch.CudaTensor(1, 3, block_size, block_size)
+   
+   for i = 1, x:size(2), output_size do
+      for j = 1, x:size(3), output_size do
+	 if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
+	    local index = {{},
+			   {i, i + block_size - 1},
+			   {j, j + block_size - 1}}
+	    input:copy(x[index])
+	    local output = model:forward(input):float():view(3, output_size, output_size)
+	    local output_index = {{},
+				  {i + offset, offset + i + output_size - 1},
+				  {offset + j, offset + j + output_size - 1}}
+	    new_x[output_index]:copy(output)
+	 end
+      end
+   end
+   return new_x
+end
+function model_is_rgb(model)
+   if model:get(model:size() - 1).weight:size(1) == 3 then
+      -- 3ch RGB
+      return true
+   else
+      -- 1ch Y
+      return false
+   end
+end
+
 local reconstruct = {}
-function reconstruct.image(model, x, offset, block_size)
+function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128
    local output_size = block_size - offset * 2
    local h_blocks = math.floor(x:size(2) / output_size) +
@@ -42,7 +74,7 @@ function reconstruct.image(model, x, offset, block_size)
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
    local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_layer(model, yuv[1], offset, block_size)
+   local y = reconstruct_y(model, yuv[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv[1]:copy(y)
@@ -55,7 +87,7 @@ function reconstruct.image(model, x, offset, block_size)
    
    return output
 end
-function reconstruct.scale(model, scale, x, offset, block_size)
+function reconstruct.scale_y(model, scale, x, offset, block_size)
    block_size = block_size or 128
    local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc")
    x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
@@ -74,7 +106,7 @@ function reconstruct.scale(model, scale, x, offset, block_size)
    local pad_w2 = (w - offset) - x:size(3)
    local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
    local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_layer(model, yuv_nn[1], offset, block_size)
+   local y = reconstruct_y(model, yuv_nn[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv_jinc[1]:copy(y)
@@ -87,5 +119,72 @@ function reconstruct.scale(model, scale, x, offset, block_size)
    
    return output
 end
+function reconstruct.image_rgb(model, x, offset, block_size)
+   block_size = block_size or 128
+   local output_size = block_size - offset * 2
+   local h_blocks = math.floor(x:size(2) / output_size) +
+      ((x:size(2) % output_size == 0 and 0) or 1)
+   local w_blocks = math.floor(x:size(3) / output_size) +
+      ((x:size(3) % output_size == 0 and 0) or 1)
+   
+   local h = offset + h_blocks * output_size + offset
+   local w = offset + w_blocks * output_size + offset
+   local pad_h1 = offset
+   local pad_w1 = offset
+   local pad_h2 = (h - offset) - x:size(2)
+   local pad_w2 = (w - offset) - x:size(3)
+   local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+   local y = reconstruct_rgb(model, input, offset, block_size)
+   local output = image.crop(y,
+			     pad_w1, pad_h1,
+			     y:size(3) - pad_w2, y:size(2) - pad_h2)
+   collectgarbage()
+   output[torch.lt(output, 0)] = 0
+   output[torch.gt(output, 1)] = 1
+   
+   return output
+end
+function reconstruct.scale_rgb(model, scale, x, offset, block_size)
+   block_size = block_size or 128
+   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
+
+   local output_size = block_size - offset * 2
+   local h_blocks = math.floor(x:size(2) / output_size) +
+      ((x:size(2) % output_size == 0 and 0) or 1)
+   local w_blocks = math.floor(x:size(3) / output_size) +
+      ((x:size(3) % output_size == 0 and 0) or 1)
+   
+   local h = offset + h_blocks * output_size + offset
+   local w = offset + w_blocks * output_size + offset
+   local pad_h1 = offset
+   local pad_w1 = offset
+   local pad_h2 = (h - offset) - x:size(2)
+   local pad_w2 = (w - offset) - x:size(3)
+   local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+   local y = reconstruct_rgb(model, input, offset, block_size)
+   local output = image.crop(y,
+			     pad_w1, pad_h1,
+			     y:size(3) - pad_w2, y:size(2) - pad_h2)
+   output[torch.lt(output, 0)] = 0
+   output[torch.gt(output, 1)] = 1
+   collectgarbage()
+   
+   return output
+end
+
+function reconstruct.image(model, x, offset, block_size)
+   if model_is_rgb(model) then
+      return reconstruct.image_rgb(model, x, offset, block_size)
+   else
+      return reconstruct.image_y(model, x, offset, block_size)
+   end
+end
+function reconstruct.scale(model, scale, x, offset, block_size)
+   if model_is_rgb(model) then
+      return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+   else
+      return reconstruct.scale_y(model, scale, x, offset, block_size)
+   end
+end
 
 return reconstruct

+ 4 - 0
lib/settings.lua

@@ -22,6 +22,7 @@ cmd:option("-test", "images/miku_small.png", 'test image file')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", '(noise|scale|noise_scale)')
 cmd:option("-noise_level", 1, '(1|2)')
+cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-scale", 2.0, 'scale')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
@@ -46,6 +47,9 @@ elseif settings.method == "noise_scale" then
 else
    error("unknown method: " .. settings.method)
 end
+if not (settings.color == "rgb" or settings.color == "y") then
+   error("color must be y or rgb")
+end
 if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
    error("scale must be mod-2")
 end

+ 27 - 6
lib/srcnn.lua

@@ -6,10 +6,22 @@ function nn.SpatialConvolutionMM:reset(stdv)
    self.bias:fill(0)
 end
 local srcnn = {}
-function srcnn.waifu2x()
+function srcnn.waifu2x(color)
    local model = nn.Sequential()
+   local ch = nil
+   if color == "rgb" then
+      ch = 3
+   elseif color == "y" then
+      ch = 1
+   else
+      if color then
+	 error("unknown color: " .. color)
+      else
+	 error("unknown color: nil")
+      end
+   end
    
-   model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
@@ -21,7 +33,7 @@ function srcnn.waifu2x()
    model:add(nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0))
+   model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
 --model:cuda()
 --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
@@ -30,10 +42,19 @@ function srcnn.waifu2x()
 end
 
 -- current 4x is worse then 2x * 2
-function srcnn.waifu4x()
+function srcnn.waifu4x(color)
    local model = nn.Sequential()
+
+   local ch = nil
+   if color == "rgb" then
+      ch = 3
+   elseif color == "y" then
+      ch = 1
+   else
+      error("unknown color: " .. color)
+   end
    
-   model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0))
+   model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
@@ -45,7 +66,7 @@ function srcnn.waifu4x()
    model:add(nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0))
+   model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    
    return model, 13

文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art_rgb/noise1_model.json


文件差異過大導致無法顯示
+ 85 - 0
models/anime_style_art_rgb/noise1_model.t7


文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art_rgb/noise2_model.json


文件差異過大導致無法顯示
+ 85 - 0
models/anime_style_art_rgb/noise2_model.t7


文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art_rgb/scale2.0x_model.json


文件差異過大導致無法顯示
+ 85 - 0
models/anime_style_art_rgb/scale2.0x_model.t7


+ 17 - 6
train.lua

@@ -59,7 +59,7 @@ local function validate(model, criterion, data)
 end
 
 local function train()
-   local model, offset = settings.create_model()
+   local model, offset = settings.create_model(settings.color)
    assert(offset == settings.block_offset)
    local criterion = nn.MSECriterion():cuda()
    local x = torch.load(settings.images)
@@ -72,6 +72,12 @@ local function train()
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
    }
+   local ch = nil
+   if settings.color == "y" then
+      ch = 1
+   elseif settings.color == "rgb" then
+      ch = 3
+   end
    local transformer = function(x, is_validation)
       if is_validation == nil then is_validation = false end
       if settings.method == "scale" then
@@ -79,20 +85,25 @@ local function train()
 					 settings.scale,
 					 settings.crop_size, offset,
 					 { color_augment = not is_validation,
-					   random_half = settings.random_half})
+					   random_half = settings.random_half,
+					   rgb = (settings.color == "rgb")
+					 })
       elseif settings.method == "noise" then
 	 return pairwise_transform.jpeg(x,
 					settings.noise_level,
 					settings.crop_size, offset,
 					{ color_augment = not is_validation,
-					  random_half = settings.random_half})
+					  random_half = settings.random_half,
+					  rgb = (settings.color == "rgb")
+					})
       elseif settings.method == "noise_scale" then
 	 return pairwise_transform.jpeg_scale(x,
 					      settings.scale,
 					      settings.noise_level,
 					      settings.crop_size, offset,
 					      { color_augment = not is_validation,
-						random_half = settings.random_half
+						random_half = settings.random_half,
+						rgb = (settings.color == "rgb")
 					      })
       end
    end
@@ -109,8 +120,8 @@ local function train()
       print("# " .. epoch)
       print(minibatch_adam(model, criterion, train_x, adam_config,
 			   transformer,
-			   {1, settings.crop_size, settings.crop_size},
-			   {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
+			   {ch, settings.crop_size, settings.crop_size},
+			   {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
 			  ))
       model:evaluate()
       print("# validation")

+ 6 - 6
train.sh

@@ -1,10 +1,10 @@
 #!/bin/sh
 
-th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png
-th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii
+th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
+th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
 
-th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png
-th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii
+th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
+th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
 
-th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png
-th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii
+th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
+th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii

+ 1 - 1
waifu2x.lua

@@ -105,7 +105,7 @@ local function waifu2x()
    cmd:option("-l", "", 'path of the image-list')
    cmd:option("-scale", 2, 'scale factor')
    cmd:option("-o", "(auto)", 'path of the output file')
-   cmd:option("-model_dir", "./models/anime_style_art", 'model directory')
+   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2)')
    cmd:option("-crop_size", 128, 'patch size per process')

+ 1 - 1
web.lua

@@ -15,7 +15,7 @@ local iproc = require './lib/iproc'
 local reconstruct = require './lib/reconstruct'
 local image_loader = require './lib/image_loader'
 
-local MODEL_DIR = "./models/anime_style_art"
+local MODEL_DIR = "./models/anime_style_art_rgb"
 
 local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
 local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")

部分文件因文件數量過多而無法顯示