nagadomi 9 éve
szülő
commit
3abc5a03e3

+ 10 - 4
.gitignore

@@ -1,8 +1,14 @@
 *~
-/*.png
-/*.mp4
-/*.jpg
+work/
 cache/*.png
-models/*.png
+data/
+!data/.gitkeep
+
+models/
+!models/anime_style_art
+!models/anime_style_art_rgb
+!models/ukbench
 models/*/*.png
+
 waifu2x.log
+

+ 0 - 280
benchmark.lua

@@ -1,280 +0,0 @@
-require './lib/portable'
-require './lib/mynn'
-require 'xlua'
-require 'pl'
-
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local image_loader = require './lib/image_loader'
-local gm = require 'graphicsmagick'
-
-local cmd = torch.CmdLine()
-cmd:text()
-cmd:text("waifu2x-benchmark")
-cmd:text("Options:")
-
-cmd:option("-seed", 11, 'fixed input seed')
-cmd:option("-test_dir", "./test", 'test image directory')
-cmd:option("-jpeg_quality", 50, 'jpeg quality')
-cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
-cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
-cmd:option("-core", 4, 'threads')
-
-local opt = cmd:parse(arg)
-torch.setnumthreads(opt.core)
-torch.setdefaulttensortype('torch.FloatTensor')
-
-local function MSE(x1, x2)
-   return (x1 - x2):pow(2):mean()
-end
-local function YMSE(x1, x2)
-   local x1_2 = x1:clone()
-   local x2_2 = x2:clone()
-
-   x1_2[1]:mul(0.299 * 3)
-   x1_2[2]:mul(0.587 * 3)
-   x1_2[3]:mul(0.114 * 3)
-   
-   x2_2[1]:mul(0.299 * 3)
-   x2_2[2]:mul(0.587 * 3)
-   x2_2[3]:mul(0.114 * 3)
-   
-   return (x1_2 - x2_2):pow(2):mean()
-end
-local function PSNR(x1, x2)
-   local mse = MSE(x1, x2)
-   return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
-end
-local function YPSNR(x1, x2)
-   local mse = YMSE(x1, x2)
-   return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
-end
-
-local function transform_jpeg(x)
-   for i = 1, opt.jpeg_times do
-      jpeg = gm.Image(x, "RGB", "DHW")
-      jpeg:format("jpeg")
-      jpeg:samplingFactors({1.0, 1.0, 1.0})
-      blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
-      jpeg:fromBlob(blob, len)
-      x = jpeg:toTensor("byte", "RGB", "DHW")
-   end
-   return x
-end
-
-local function noise_benchmark(x, v1_noise, v2_noise)
-   local v1_mse = 0
-   local v2_mse = 0
-   local jpeg_mse = 0
-   local v1_psnr = 0
-   local v2_psnr = 0
-   local jpeg_psnr = 0
-   local v1_time = 0
-   local v2_time = 0
-   
-   for i = 1, #x do
-      local ground_truth = x[i]
-      local jpg, blob, len, input, v1_out, v2_out, t, mse
-
-      input = transform_jpeg(ground_truth)
-      input = input:float():div(255)
-      ground_truth = ground_truth:float():div(255)
-      
-      jpeg_mse = jpeg_mse + MSE(ground_truth, input)
-      jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
-      
-      t = sys.clock()
-      v1_output = reconstruct.image(v1_noise, input)
-      v1_time = v1_time + (sys.clock() - t)
-      v1_mse = v1_mse + MSE(ground_truth, v1_output)
-      v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
-      
-      t = sys.clock()
-      v2_output = reconstruct.image(v2_noise, input)
-      v2_time = v2_time + (sys.clock() - t)
-      v2_mse = v2_mse + MSE(ground_truth, v2_output)
-      v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
-      
-      io.stdout:write(
-	 string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
-		       i, #x,
-		       v1_time / i, v2_time / i,
-		       jpeg_mse / i,
-		       v1_mse / i, v2_mse / i,
-		       jpeg_psnr / i,
-		       v1_psnr / i, v2_psnr / i
-	 )
-      )
-      io.stdout:flush()
-   end
-   io.stdout:write("\n")
-end
-local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
-   local v1_mse = 0
-   local v2_mse = 0
-   local jinc_mse = 0
-   local v1_time = 0
-   local v2_time = 0
-   
-   for i = 1, #x do
-      local ground_truth = x[i]
-      local downscale = iproc.scale(ground_truth,
-				    ground_truth:size(3) * 0.5,
-				    ground_truth:size(2) * 0.5,
-				    params[i].filter)
-      local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
-      
-      jpeg = gm.Image(downscale, "RGB", "DHW")
-      jpeg:format("jpeg")
-      blob, len = jpeg:toBlob(params[i].quality)
-      jpeg:fromBlob(blob, len)
-      input = jpeg:toTensor("byte", "RGB", "DHW")
-
-      input = input:float():div(255)
-      ground_truth = ground_truth:float():div(255)
-
-      jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
-      jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
-      
-      t = sys.clock()
-      v1_output = reconstruct.image(v1_noise, input)
-      v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
-      v1_time = v1_time + (sys.clock() - t)
-      mse = (ground_truth - v1_output):pow(2):mean()
-      v1_mse = v1_mse + mse
-      
-      t = sys.clock()
-      v2_output = reconstruct.image(v2_noise, input)
-      v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
-      v2_time = v2_time + (sys.clock() - t)
-      mse = (ground_truth - v2_output):pow(2):mean()
-      v2_mse = v2_mse + mse
-      
-      io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
-				    i, #x,
-				    v1_time / i, v2_time / i,
-				    (v1_time / i) / (v2_time / i),
-				    jinc_mse / i,
-				    v1_mse / i, (v1_mse/i) / (jinc_mse/i),
-				    v2_mse / i, (v2_mse/i) / (jinc_mse/i),
-				    (v1_mse / i) / (v2_mse / i)))
-				    
-      io.stdout:flush()
-   end
-   io.stdout:write("\n")
-end
-local function scale_benchmark(x, params, v1_scale, v2_scale)
-   local v1_mse = 0
-   local v2_mse = 0
-   local jinc_mse = 0
-   local v1_psnr = 0
-   local v2_psnr = 0
-   local jinc_psnr = 0
-   
-   local v1_time = 0
-   local v2_time = 0
-   
-   for i = 1, #x do
-      local ground_truth = x[i]
-      local downscale = iproc.scale(ground_truth,
-				    ground_truth:size(3) * 0.5,
-				    ground_truth:size(2) * 0.5,
-				    params[i].filter)
-      local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
-      input = downscale
-
-      input = input:float():div(255)
-      ground_truth = ground_truth:float():div(255)
-
-      jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
-      mse = (ground_truth - jinc_output):pow(2):mean()
-      jinc_mse = jinc_mse + mse
-      jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
-      
-      t = sys.clock()
-      v1_output = reconstruct.scale(v1_scale, 2.0, input)
-      v1_time = v1_time + (sys.clock() - t)
-      mse = (ground_truth - v1_output):pow(2):mean()
-      v1_mse = v1_mse + mse
-      v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
-      
-      t = sys.clock()
-      v2_output = reconstruct.scale(v2_scale, 2.0, input)
-      v2_time = v2_time + (sys.clock() - t)
-      mse = (ground_truth - v2_output):pow(2):mean()
-      v2_mse = v2_mse + mse
-      v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
-      
-      io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
-				    i, #x,
-				    v1_time / i, v2_time / i,
-				    (v1_time / i) / (v2_time / i),
-				    jinc_psnr / i,
-				    v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
-				    v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
-				    (v1_psnr / i) / (v2_psnr / i)))
-				    
-      io.stdout:flush()
-   end
-   io.stdout:write("\n")
-end
-
-local function split_data(x, test_size)
-   local index = torch.randperm(#x)
-   local train_size = #x - test_size
-   local train_x = {}
-   local valid_x = {}
-   for i = 1, train_size do
-      train_x[i] = x[index[i]]
-   end
-   for i = 1, test_size do
-      valid_x[i] = x[index[train_size + i]]
-   end
-   return train_x, valid_x
-end
-local function crop_4x(x)
-   local w = x:size(3) % 4
-   local h = x:size(2) % 4
-   return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
-end
-local function load_data(valid_dir)
-   local valid_x = {}
-   local files = dir.getfiles(valid_dir, "*.png")
-   for i = 1, #files do
-      table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
-      xlua.progress(i, #files)
-   end
-   return valid_x
-end
-
-local function noise_main(valid_dir, level)
-   local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
-   local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
-   local valid_x = load_data(valid_dir)
-   noise_benchmark(valid_x, v1_noise, v2_noise)
-end
-local function scale_main(valid_dir)
-   local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
-   local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
-   local valid_x = load_data(valid_dir)
-   local params = random_params(valid_x, 2)
-   scale_benchmark(valid_x, params, v1, v2)
-end
-local function noise_scale_main(valid_dir)
-   local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
-   local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
-   local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
-   local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
-   local valid_x = load_data(valid_dir)
-   local params = random_params(valid_x, 2)
-   noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
-end
-
-V1_DIR = "models/anime_style_art_rgb"
-V2_DIR = "models/anime_style_art_rgb5"
-
-torch.manualSeed(opt.seed)
-cutorch.manualSeed(opt.seed)
-noise_main("./test", 2)
---scale_main("./test")
---noise_scale_main("./test")

+ 19 - 29
convert_data.lua

@@ -1,22 +1,14 @@
-local ffi = require 'ffi'
-require './lib/portable'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
+
+require 'pl'
 require 'image'
-require 'snappy'
-local settings = require './lib/settings'
-local image_loader = require './lib/image_loader'
+local compression = require 'compression'
+local settings = require 'settings'
+local image_loader = require 'image_loader'
 
 local MAX_SIZE = 1440
 
-local function count_lines(file)
-   local fp = io.open(file, "r")
-   local count = 0
-   for line in fp:lines() do
-      count = count + 1
-   end
-   fp:close()
-   
-   return count
-end
 local function crop_if_large(src, max_size)
    if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
       local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
@@ -36,40 +28,38 @@ end
 
 local function load_images(list)
    local MARGIN = 32
-   local count = count_lines(list)
-   local fp = io.open(list, "r")
+   local lines = utils.split(file.read(list), "\n")
    local x = {}
-   local c = 0
-   for line in fp:lines() do
+   for i = 1, #lines do
+      local line = lines[i]
       local im, alpha = image_loader.load_byte(line)
-      im = crop_if_large(im, settings.max_size)
-      im = crop_4x(im)
-      
       if alpha then
-	 io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
+	 io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
       else
+	 im = crop_if_large(im, settings.max_size)
+	 im = crop_4x(im)
 	 local scale = 1.0
 	 if settings.random_half then
 	    scale = 2.0
 	 end
 	 if im then
 	    if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
-	       table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
+	       table.insert(x, compression.compress(im))
 	    else
-	       io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
+	       io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
 	    end
 	 else
-	    io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
+	    io.stderr:write(string.format("\n%s: skip: load error.\n", line))
 	 end
       end
-      c = c + 1
-      xlua.progress(c, count)
-      if c % 10 == 0 then
+      xlua.progress(i, #lines)
+      if i % 10 == 0 then
 	 collectgarbage()
       end
    end
    return x
 end
+
 torch.manualSeed(settings.seed)
 print(settings)
 local x = load_images(settings.image_list)

+ 6 - 4
lib/DepthExpand2x.lua

@@ -1,7 +1,7 @@
-if mynn.DepthExpand2x then
-   return mynn.DepthExpand2x
+if w2nn.DepthExpand2x then
+   return w2nn.DepthExpand2x
 end
-local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module')
+local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
  
 function DepthExpand2x:__init()
    parent:__init()
@@ -67,9 +67,11 @@ function DepthExpand2x.test()
    end
    show(x)
    
-   local de2x = mynn.DepthExpand2x()
+   local de2x = w2nn.DepthExpand2x()
    out = de2x:forward(x)
    show(out)
    out = de2x:updateGradInput(x, out)
    show(out)
 end
+
+return DepthExpand2x

+ 3 - 3
lib/LeakyReLU.lua

@@ -1,8 +1,8 @@
-if mynn.LeakyReLU then
-   return mynn.LeakyReLU
+if w2nn and w2nn.LeakyReLU then
+   return w2nn.LeakyReLU
 end
 
-local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module')
+local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
  
 function LeakyReLU:__init(negative_scale)
    parent.__init(self)

+ 4 - 5
lib/RGBWeightedMSECriterion.lua → lib/WeightedMSECriterion.lua

@@ -1,13 +1,13 @@
-local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
+local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
 
-function RGBWeightedMSECriterion:__init(w)
+function WeightedMSECriterion:__init(w)
    parent.__init(self)
    self.weight = w:clone()
    self.diff = torch.Tensor()
    self.loss = torch.Tensor()
 end
 
-function RGBWeightedMSECriterion:updateOutput(input, target)
+function WeightedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
    for i = 1, input:size(1) do
       self.diff[i]:add(-1, target[i]):cmul(self.weight)
@@ -18,8 +18,7 @@ function RGBWeightedMSECriterion:updateOutput(input, target)
    return self.output
 end
 
-function RGBWeightedMSECriterion:updateGradInput(input, target)
+function WeightedMSECriterion:updateGradInput(input, target)
    self.gradInput:resizeAs(input):copy(self.diff)
    return self.gradInput
 end
-

+ 17 - 0
lib/compression.lua

@@ -0,0 +1,17 @@
+-- snapply compression for ByteTensor
+require 'snappy'
+
+local compression = {}
+compression.compress = function (bt)
+   local enc = snappy.compress(bt:storage():string())
+   return {bt:size(), torch.ByteStorage():string(enc)}
+end
+compression.decompress = function(data)
+   local size = data[1]
+   local dec = snappy.decompress(data[2]:string())
+   local bt = torch.ByteTensor(unpack(torch.totable(size)))
+   bt:storage():string(dec)
+   return bt
+end
+
+return compression

+ 3 - 3
lib/image_loader.lua

@@ -17,7 +17,7 @@ function image_loader.encode_png(rgb, alpha)
    end
    if alpha then
       if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
-	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
+	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
       end
       local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
       rgba[1]:copy(rgb[1])
@@ -50,8 +50,8 @@ function image_loader.decode_byte(blob)
       if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
 	 -- split alpha channel
 	 im = im:toTensor('float', 'RGBA', 'DHW')
-	 local sum_alpha = (im[4] - 1):sum()
-	 if sum_alpha > 0 or sum_alpha < 0 then
+	 local sum_alpha = (im[4] - 1.0):sum()
+	 if sum_alpha < 0 then
 	    alpha = im[4]:reshape(1, im:size(2), im:size(3))
 	 end
 	 local new_im = torch.FloatTensor(3, im:size(2), im:size(3))

+ 0 - 1
lib/iproc.lua

@@ -22,5 +22,4 @@ function iproc.padding(img, w1, w2, h1, h2)
    flow[2]:add(-w1)
    return image.warp(img, flow, "simple", false, "clamp")
 end
-
 return iproc

+ 0 - 20
lib/mynn.lua

@@ -1,20 +0,0 @@
-local function load_cunn()
-   require 'nn'
-   require 'cunn'
-end
-local function load_cudnn()
-   require 'cudnn'
-   cudnn.fastest = true
-end
-if mynn then
-   return mynn
-else
-   load_cunn()
-   --load_cudnn()
-   mynn = {}
-   require './LeakyReLU'
-   require './LeakyReLU_deprecated'
-   require './DepthExpand2x'
-   require './RGBWeightedMSECriterion'
-   return mynn
-end

+ 12 - 20
lib/pairwise_transform.lua

@@ -1,7 +1,7 @@
 require 'image'
 local gm = require 'graphicsmagick'
-local iproc = require './iproc'
-local reconstruct = require './reconstruct'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
 local pairwise_transform = {}
 
 local function random_half(src, p)
@@ -81,6 +81,11 @@ local function color_noise(src)
    
    return x:mul(255):byte()
 end
+local function shift_1px(src)
+   -- reducing the even/odd issue in nearest neighbor.
+   local r = torch.random(1, 4)
+   
+end
 local function flip_augment(x, y)
    local flip = torch.random(1, 4)
    if y then
@@ -138,17 +143,16 @@ local function data_augment(y, options)
    return y
 end
 
-
 local INTERPOLATION_PADDING = 16
 function pairwise_transform.scale(src, scale, size, offset, n, options)
    local filters = {
-      "Box","Box","Box",  -- 0.012756949974688
+      "Box","Box",  -- 0.012756949974688
       "Blackman",   -- 0.013191924552285
       --"Cartom",     -- 0.013753536746706
       --"Hanning",    -- 0.013761314529647
       --"Hermite",    -- 0.013850225205266
       "SincFast",   -- 0.014095824314306
-      "Jinc",       -- 0.014244299255442
+      --"Jinc",       -- 0.014244299255442
    }
    if options.random_half then
       src = random_half(src)
@@ -176,26 +180,14 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
    return batch
 end
 function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
-   if options.random_half then
-      src = random_half(src)
-   end
-   src = crop_if_large(src, math.max(size * 4, 512))
-   local y = src
-   local x
-
-   if options.color_noise then
-      y = color_noise(y)
-   end
-   if options.overlay then
-      y = overlay_augment(y)
-   end
-   x = y
+   local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options)   
+   local x = y
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
       if options.jpeg_sampling_factors == 444 then
 	 x:samplingFactors({1.0, 1.0, 1.0})
-      else -- 422
+      else -- 420
 	 x:samplingFactors({2.0, 1.0, 1.0})
       end
       local blob, len = x:toBlob(quality[i])

+ 0 - 17
lib/portable.lua

@@ -1,17 +0,0 @@
-require 'torch'
-require 'nn'
-
-local function load_cuda()
-   require 'cutorch'
-   require 'cunn'
-end
-local function load_cudnn()
-   require 'cudnn'
-   --cudnn.fastest = true
-end
-
-if pcall(load_cuda) then
-else
-end
-if pcall(load_cudnn) then
-end

+ 1 - 1
lib/reconstruct.lua

@@ -1,5 +1,5 @@
 require 'image'
-local iproc = require './iproc'
+local iproc = require 'iproc'
 
 local function reconstruct_y(model, x, offset, block_size)
    if x:dim() == 2 then

+ 4 - 2
lib/settings.lua

@@ -35,7 +35,7 @@ cmd:option("-crop_size", 128, 'crop size')
 cmd:option("-max_size", -1, 'crop if image size larger then this value.')
 cmd:option("-batch_size", 2, 'mini batch size')
 cmd:option("-epoch", 200, 'epoch')
-cmd:option("-core", 2, 'cpu core')
+cmd:option("-thread", -1, 'number of CPU threads')
 cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
 cmd:option("-validation_ratio", 0.1, 'validation ratio')
 cmd:option("-validation_crops", 40, 'number of crop region in validation')
@@ -84,7 +84,9 @@ else
    settings.overlay = false
 end
 
-torch.setnumthreads(settings.core)
+if settings.thread > 0 then
+   torch.setnumthreads(tonumber(settings.thread))
+end
 
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)

+ 13 - 14
lib/srcnn.lua

@@ -1,5 +1,4 @@
-
-require './mynn'
+require 'w2nn'
 
 -- ref: http://arxiv.org/abs/1502.01852
 -- ref: http://arxiv.org/abs/1501.00092
@@ -7,17 +6,17 @@ local srcnn = {}
 function srcnn.waifu2x_cunn(ch)
    local model = nn.Sequential()
    model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    --model:cuda()
@@ -28,17 +27,17 @@ end
 function srcnn.waifu2x_cudnn(ch)
    local model = nn.Sequential()
    model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(mynn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    --model:cuda()

+ 24 - 0
lib/w2nn.lua

@@ -0,0 +1,24 @@
+local function load_nn()
+   require 'torch'
+   require 'nn'
+end
+local function load_cunn()
+   require 'cutorch'
+   require 'cunn'
+end
+local function load_cudnn()
+   require 'cudnn'
+   cudnn.fastest = true
+end
+if w2nn then
+   return w2nn
+else
+   pcall(load_cunn)
+   pcall(load_cudnn)
+   w2nn = {}
+   require 'LeakyReLU'
+   require 'LeakyReLU_deprecated'
+   require 'DepthExpand2x'
+   require 'WeightedMSECriterion'
+   return w2nn
+end

+ 148 - 0
tools/benchmark.lua

@@ -0,0 +1,148 @@
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'xlua'
+require 'pl'
+
+require 'w2nn'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
+local gm = require 'graphicsmagick'
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x-benchmark")
+cmd:text("Options:")
+
+cmd:option("-seed", 11, 'fixed input seed')
+cmd:option("-dir", "./data/test", 'test image directory')
+cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
+cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
+cmd:option("-method", "scale", '(scale|noise)')
+cmd:option("-noise_level", 1, '(1|2)')
+cmd:option("-color_weight", "y", '(y|rgb)')
+cmd:option("-jpeg_quality", 75, 'jpeg quality')
+cmd:option("-jpeg_times", 1, 'jpeg compression times')
+cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
+
+local opt = cmd:parse(arg)
+torch.setdefaulttensortype('torch.FloatTensor')
+
+local function MSE(x1, x2)
+   return (x1 - x2):pow(2):mean()
+end
+local function YMSE(x1, x2)
+   local x1_2 = x1:clone()
+   local x2_2 = x2:clone()
+
+   x1_2[1]:mul(0.299 * 3)
+   x1_2[2]:mul(0.587 * 3)
+   x1_2[3]:mul(0.114 * 3)
+   
+   x2_2[1]:mul(0.299 * 3)
+   x2_2[2]:mul(0.587 * 3)
+   x2_2[3]:mul(0.114 * 3)
+   
+   return (x1_2 - x2_2):pow(2):mean()
+end
+local function PSNR(x1, x2)
+   local mse = MSE(x1, x2)
+   return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
+end
+local function YPSNR(x1, x2)
+   local mse = YMSE(x1, x2)
+   return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
+end
+
+local function transform_jpeg(x)
+   for i = 1, opt.jpeg_times do
+      jpeg = gm.Image(x, "RGB", "DHW")
+      jpeg:format("jpeg")
+      jpeg:samplingFactors({1.0, 1.0, 1.0})
+      blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
+      jpeg:fromBlob(blob, len)
+      x = jpeg:toTensor("byte", "RGB", "DHW")
+   end
+   return x
+end
+local function transform_scale(x)
+   return iproc.scale(x,
+		      x:size(3) * 0.5,
+		      x:size(2) * 0.5,
+		      "Box")
+end
+
+local function benchmark(color_weight, x, input_func, v1_noise, v2_noise)
+   local v1_mse = 0
+   local v2_mse = 0
+   local v1_psnr = 0
+   local v2_psnr = 0
+   
+   for i = 1, #x do
+      local ground_truth = x[i]
+      local input, v1_output, v2_output
+
+      input = input_func(ground_truth)
+      input = input:float():div(255)
+      ground_truth = ground_truth:float():div(255)
+      
+      t = sys.clock()
+      if input:size(3) == ground_truth:size(3) then
+	 v1_output = reconstruct.image(v1_noise, input)
+	 v2_output = reconstruct.image(v2_noise, input)
+      else
+	 v1_output = reconstruct.scale(v1_noise, 2.0, input)
+	 v2_output = reconstruct.scale(v2_noise, 2.0, input)
+      end
+      if color_weight == "y" then
+	 v1_mse = v1_mse + YMSE(ground_truth, v1_output)
+	 v1_psnr = v1_psnr + YPSNR(ground_truth, v1_output)
+	 v2_mse = v2_mse + YMSE(ground_truth, v2_output)
+	 v2_psnr = v2_psnr + YPSNR(ground_truth, v2_output)
+      elseif color_weight == "rgb" then
+	 v1_mse = v1_mse + MSE(ground_truth, v1_output)
+	 v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
+	 v2_mse = v2_mse + MSE(ground_truth, v2_output)
+	 v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
+      end
+      
+      io.stdout:write(
+	 string.format("%d/%d; v1_mse=%f, v2_mse=%f, v1_psnr=%f, v2_psnr=%f \r",
+		       i, #x,
+		       v1_mse / i, v2_mse / i,
+		       v1_psnr / i, v2_psnr / i
+	 )
+      )
+      io.stdout:flush()
+   end
+   io.stdout:write("\n")
+end
+local function crop_4x(x)
+   local w = x:size(3) % 4
+   local h = x:size(2) % 4
+   return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
+end
+local function load_data(test_dir)
+   local test_x = {}
+   local files = dir.getfiles(test_dir, "*.*")
+   for i = 1, #files do
+      table.insert(test_x, crop_4x(image_loader.load_byte(files[i])))
+      xlua.progress(i, #files)
+   end
+   return test_x
+end
+
+print(opt)
+torch.manualSeed(opt.seed)
+cutorch.manualSeed(opt.seed)
+if opt.method == "scale" then
+   local v1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
+   local v2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
+   local test_x = load_data(opt.dir)
+   benchmark(opt.color_weight, test_x, transform_scale, v1, v2)
+elseif opt.method == "noise" then
+   local v1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
+   local v2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
+   local test_x = load_data(opt.dir)
+   benchmark(opt.color_weight, test_x, transform_jpeg, v1, v2)
+end

+ 4 - 3
cleanup_model.lua → tools/cleanup_model.lua

@@ -1,6 +1,7 @@
-require './lib/portable'
-require './lib/mynn'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
 
+require 'w2nn'
 torch.setdefaulttensortype("torch.FloatTensor")
 
 -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
@@ -27,7 +28,7 @@ local function cleanupModel(node)
    if node.finput ~= nil then
       node.finput = zeroDataSize(node.finput)
    end
-   if tostring(node) == "nn.LeakyReLU" then
+   if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
       if node.negative ~= nil then
 	 node.negative = zeroDataSize(node.negative)
       end

+ 3 - 2
export_model.lua → tools/export_model.lua

@@ -1,6 +1,7 @@
 -- adapted from https://github.com/marcan/cl-waifu2x
-require './lib/portable'
-require './lib/LeakyReLU'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'w2nn'
 local cjson = require "cjson"
 
 local model = torch.load(arg[1], "ascii")

+ 13 - 16
train.lua

@@ -1,17 +1,18 @@
-require './lib/portable'
-require './lib/mynn'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 require 'optim'
 require 'xlua'
 require 'pl'
-require 'snappy'
 
-local settings = require './lib/settings'
-local srcnn = require './lib/srcnn'
-local minibatch_adam = require './lib/minibatch_adam'
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local pairwise_transform = require './lib/pairwise_transform'
-local image_loader = require './lib/image_loader'
+require 'w2nn'
+local settings = require 'settings'
+local srcnn = require 'srcnn'
+local minibatch_adam = require 'minibatch_adam'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local compression = require 'compression'
+local pairwise_transform = require 'pairwise_transform'
+local image_loader = require 'image_loader'
 
 local function save_test_scale(model, rgb, file)
    local up = reconstruct.scale(model, settings.scale, rgb)
@@ -73,17 +74,13 @@ local function create_criterion(model)
       weight[1]:fill(0.299 * 3) -- R
       weight[2]:fill(0.587 * 3) -- G
       weight[3]:fill(0.114 * 3) -- B
-      return mynn.RGBWeightedMSECriterion(weight):cuda()
+      return w2nn.WeightedMSECriterion(weight):cuda()
    else
       return nn.MSECriterion():cuda()
    end
 end
 local function transformer(x, is_validation, n, offset)
-   local size = x[1]
-   local dec = snappy.decompress(x[2]:string())
-   x = torch.ByteTensor(size[1], size[2], size[3])
-   x:storage():string(dec)
-   
+   x = compression.decompress(x)
    n = n or settings.batch_size;
    if is_validation == nil then is_validation = false end
    local color_noise = nil 

+ 11 - 7
waifu2x.lua

@@ -1,11 +1,11 @@
-require './lib/portable'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 require 'sys'
 require 'pl'
-require './lib/mynn'
-
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local image_loader = require './lib/image_loader'
+require 'w2nn'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
@@ -111,8 +111,12 @@ local function waifu2x()
    cmd:option("-noise_level", 1, '(1|2)')
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
-   
+   cmd:option("-thread", -1, "number of CPU threads")
+
    local opt = cmd:parse(arg)
+   if opt.thread > 0 then
+      torch.setnumthreads(opt.thread)
+   end
    if string.len(opt.l) == 0 then
       convert_image(opt)
    else

+ 13 - 11
web.lua

@@ -1,11 +1,16 @@
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 _G.TURBO_SSL = true
+
+require 'pl'
+require 'w2nn'
 local turbo = require 'turbo'
 local uuid = require 'uuid'
 local ffi = require 'ffi'
 local md5 = require 'md5'
-require 'pl'
-require 'lib.portable'
-require 'lib.mynn'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
 
 local cmd = torch.CmdLine()
 cmd:text()
@@ -13,18 +18,15 @@ cmd:text("waifu2x-api")
 cmd:text("Options:")
 cmd:option("-port", 8812, 'listen port')
 cmd:option("-gpu", 1, 'Device ID')
-cmd:option("-core", 2, 'number of CPU cores')
+cmd:option("-thread", -1, 'number of CPU threads')
 local opt = cmd:parse(arg)
 cutorch.setDevice(opt.gpu)
 torch.setdefaulttensortype('torch.FloatTensor')
-torch.setnumthreads(opt.core)
-
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local image_loader = require './lib/image_loader'
-
-local MODEL_DIR = "./models/anime_style_art_rgb3"
+if opt.thread > 0 then
+   torch.setnumthreads(opt.thread)
+end
 
+local MODEL_DIR = "./models/anime_style_art_rgb"
 local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
 local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
 local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")