瀏覽代碼

Merge pull request #26 from nagadomi/merge_develop

remove support for cuDNN
nagadomi 10 年之前
父節點
當前提交
dd629dd015

+ 21 - 30
README.md

@@ -23,24 +23,20 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
 ## Dependencies
 
 ### Hardware
-- NVIDIA GPU (Compute Capability 3.0 or later)
+- NVIDIA GPU
 
 ### Platform
 - [Torch7](http://torch.ch/)
 - [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit)
-- [NVIDIA cuDNN](https://developer.nvidia.com/cuDNN)
 
 ### Packages (luarocks)
 - cutorch
 - cunn
-- [cudnn](https://github.com/soumith/cudnn.torch)
 - [graphicsmagick](https://github.com/clementfarabet/graphicsmagick)
 - [turbo](https://github.com/kernelsauce/turbo)
 - md5
 - uuid
 
-NOTE: Turbo 1.1.3 has bug in file uploading. Please install from the master branch on github.
-
 ## Installation
 
 ### Setting Up the Command Line Tool Environment
@@ -54,16 +50,15 @@ curl -s https://raw.githubusercontent.com/torch/ezinstall/master/install-all | s
 ```
 see [Torch (easy) install](https://github.com/torch/ezinstall)
 
-#### Install CUDA and cuDNN.
+#### Install CUDA
 
-Google! Search keyword is "install cuda ubuntu" and "install cudnn ubuntu"
+Google! Search keyword: "install cuda ubuntu"
 
 #### Install packages
 
 ```
 sudo luarocks install cutorch
 sudo luarocks install cunn
-sudo luarocks install cudnn
 sudo apt-get install graphicsmagick libgraphicsmagick-dev
 sudo luarocks install graphicsmagick
 ```
@@ -91,21 +86,10 @@ Install luarocks packages.
 ```
 sudo luarocks install md5
 sudo luarocks install uuid
-```
-
-Install turbo.
-```
-git clone https://github.com/kernelsauce/turbo.git
-cd turbo
-sudo luarocks make rockspecs/turbo-dev-1.rockspec 
+sudo luarocks install turbo
 ```
 
 ## Web Application
-
-Please edit the first line in `web.lua`.
-```
-local ROOT = '/path/to/waifu2x/dir'
-```
 Run.
 ```
 th web.lua
@@ -173,7 +157,7 @@ Genrating a file list.
 ```
 find /path/to/image/dir -name "*.png" > data/image_list.txt
 ```
-(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-beautiful-PNG images.)
+(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.)
 
 Converting training data.
 ```
@@ -183,23 +167,30 @@ th convert_data.lua
 ### Training a Noise Reduction(level1) model
 
 ```
-th train.lua -method noise -noise_level 1 -test images/miku_noisy.png
-th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii
+mkdir models/my_model
+th train.lua -model_dir models/my_model -method noise -noise_level 1 -test images/miku_noisy.png
+th cleanup_model.lua -model models/my_model/noise1_model.t7 -oformat ascii
+# usage
+th waifu2x.lua -model_dir models/my_model -m noise -noise_level 1 -i images/miku_noisy.png -o output.png
 ```
-You can check the performance of model with `models/noise1_best.png`.
+You can check the performance of model with `models/my_model/noise1_best.png`.
 
 ### Training a Noise Reduction(level2) model
 
 ```
-th train.lua -method noise -noise_level 2 -test images/miku_noisy.png
-th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii
+th train.lua -model_dir models/my_model -method noise -noise_level 2 -test images/miku_noisy.png
+th cleanup_model.lua -model models/my_model/noise2_model.t7 -oformat ascii
+# usage
+th waifu2x.lua -model_dir models/my_model -m noise -noise_level 2 -i images/miku_noisy.png -o output.png
 ```
-You can check the performance of model with `models/noise2_best.png`.
+You can check the performance of model with `models/my_model/noise2_best.png`.
 
 ### Training a 2x UpScaling model
 
 ```
-th train.lua -method scale -scale 2 -test images/miku_small.png
-th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii
+th train.lua -model_dir models/my_model -method scale -scale 2 -test images/miku_small.png
+th cleanup_model.lua -model models/my_model/scale2.0x_model.t7 -oformat ascii
+# usage
+th waifu2x.lua -model_dir models/my_model -m scale -scale 2 -i images/miku_small.png -o output.png
 ```
-You can check the performance of model with `models/scale2.0x_best.png`.
+You can check the performance of model with `models/my_model/scale2.0x_best.png`.

+ 1 - 2
cleanup_model.lua

@@ -1,5 +1,4 @@
-require 'cunn'
-require 'cudnn'
+require './lib/portable'
 require './lib/LeakyReLU'
 
 torch.setdefaulttensortype("torch.FloatTensor")

+ 10 - 3
convert_data.lua

@@ -1,4 +1,5 @@
-require 'torch'
+require './lib/portable'
+require 'image'
 local settings = require './lib/settings'
 local image_loader = require './lib/image_loader'
 
@@ -13,15 +14,21 @@ local function count_lines(file)
    return count
 end
 
+local function crop_4x(x)
+   local w = x:size(3) % 4
+   local h = x:size(2) % 4
+   return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
+end
+
 local function load_images(list)
    local count = count_lines(list)
    local fp = io.open(list, "r")
    local x = {}
    local c = 0
    for line in fp:lines() do
-      local im = image_loader.load_byte(line)
+      local im = crop_4x(image_loader.load_byte(line))
       if im then
-	 if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then
+	 if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then
 	    table.insert(x, im)
 	 end
       else

+ 34 - 0
cudnn2cunn.lua

@@ -0,0 +1,34 @@
+require 'cunn'
+require 'cudnn'
+require 'cutorch'
+require './lib/LeakyReLU'
+local srcnn = require 'lib/srcnn'
+
+local function cudnn2cunn(cudnn_model)
+   local cunn_model = srcnn.waifu2x()
+   local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
+   local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
+
+   for i = 1, #from_seq do
+      local from = from_seq[i]
+      local to = to_seq[i]
+      to.weight:copy(from.weight)
+      to.bias:copy(from.bias)
+   end
+   cunn_model:cuda()
+   cunn_model:evaluate()
+   return cunn_model
+end
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("convert cudnn model to cunn model ")
+cmd:text("Options:")
+cmd:option("-model", "./model.t7", 'path of cudnn model file')
+cmd:option("-iformat", "ascii", 'input format')
+cmd:option("-oformat", "ascii", 'output format')
+
+local opt = cmd:parse(arg)
+local cudnn_model = torch.load(opt.model, opt.iformat)
+local cunn_model = cudnn2cunn(cudnn_model)
+torch.save(opt.model, cunn_model, opt.oformat)

+ 23 - 0
export_model.lua

@@ -0,0 +1,23 @@
+-- adapted from https://github.com/marcan/cl-waifu2x
+require './lib/portable'
+require './lib/LeakyReLU'
+local cjson = require "cjson"
+
+local model = torch.load(arg[1], "ascii")
+
+local jmodules = {}
+local modules = model:findModules("nn.SpatialConvolutionMM")
+for i = 1, #modules, 1 do
+   local module = modules[i]
+   local jmod = {
+      kW = module.kW,
+      kH = module.kH,
+      nInputPlane = module.nInputPlane,
+      nOutputPlane = module.nOutputPlane,
+      bias = torch.totable(module.bias:float()),
+      weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
+   }
+   table.insert(jmodules, jmod)
+end
+
+io.write(cjson.encode(jmodules))

二進制
images/lena_waifu2x.png


二進制
images/miku_CC_BY-NC_noisy_waifu2x.png


二進制
images/miku_noisy_waifu2x.png


二進制
images/miku_small_noisy_waifu2x.png


二進制
images/miku_small_waifu2x.png


+ 195 - 90
lib/pairwise_transform.lua

@@ -4,84 +4,103 @@ local iproc = require './iproc'
 local reconstruct = require './reconstruct'
 local pairwise_transform = {}
 
+local function random_half(src, p, min_size)
+   p = p or 0.5
+   local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
+   if p > torch.uniform() then
+      return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
+   else
+      return src
+   end
+end
+local function color_augment(x)
+   local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
+   x = x:float():div(255)
+   for i = 1, 3 do
+      x[i]:mul(color_scale[i])
+   end
+   x[torch.lt(x, 0.0)] = 0.0
+   x[torch.gt(x, 1.0)] = 1.0
+   return x:mul(255):byte()
+end
+local function flip_augment(x, y)
+   local flip = torch.random(1, 4)
+   if y then
+      if flip == 1 then
+	 x = image.hflip(x)
+	 y = image.hflip(y)
+      elseif flip == 2 then
+	 x = image.vflip(x)
+	 y = image.vflip(y)
+      elseif flip == 3 then
+	 x = image.hflip(image.vflip(x))
+	 y = image.hflip(image.vflip(y))
+      elseif flip == 4 then
+      end
+      return x, y
+   else
+      if flip == 1 then
+	 x = image.hflip(x)
+      elseif flip == 2 then
+	 x = image.vflip(x)
+      elseif flip == 3 then
+	 x = image.hflip(image.vflip(x))
+      elseif flip == 4 then
+      end
+      return x
+   end
+end
+local INTERPOLATION_PADDING = 16
 function pairwise_transform.scale(src, scale, size, offset, options)
-   options = options or {}
-   local yi = torch.random(0, src:size(2) - size - 1)
-   local xi = torch.random(0, src:size(3) - size - 1)
+   options = options or {color_augment = true, random_half = true}
+   if options.random_half then
+      src = random_half(src)
+   end
+   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
+   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
    local down_scale = 1.0 / scale
-   local y = image.crop(src, xi, yi, xi + size, yi + size)
-   local flip = torch.random(1, 4)
-   local nega = torch.random(0, 1)
+   local y = image.crop(src,
+			xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
+			xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
    local filters = {
       "Box",        -- 0.012756949974688
       "Blackman",   -- 0.013191924552285
       --"Cartom",     -- 0.013753536746706
       --"Hanning",    -- 0.013761314529647
       --"Hermite",    -- 0.013850225205266
-      --"SincFast",   -- 0.014095824314306
-      --"Jinc",       -- 0.014244299255442
+      "SincFast",   -- 0.014095824314306
+      "Jinc",       -- 0.014244299255442
    }
    local downscale_filter = filters[torch.random(1, #filters)]
    
-   if flip == 1 then
-      y = image.hflip(y)
-   elseif flip == 2 then
-      y = image.vflip(y)
-   elseif flip == 3 then
-      y = image.hflip(image.vflip(y))
-   elseif flip == 4 then
-      -- none
-   end
+   y = flip_augment(y)
    if options.color_augment then
-      y = y:float():div(255)
-      local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
-      for i = 1, 3 do
-	 y[i]:mul(color_scale[i])
-      end
-      y[torch.lt(y, 0)] = 0
-      y[torch.gt(y, 1.0)] = 1.0
-      y = y:mul(255):byte()
+      y = color_augment(y)
    end
    local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
-   if options.noise and (options.noise_ratio or 0.5) > torch.uniform() then
-      -- add noise
-      local quality = {torch.random(70, 90)}
-      for i = 1, #quality do
-	 x = gm.Image(x, "RGB", "DHW")
-	 x:format("jpeg")
-	 local blob, len = x:toBlob(quality[i])
-	 x:fromBlob(blob, len)
-	    x = x:toTensor("byte", "RGB", "DHW")
-      end
-   end
-   if options.denoise_model and (options.denoise_ratio or 0.5) > torch.uniform() then
-      x = reconstruct(options.denoise_model, x:float():div(255), offset):mul(255):byte()
-   end
    x = iproc.scale(x, y:size(3), y:size(2))
    y = y:float():div(255)
    x = x:float():div(255)
    y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
    x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+
+   y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset -	INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
+   x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
    
-   return x, image.crop(y, offset, offset, size - offset, size - offset)
+   return x, y
 end
-function pairwise_transform.jpeg_(src, quality, size, offset, color_augment)
-   if color_augment == nil then color_augment = true end
+function pairwise_transform.jpeg_(src, quality, size, offset, options)
+   options = options or {color_augment = true, random_half = true}
+   if options.random_half then
+      src = random_half(src)
+   end
    local yi = torch.random(0, src:size(2) - size - 1)
    local xi = torch.random(0, src:size(3) - size - 1)
    local y = src
    local x
-   local flip = torch.random(1, 4)
 
-   if color_augment then
-      local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
-      y = y:float():div(255)
-      for i = 1, 3 do
-	 y[i]:mul(color_scale[i])
-      end
-      y[torch.lt(y, 0)] = 0
-      y[torch.gt(y, 1.0)] = 1.0
-      y = y:mul(255):byte()
+   if options.color_augment then
+      y = color_augment(y)
    end
    x = y
    for i = 1, #quality do
@@ -94,48 +113,115 @@ function pairwise_transform.jpeg_(src, quality, size, offset, color_augment)
    
    y = image.crop(y, xi, yi, xi + size, yi + size)
    x = image.crop(x, xi, yi, xi + size, yi + size)
-   x = x:float():div(255)
    y = y:float():div(255)
+   x = x:float():div(255)
+   x, y = flip_augment(x, y)
    
-   if flip == 1 then
-      y = image.hflip(y)
-      x = image.hflip(x)
-   elseif flip == 2 then
-      y = image.vflip(y)
-      x = image.vflip(x)
-   elseif flip == 3 then
-      y = image.hflip(image.vflip(y))
-      x = image.hflip(image.vflip(x))
-   elseif flip == 4 then
-      -- none
-   end
    y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
    x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
 
    return x, image.crop(y, offset, offset, size - offset, size - offset)
 end
-function pairwise_transform.jpeg(src, level, size, offset, color_augment)
+function pairwise_transform.jpeg(src, level, size, offset, options)
    if level == 1 then
       return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
 				      size, offset,
-				      color_augment)
+				      options)
    elseif level == 2 then
       local r = torch.uniform()
       if r > 0.6 then
-	 return pairwise_transform.jpeg_(src, {torch.random(27, 80)},
+	 return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
 					 size, offset,
-					 color_augment)
+					 options)
       elseif r > 0.3 then
-	 local quality1 = torch.random(32, 40)
-	 local quality2 = quality1 - 5
+	 local quality1 = torch.random(37, 70)
+	 local quality2 = quality1 - torch.random(5, 10)
 	 return pairwise_transform.jpeg_(src, {quality1, quality2},
-					 size, offset,
-					 color_augment)
+					    size, offset,
+					    options)
       else
-	 local quality1 = torch.random(47, 70)
-	 return pairwise_transform.jpeg_(src, {quality1, quality1 - 10, quality1 - 20},
+	 local quality1 = torch.random(52, 70)
+	 return pairwise_transform.jpeg_(src,
+					 {quality1,
+					  quality1 - torch.random(5, 15),
+					  quality1 - torch.random(15, 25)},
 					 size, offset,
-					 color_augment)
+					 options)
+      end
+   else
+      error("unknown noise level: " .. level)
+   end
+end
+function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
+   if options.random_half then
+      src = random_half(src)
+   end
+   local down_scale = 1.0 / scale
+   local filters = {
+      "Box",        -- 0.012756949974688
+      --"Blackman",   -- 0.013191924552285
+      --"Cartom",     -- 0.013753536746706
+      --"Hanning",    -- 0.013761314529647
+      --"Hermite",    -- 0.013850225205266
+      --"SincFast",   -- 0.014095824314306
+      --"Jinc",       -- 0.014244299255442
+   }
+   local downscale_filter = filters[torch.random(1, #filters)]
+   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
+   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
+   local y = src
+   local x
+   
+   if options.color_augment then
+      y = color_augment(y)
+   end
+   x = y
+   x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
+   for i = 1, #quality do
+      x = gm.Image(x, "RGB", "DHW")
+      x:format("jpeg")
+      local blob, len = x:toBlob(quality[i])
+      x:fromBlob(blob, len)
+      x = x:toTensor("byte", "RGB", "DHW")
+   end
+   x = iproc.scale(x, y:size(3), y:size(2))
+   y = image.crop(y,
+		  xi, yi,
+		  xi + size, yi + size)
+   x = image.crop(x,
+		  xi, yi,
+		  xi + size, yi + size)
+   x = x:float():div(255)
+   y = y:float():div(255)
+   x, y = flip_augment(x, y)
+   
+   y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
+   x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+
+   return x, image.crop(y, offset, offset, size - offset, size - offset)
+end
+function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
+   options = options or {color_augment = true, random_half = true}
+   if level == 1 then
+      return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
+					    size, offset, options)
+   elseif level == 2 then
+      local r = torch.uniform()
+      if r > 0.6 then
+	 return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
+					       size, offset, options)
+      elseif r > 0.3 then
+	 local quality1 = torch.random(37, 70)
+	 local quality2 = quality1 - torch.random(5, 10)
+	 return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
+					       size, offset, options)
+      else
+	 local quality1 = torch.random(52, 70)
+	    return pairwise_transform.jpeg_scale_(src, scale,
+						  {quality1,
+						   quality1 - torch.random(5, 15),
+						   quality1 - torch.random(15, 25)},
+						  size, offset, options)
       end
    else
       error("unknown noise level: " .. level)
@@ -143,32 +229,51 @@ function pairwise_transform.jpeg(src, level, size, offset, color_augment)
 end
 
 local function test_jpeg()
-   local loader = require 'image_loader'
-   local src = loader.load_byte("a.jpg")
-
+   local loader = require './image_loader'
+   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
+   local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false)
+   image.display({image = y, legend = "y:0"})
+   image.display({image = x, legend = "x:0"})
    for i = 2, 9 do
-      local y, x = pairwise_transform.jpeg_(src, {i * 10}, 128, 0, false)
+      local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src),
+					    {i * 10}, 128, 0, {color_augment = false, random_half = true})
       image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
       image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
       --print(x:mean(), y:mean())
    end
 end
-local function test_scale()
-   require 'nn'
-   require 'cudnn'
-   require './LeakyReLU'
-   
-   local loader = require 'image_loader'
-   local src = loader.load_byte("e.jpg")
 
+local function test_scale()
+   local loader = require './image_loader'
+   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
+   for i = 1, 9 do
+      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true})
+      image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
+      image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
+      print(y:size(), x:size())
+      --print(x:mean(), y:mean())
+   end
+end
+local function test_jpeg_scale()
+   local loader = require './image_loader'
+   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
+   for i = 1, 9 do
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true})
+      image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
+      image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
+      print(y:size(), x:size())
+      --print(x:mean(), y:mean())
+   end
    for i = 1, 9 do
-      local y, x = pairwise_transform.scale(src, 2.0, "Box", 128, 7, {noise = true, denoise_model = torch.load("models/noise1_model.t7")})
-      image.display({image = y, legend = "y:" .. (i * 10)})
-      image.display({image = x, legend = "x:" .. (i * 10)})
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true})
+      image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
+      image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
+      print(y:size(), x:size())
       --print(x:mean(), y:mean())
    end
 end
 --test_jpeg()
 --test_scale()
+--test_jpeg_scale()
 
 return pairwise_transform

+ 15 - 0
lib/portable.lua

@@ -0,0 +1,15 @@
+local function load_cuda()
+   require 'cunn'
+end
+
+if pcall(load_cuda) then
+   require 'cunn'
+else
+   --[[ TODO: fakecuda does not work.
+      
+   io.stderr:write("use FakeCUDA; if you have NVIDIA GPU, Please install cutorch and cunn. FakeCuda will be extremely slow.\n")
+   require 'torch'
+   require 'nn'
+   require('fakecuda').init(true)
+   --]]
+end

+ 3 - 3
lib/reconstruct.lua

@@ -1,7 +1,7 @@
 require 'image'
 local iproc = require './iproc'
 
-local function reconstruct_layer(model, x, block_size, offset)
+local function reconstruct_layer(model, x, offset, block_size)
    if x:dim() == 2 then
       x = x:reshape(1, x:size(1), x:size(2))
    end
@@ -42,7 +42,7 @@ function reconstruct.image(model, x, offset, block_size)
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
    local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_layer(model, yuv[1], block_size, offset)
+   local y = reconstruct_layer(model, yuv[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv[1]:copy(y)
@@ -74,7 +74,7 @@ function reconstruct.scale(model, scale, x, offset, block_size)
    local pad_w2 = (w - offset) - x:size(3)
    local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
    local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_layer(model, yuv_nn[1], block_size, offset)
+   local y = reconstruct_layer(model, yuv_nn[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv_jinc[1]:copy(y)

+ 23 - 7
lib/settings.lua

@@ -1,5 +1,3 @@
-require 'torch'
-require 'cutorch'
 require 'xlua'
 require 'pl'
 
@@ -22,10 +20,11 @@ cmd:option("-seed", 11, 'fixed input seed')
 cmd:option("-data_dir", "./data", 'data directory')
 cmd:option("-test", "images/miku_small.png", 'test image file')
 cmd:option("-model_dir", "./models", 'model directory')
-cmd:option("-method", "scale", '(noise|scale)')
+cmd:option("-method", "scale", '(noise|scale|noise_scale)')
 cmd:option("-noise_level", 1, '(1|2)')
 cmd:option("-scale", 2.0, 'scale')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
+cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
 cmd:option("-crop_size", 128, 'crop size')
 cmd:option("-batch_size", 2, 'mini batch size')
 cmd:option("-epoch", 200, 'epoch')
@@ -36,16 +35,25 @@ for k, v in pairs(opt) do
    settings[k] = v
 end
 if settings.method == "noise" then
-   settings.model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level)
+   settings.model_file = string.format("%s/noise%d_model.t7",
+				       settings.model_dir, settings.noise_level)
 elseif settings.method == "scale" then
-   settings.model_file = string.format("%s/scale%.1fx_model.t7", settings.model_dir, settings.scale)
-   settings.denoise_model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level)
+   settings.model_file = string.format("%s/scale%.1fx_model.t7",
+				       settings.model_dir, settings.scale)
+elseif settings.method == "noise_scale" then
+   settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
+				       settings.model_dir, settings.noise_level, settings.scale)
 else
    error("unknown method: " .. settings.method)
 end
 if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
    error("scale must be mod-2")
 end
+if settings.random_half == 1 then
+   settings.random_half = true
+else
+   settings.random_half = false
+end
 torch.setnumthreads(settings.core)
 
 settings.images = string.format("%s/images.t7", settings.data_dir)
@@ -53,6 +61,14 @@ settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
 settings.validation_ratio = 0.1
 settings.validation_crops = 40
-settings.block_offset = 7 -- see srcnn.lua
+
+local srcnn = require './srcnn'
+if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
+   settings.create_model = srcnn.waifu4x
+   settings.block_offset = 13
+else
+   settings.create_model = srcnn.waifu2x
+   settings.block_offset = 7
+end
 
 return settings

+ 35 - 14
lib/srcnn.lua

@@ -1,32 +1,53 @@
-require 'cunn'
-require 'cudnn'
 require './LeakyReLU'
 
-function cudnn.SpatialConvolution:reset(stdv)
+function nn.SpatialConvolutionMM:reset(stdv)
    stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
    self.weight:normal(0, stdv)
    self.bias:fill(0)
 end
-local function create_model()
-   local model = nn.Sequential() 
+local srcnn = {}
+function srcnn.waifu2x()
+   local model = nn.Sequential()
    
-   model:add(cudnn.SpatialConvolution(1, 32, 3, 3, 1, 1, 0, 0):fastest())
-   model:add(nn.LeakyReLU(0.1))   
-   model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1))
-   model:add(cudnn.SpatialConvolution(128, 1, 3, 3, 1, 1, 0, 0):fastest())
+   model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
 --model:cuda()
 --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
    
    return model, 7
 end
-return create_model
+
+-- current 4x is worse then 2x * 2
+function srcnn.waifu4x()
+   local model = nn.Sequential()
+   
+   model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0))
+   model:add(nn.View(-1):setNumInputDims(3))
+   
+   return model, 13
+end
+return srcnn

文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art/noise1_model.json


文件差異過大導致無法顯示
+ 20 - 5
models/anime_style_art/noise1_model.t7


文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art/noise2_model.json


文件差異過大導致無法顯示
+ 24 - 9
models/anime_style_art/noise2_model.t7


文件差異過大導致無法顯示
+ 0 - 0
models/anime_style_art/scale2.0x_model.json


文件差異過大導致無法顯示
+ 24 - 9
models/anime_style_art/scale2.0x_model.t7


+ 48 - 31
train.lua

@@ -1,5 +1,4 @@
-require 'cutorch'
-require 'cunn'
+require './lib/portable'
 require 'optim'
 require 'xlua'
 require 'pl'
@@ -7,7 +6,6 @@ require 'pl'
 local settings = require './lib/settings'
 local minibatch_adam = require './lib/minibatch_adam'
 local iproc = require './lib/iproc'
-local create_model = require './lib/srcnn'
 local reconstruct = require './lib/reconstruct'
 local pairwise_transform = require './lib/pairwise_transform'
 local image_loader = require './lib/image_loader'
@@ -61,10 +59,11 @@ local function validate(model, criterion, data)
 end
 
 local function train()
-   local model, offset = create_model()
+   local model, offset = settings.create_model()
    assert(offset == settings.block_offset)
    local criterion = nn.MSECriterion():cuda()
    local x = torch.load(settings.images)
+   local lrd_count = 0
    local train_x, valid_x = split_data(x,
 				       math.floor(settings.validation_ratio * #x),
 				       settings.validation_crops)
@@ -78,16 +77,23 @@ local function train()
       if settings.method == "scale" then
 	 return pairwise_transform.scale(x,
 					 settings.scale,
-					 settings.crop_size,
-					 offset,
-					 {color_augment = not is_validation,
-					  noise = false,
-					  denoise_model = nil
-					 })
+					 settings.crop_size, offset,
+					 { color_augment = not is_validation,
+					   random_half = settings.random_half})
       elseif settings.method == "noise" then
-	 return pairwise_transform.jpeg(x, settings.noise_level,
+	 return pairwise_transform.jpeg(x,
+					settings.noise_level,
 					settings.crop_size, offset,
-					   not is_validation)
+					{ color_augment = not is_validation,
+					  random_half = settings.random_half})
+      elseif settings.method == "noise_scale" then
+	 return pairwise_transform.jpeg_scale(x,
+					      settings.scale,
+					      settings.noise_level,
+					      settings.crop_size, offset,
+					      { color_augment = not is_validation,
+						random_half = settings.random_half
+					      })
       end
    end
    local best_score = 100000.0
@@ -106,27 +112,38 @@ local function train()
 			   {1, settings.crop_size, settings.crop_size},
 			   {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
 			  ))
-      if epoch % 1 == 0 then
-	 collectgarbage()
-	 model:evaluate()
-	 print("# validation")
-	 local score = validate(model, criterion, valid_xy)
-	 if score < best_score then
-	    best_score = score
-	    print("* update best model")
-	    torch.save(settings.model_file, model)
-	    if settings.method == "noise" then
-	       local log = path.join(settings.model_dir,
-				     ("noise%d_best.png"):format(settings.noise_level))
-	       save_test_jpeg(model, test, log)
-	    elseif settings.method == "scale" then
-	       local log = path.join(settings.model_dir,
-				     ("scale%.1f_best.png"):format(settings.scale))
-	       save_test_scale(model, test, log)
-	    end
+      model:evaluate()
+      print("# validation")
+      local score = validate(model, criterion, valid_xy)
+      if score < best_score then
+	 lrd_count = 0
+	 best_score = score
+	 print("* update best model")
+	 torch.save(settings.model_file, model)
+	 if settings.method == "noise" then
+	    local log = path.join(settings.model_dir,
+				  ("noise%d_best.png"):format(settings.noise_level))
+	    save_test_jpeg(model, test, log)
+	 elseif settings.method == "scale" then
+	    local log = path.join(settings.model_dir,
+				  ("scale%.1f_best.png"):format(settings.scale))
+	    save_test_scale(model, test, log)
+	 elseif settings.method == "noise_scale" then
+	    local log = path.join(settings.model_dir,
+				  ("noise%d_scale%.1f_best.png"):format(settings.noise_level,
+									settings.scale))
+	    save_test_scale(model, test, log)
+	 end
+      else
+	 lrd_count = lrd_count + 1
+	 if lrd_count > 5 then
+	    lrd_count = 0
+	    adam_config.learningRate = adam_config.learningRate * 0.8
+	    print("* learning rate decay: " .. adam_config.learningRate)
 	 end
-	 print("current: " .. score .. ", best: " .. best_score)
       end
+      print("current: " .. score .. ", best: " .. best_score)
+      collectgarbage()
    end
 end
 torch.manualSeed(settings.seed)

+ 6 - 6
train.sh

@@ -1,10 +1,10 @@
 #!/bin/sh
 
-th train.lua -method noise -noise_level 1 -test images/miku_noisy.png
-th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii
+th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png
+th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii
 
-th train.lua -method noise -noise_level 2 -test images/miku_noisy.png
-th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii
+th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png
+th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii
 
-th train.lua -method scale -scale 2 -test images/miku_small.png
-th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii
+th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png
+th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii

+ 9 - 9
waifu2x.lua

@@ -1,4 +1,4 @@
-require 'cudnn'
+require './lib/portable'
 require 'sys'
 require 'pl'
 require './lib/LeakyReLU'
@@ -24,18 +24,18 @@ local function convert_image(opt)
    if opt.m == "noise" then
       local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
       model:evaluate()
-      new_x = reconstruct.image(model, x, BLOCK_OFFSET)
+      new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
    elseif opt.m == "scale" then
       local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
       model:evaluate()
-      new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET)
+      new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
    elseif opt.m == "noise_scale" then
       local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
       local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
       noise_model:evaluate()
       scale_model:evaluate()
       x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
-      new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET)
+      new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
    else
       error("undefined method:" .. opt.method)
    end
@@ -63,17 +63,17 @@ local function convert_frames(opt)
 	      local x = image_loader.load_float(lines[i])
 	      local new_x = nil
 	      if opt.m == "noise" and opt.noise_level == 1 then
-		 new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
+		 new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
 	      elseif opt.m == "noise" and opt.noise_level == 2 then
 		 new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
 	      elseif opt.m == "scale" then
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET)
+		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
 	      elseif opt.m == "noise_scale" and opt.noise_level == 1 then
 		 x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET)
+		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
 	      elseif opt.m == "noise_scale" and opt.noise_level == 2 then
 		 x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET)
+		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
 	      else
 		 error("undefined method:" .. opt.method)
 	      end
@@ -106,7 +106,7 @@ local function waifu2x()
    cmd:option("-l", "", 'path of the image-list')
    cmd:option("-scale", 2, 'scale factor')
    cmd:option("-o", "(auto)", 'path of the output file')
-   cmd:option("-model_dir", "./models", 'model directory')
+   cmd:option("-model_dir", "./models/anime_style_art", 'model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2)')
    cmd:option("-crop_size", 128, 'patch size per process')

+ 18 - 21
web.lua

@@ -1,37 +1,34 @@
-local ROOT = '/home/ubuntu/waifu2x'
-
-_G.TURBO_SSL = true -- Enable SSL
 local turbo = require 'turbo'
 local uuid = require 'uuid'
 local ffi = require 'ffi'
 local md5 = require 'md5'
-require 'torch'
-require 'cudnn'
 require 'pl'
 
 torch.setdefaulttensortype('torch.FloatTensor')
 torch.setnumthreads(4)
 
-package.path = package.path .. ";" .. path.join(ROOT, 'lib', '?.lua')
+require './lib/portable'
+require './lib/LeakyReLU'
+
+local iproc = require './lib/iproc'
+local reconstruct = require './lib/reconstruct'
+local image_loader = require './lib/image_loader'
 
-require 'LeakyReLU'
-local iproc = require 'iproc'
-local reconstruct = require 'reconstruct'
-local image_loader = require 'image_loader'
+local MODEL_DIR = "./models/anime_style_art"
 
-local noise1_model = torch.load(path.join(ROOT, "models", "noise1_model.t7"), "ascii")
-local noise2_model = torch.load(path.join(ROOT, "models", "noise2_model.t7"), "ascii")
-local scale20_model = torch.load(path.join(ROOT, "models", "scale2.0x_model.t7"), "ascii")
+local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
+local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
+local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
 
 local USE_CACHE = true
-local CACHE_DIR = path.join(ROOT, "cache")
+local CACHE_DIR = "./cache"
 local MAX_NOISE_IMAGE = 2560 * 2560
 local MAX_SCALE_IMAGE = 1280 * 1280
 local CURL_OPTIONS = {
-   request_timeout = 10,
-   connect_timeout = 5,
+   request_timeout = 15,
+   connect_timeout = 10,
    allow_redirects = true,
-   max_redirects = 1
+   max_redirects = 2
 }
 local CURL_MAX_SIZE = 2 * 1024 * 1024
 local BLOCK_OFFSET = 7 -- see srcnn.lua
@@ -171,8 +168,8 @@ function APIHandler:post()
    collectgarbage()
 end
 local FormHandler = class("FormHandler", turbo.web.RequestHandler)
-local index_ja = file.read(path.join(ROOT, "assets/index.ja.html"))
-local index_en = file.read(path.join(ROOT, "assets/index.html"))
+local index_ja = file.read("./assets/index.ja.html")
+local index_en = file.read("./assets/index.html")
 function FormHandler:get()
    local lang = self.request.headers:get("Accept-Language")
    if lang then
@@ -193,8 +190,8 @@ end
 local app = turbo.web.Application:new(
    {
       {"^/$", FormHandler},
-      {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
-      {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
+      {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
+      {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
       {"^/api$", APIHandler},
    }
 )

部分文件因文件數量過多而無法顯示