فهرست منبع

sync from internal repo

- Memory compression by snappy (lua-csnappy)
- Use RGB-wise Weighted MSE(R*0.299, G*0.587, B*0.114) instead of MSE
- Aggressive cropping for edge region
and some change.
nagadomi 9 سال پیش
والد
کامیت
8dea362bed

+ 4 - 0
.gitignore

@@ -1,4 +1,8 @@
 *~
+/*.png
+/*.mp4
+/*.jpg
 cache/*.png
 models/*.png
+models/*/*.png
 waifu2x.log

+ 280 - 0
benchmark.lua

@@ -0,0 +1,280 @@
+require './lib/portable'
+require './lib/mynn'
+require 'xlua'
+require 'pl'
+
+local iproc = require './lib/iproc'
+local reconstruct = require './lib/reconstruct'
+local image_loader = require './lib/image_loader'
+local gm = require 'graphicsmagick'
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x-benchmark")
+cmd:text("Options:")
+
+cmd:option("-seed", 11, 'fixed input seed')
+cmd:option("-test_dir", "./test", 'test image directory')
+cmd:option("-jpeg_quality", 50, 'jpeg quality')
+cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
+cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
+cmd:option("-core", 4, 'threads')
+
+local opt = cmd:parse(arg)
+torch.setnumthreads(opt.core)
+torch.setdefaulttensortype('torch.FloatTensor')
+
+local function MSE(x1, x2)
+   return (x1 - x2):pow(2):mean()
+end
+local function YMSE(x1, x2)
+   local x1_2 = x1:clone()
+   local x2_2 = x2:clone()
+
+   x1_2[1]:mul(0.299 * 3)
+   x1_2[2]:mul(0.587 * 3)
+   x1_2[3]:mul(0.114 * 3)
+   
+   x2_2[1]:mul(0.299 * 3)
+   x2_2[2]:mul(0.587 * 3)
+   x2_2[3]:mul(0.114 * 3)
+   
+   return (x1_2 - x2_2):pow(2):mean()
+end
+local function PSNR(x1, x2)
+   local mse = MSE(x1, x2)
+   return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
+end
+local function YPSNR(x1, x2)
+   local mse = YMSE(x1, x2)
+   return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
+end
+
+local function transform_jpeg(x)
+   for i = 1, opt.jpeg_times do
+      jpeg = gm.Image(x, "RGB", "DHW")
+      jpeg:format("jpeg")
+      jpeg:samplingFactors({1.0, 1.0, 1.0})
+      blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
+      jpeg:fromBlob(blob, len)
+      x = jpeg:toTensor("byte", "RGB", "DHW")
+   end
+   return x
+end
+
+local function noise_benchmark(x, v1_noise, v2_noise)
+   local v1_mse = 0
+   local v2_mse = 0
+   local jpeg_mse = 0
+   local v1_psnr = 0
+   local v2_psnr = 0
+   local jpeg_psnr = 0
+   local v1_time = 0
+   local v2_time = 0
+   
+   for i = 1, #x do
+      local ground_truth = x[i]
+      local jpg, blob, len, input, v1_out, v2_out, t, mse
+
+      input = transform_jpeg(ground_truth)
+      input = input:float():div(255)
+      ground_truth = ground_truth:float():div(255)
+      
+      jpeg_mse = jpeg_mse + MSE(ground_truth, input)
+      jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
+      
+      t = sys.clock()
+      v1_output = reconstruct.image(v1_noise, input)
+      v1_time = v1_time + (sys.clock() - t)
+      v1_mse = v1_mse + MSE(ground_truth, v1_output)
+      v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
+      
+      t = sys.clock()
+      v2_output = reconstruct.image(v2_noise, input)
+      v2_time = v2_time + (sys.clock() - t)
+      v2_mse = v2_mse + MSE(ground_truth, v2_output)
+      v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
+      
+      io.stdout:write(
+	 string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
+		       i, #x,
+		       v1_time / i, v2_time / i,
+		       jpeg_mse / i,
+		       v1_mse / i, v2_mse / i,
+		       jpeg_psnr / i,
+		       v1_psnr / i, v2_psnr / i
+	 )
+      )
+      io.stdout:flush()
+   end
+   io.stdout:write("\n")
+end
+local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
+   local v1_mse = 0
+   local v2_mse = 0
+   local jinc_mse = 0
+   local v1_time = 0
+   local v2_time = 0
+   
+   for i = 1, #x do
+      local ground_truth = x[i]
+      local downscale = iproc.scale(ground_truth,
+				    ground_truth:size(3) * 0.5,
+				    ground_truth:size(2) * 0.5,
+				    params[i].filter)
+      local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
+      
+      jpeg = gm.Image(downscale, "RGB", "DHW")
+      jpeg:format("jpeg")
+      blob, len = jpeg:toBlob(params[i].quality)
+      jpeg:fromBlob(blob, len)
+      input = jpeg:toTensor("byte", "RGB", "DHW")
+
+      input = input:float():div(255)
+      ground_truth = ground_truth:float():div(255)
+
+      jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
+      jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
+      
+      t = sys.clock()
+      v1_output = reconstruct.image(v1_noise, input)
+      v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
+      v1_time = v1_time + (sys.clock() - t)
+      mse = (ground_truth - v1_output):pow(2):mean()
+      v1_mse = v1_mse + mse
+      
+      t = sys.clock()
+      v2_output = reconstruct.image(v2_noise, input)
+      v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
+      v2_time = v2_time + (sys.clock() - t)
+      mse = (ground_truth - v2_output):pow(2):mean()
+      v2_mse = v2_mse + mse
+      
+      io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
+				    i, #x,
+				    v1_time / i, v2_time / i,
+				    (v1_time / i) / (v2_time / i),
+				    jinc_mse / i,
+				    v1_mse / i, (v1_mse/i) / (jinc_mse/i),
+				    v2_mse / i, (v2_mse/i) / (jinc_mse/i),
+				    (v1_mse / i) / (v2_mse / i)))
+				    
+      io.stdout:flush()
+   end
+   io.stdout:write("\n")
+end
+local function scale_benchmark(x, params, v1_scale, v2_scale)
+   local v1_mse = 0
+   local v2_mse = 0
+   local jinc_mse = 0
+   local v1_psnr = 0
+   local v2_psnr = 0
+   local jinc_psnr = 0
+   
+   local v1_time = 0
+   local v2_time = 0
+   
+   for i = 1, #x do
+      local ground_truth = x[i]
+      local downscale = iproc.scale(ground_truth,
+				    ground_truth:size(3) * 0.5,
+				    ground_truth:size(2) * 0.5,
+				    params[i].filter)
+      local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
+      input = downscale
+
+      input = input:float():div(255)
+      ground_truth = ground_truth:float():div(255)
+
+      jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
+      mse = (ground_truth - jinc_output):pow(2):mean()
+      jinc_mse = jinc_mse + mse
+      jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
+      
+      t = sys.clock()
+      v1_output = reconstruct.scale(v1_scale, 2.0, input)
+      v1_time = v1_time + (sys.clock() - t)
+      mse = (ground_truth - v1_output):pow(2):mean()
+      v1_mse = v1_mse + mse
+      v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
+      
+      t = sys.clock()
+      v2_output = reconstruct.scale(v2_scale, 2.0, input)
+      v2_time = v2_time + (sys.clock() - t)
+      mse = (ground_truth - v2_output):pow(2):mean()
+      v2_mse = v2_mse + mse
+      v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
+      
+      io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
+				    i, #x,
+				    v1_time / i, v2_time / i,
+				    (v1_time / i) / (v2_time / i),
+				    jinc_psnr / i,
+				    v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
+				    v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
+				    (v1_psnr / i) / (v2_psnr / i)))
+				    
+      io.stdout:flush()
+   end
+   io.stdout:write("\n")
+end
+
+local function split_data(x, test_size)
+   local index = torch.randperm(#x)
+   local train_size = #x - test_size
+   local train_x = {}
+   local valid_x = {}
+   for i = 1, train_size do
+      train_x[i] = x[index[i]]
+   end
+   for i = 1, test_size do
+      valid_x[i] = x[index[train_size + i]]
+   end
+   return train_x, valid_x
+end
+local function crop_4x(x)
+   local w = x:size(3) % 4
+   local h = x:size(2) % 4
+   return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
+end
+local function load_data(valid_dir)
+   local valid_x = {}
+   local files = dir.getfiles(valid_dir, "*.png")
+   for i = 1, #files do
+      table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
+      xlua.progress(i, #files)
+   end
+   return valid_x
+end
+
+local function noise_main(valid_dir, level)
+   local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
+   local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
+   local valid_x = load_data(valid_dir)
+   noise_benchmark(valid_x, v1_noise, v2_noise)
+end
+local function scale_main(valid_dir)
+   local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
+   local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
+   local valid_x = load_data(valid_dir)
+   local params = random_params(valid_x, 2)
+   scale_benchmark(valid_x, params, v1, v2)
+end
+local function noise_scale_main(valid_dir)
+   local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
+   local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
+   local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
+   local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
+   local valid_x = load_data(valid_dir)
+   local params = random_params(valid_x, 2)
+   noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
+end
+
+V1_DIR = "models/anime_style_art_rgb"
+V2_DIR = "models/anime_style_art_rgb5"
+
+torch.manualSeed(opt.seed)
+cutorch.manualSeed(opt.seed)
+noise_main("./test", 2)
+--scale_main("./test")
+--noise_scale_main("./test")

+ 1 - 1
cleanup_model.lua

@@ -1,5 +1,5 @@
 require './lib/portable'
-require './lib/LeakyReLU'
+require './lib/mynn'
 
 torch.setdefaulttensortype("torch.FloatTensor")
 

+ 35 - 8
convert_data.lua

@@ -1,8 +1,12 @@
+local ffi = require 'ffi'
 require './lib/portable'
 require 'image'
+require 'snappy'
 local settings = require './lib/settings'
 local image_loader = require './lib/image_loader'
 
+local MAX_SIZE = 1440
+
 local function count_lines(file)
    local fp = io.open(file, "r")
    local count = 0
@@ -13,7 +17,17 @@ local function count_lines(file)
    
    return count
 end
-
+local function crop_if_large(src, max_size)
+   if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
+      local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
+      local sy = torch.random(0, src:size(2) - math.min(max_size, src:size(2)))
+      return image.crop(src, sx, sy,
+			math.min(sx + max_size, src:size(3)),
+			math.min(sy + max_size, src:size(2)))
+   else
+      return src
+   end
+end
 local function crop_4x(x)
    local w = x:size(3) % 4
    local h = x:size(2) % 4
@@ -27,13 +41,26 @@ local function load_images(list)
    local x = {}
    local c = 0
    for line in fp:lines() do
-      local im = crop_4x(image_loader.load_byte(line))
-      if im then
-	 if im:size(2) > (settings.crop_size * 2 + MARGIN) and im:size(3) > (settings.crop_size * 2 + MARGIN) then
-	    table.insert(x, im)
-	 end
+      local im, alpha = image_loader.load_byte(line)
+      im = crop_if_large(im, settings.max_size)
+      im = crop_4x(im)
+      
+      if alpha then
+	 io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
       else
-	 print("error:" .. line)
+	 local scale = 1.0
+	 if settings.random_half then
+	    scale = 2.0
+	 end
+	 if im then
+	    if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
+	       table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
+	    else
+	       io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
+	    end
+	 else
+	    io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
+	 end
       end
       c = c + 1
       xlua.progress(c, count)
@@ -43,7 +70,7 @@ local function load_images(list)
    end
    return x
 end
+torch.manualSeed(settings.seed)
 print(settings)
 local x = load_images(settings.image_list)
 torch.save(settings.images, x)
-

+ 0 - 0
data/.gitkeep


+ 1 - 2
images/gen.sh

@@ -1,8 +1,7 @@
 #!/bin/sh
 
-th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
+th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
-th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
 th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png

BIN
images/lena_waifu2x.png


BIN
images/miku_CC_BY-NC_noisy_waifu2x.png


BIN
images/miku_noisy_waifu2x.png


BIN
images/miku_small.png


BIN
images/miku_small_lanczos3.png


BIN
images/miku_small_noisy_waifu2x.png


BIN
images/miku_small_waifu2x.png


BIN
images/slide.odp


BIN
images/slide.png


BIN
images/slide_noise_reduction.png


BIN
images/slide_result.png


BIN
images/slide_upscaling.png


+ 75 - 0
lib/DepthExpand2x.lua

@@ -0,0 +1,75 @@
+if mynn.DepthExpand2x then
+   return mynn.DepthExpand2x
+end
+local DepthExpand2x, parent = torch.class('mynn.DepthExpand2x','nn.Module')
+ 
+function DepthExpand2x:__init()
+   parent:__init()
+end
+
+function DepthExpand2x:updateOutput(input)
+   local x = input
+   -- (batch_size, depth, height, width)
+   self.shape = x:size()
+
+   assert(self.shape:size() == 4, "input must be 4d tensor")
+   assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
+   -- (batch_size, width, height, depth)
+   x = x:transpose(2, 4)
+   -- (batch_size, width, height * 2, depth / 2)
+   x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
+   -- (batch_size, height * 2, width, depth / 2)
+   x = x:transpose(2, 3)
+   -- (batch_size, height * 2, width * 2, depth / 4)
+   x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
+   -- (batch_size, depth / 4, height * 2, width * 2)
+   x = x:transpose(2, 4)
+   x = x:transpose(3, 4)
+   self.output:resizeAs(x):copy(x) -- contiguous
+   
+   return self.output
+end
+
+function DepthExpand2x:updateGradInput(input, gradOutput)
+   -- (batch_size, depth / 4, height * 2, width * 2)
+   local x = gradOutput
+   -- (batch_size, height * 2, width * 2, depth / 4)
+   x = x:transpose(2, 4)
+   x = x:transpose(2, 3)
+   -- (batch_size, height * 2, width, depth / 2)
+   x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
+   -- (batch_size, width, height * 2, depth / 2)
+   x = x:transpose(2, 3)
+   -- (batch_size, width, height, depth)
+   x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
+   -- (batch_size, depth, height, width)
+   x = x:transpose(2, 4)
+   
+   self.gradInput:resizeAs(x):copy(x)
+   
+   return self.gradInput
+end
+
+function DepthExpand2x.test()
+   require 'image'
+   local function show(x)
+      local img = torch.Tensor(3, x:size(3), x:size(4))
+      img[1]:copy(x[1][1])
+      img[2]:copy(x[1][2])
+      img[3]:copy(x[1][3])
+      image.display(img)
+   end
+   local img = image.lena()
+   local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
+   for i = 0, img:size(1) * 4 - 1 do
+      src_index = ((i % 3) + 1)
+      x[1][i + 1]:copy(img[src_index])
+   end
+   show(x)
+   
+   local de2x = mynn.DepthExpand2x()
+   out = de2x:forward(x)
+   show(out)
+   out = de2x:updateGradInput(x, out)
+   show(out)
+end

+ 4 - 3
lib/LeakyReLU.lua

@@ -1,7 +1,8 @@
-if nn.LeakyReLU then
-   return
+if mynn.LeakyReLU then
+   return mynn.LeakyReLU
 end
-local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
+
+local LeakyReLU, parent = torch.class('mynn.LeakyReLU','nn.Module')
  
 function LeakyReLU:__init(negative_scale)
    parent.__init(self)

+ 31 - 0
lib/LeakyReLU_deprecated.lua

@@ -0,0 +1,31 @@
+if nn.LeakyReLU then
+   return nn.LeakyReLU
+end
+
+local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
+ 
+function LeakyReLU:__init(negative_scale)
+   parent.__init(self)
+   self.negative_scale = negative_scale or 0.333
+   self.negative = torch.Tensor()
+end
+ 
+function LeakyReLU:updateOutput(input)
+   self.output:resizeAs(input):copy(input):abs():add(input):div(2)
+   self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
+   self.output:add(self.negative)
+   
+   return self.output
+end
+ 
+function LeakyReLU:updateGradInput(input, gradOutput)
+   self.gradInput:resizeAs(gradOutput)
+   -- filter positive
+   self.negative:sign():add(1)
+   torch.cmul(self.gradInput, gradOutput, self.negative)
+   -- filter negative
+   self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
+   self.gradInput:add(self.negative)
+   
+   return self.gradInput
+end

+ 25 - 0
lib/RGBWeightedMSECriterion.lua

@@ -0,0 +1,25 @@
+local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
+
+function RGBWeightedMSECriterion:__init(w)
+   parent.__init(self)
+   self.weight = w:clone()
+   self.diff = torch.Tensor()
+   self.loss = torch.Tensor()
+end
+
+function RGBWeightedMSECriterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   for i = 1, input:size(1) do
+      self.diff[i]:add(-1, target[i]):cmul(self.weight)
+   end
+   self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
+   self.output = self.loss:mean()
+   
+   return self.output
+end
+
+function RGBWeightedMSECriterion:updateGradInput(input, target)
+   self.gradInput:resizeAs(input):copy(self.diff)
+   return self.gradInput
+end
+

+ 4 - 3
lib/image_loader.lua

@@ -13,7 +13,7 @@ function image_loader.decode_float(blob)
 end
 function image_loader.encode_png(rgb, alpha)
    if rgb:type() == "torch.ByteTensor" then
-      error("expect FloatTensor")
+      rgb = rgb:float():div(255)
    end
    if alpha then
       if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
@@ -26,11 +26,11 @@ function image_loader.encode_png(rgb, alpha)
       rgba[4]:copy(alpha)
       local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
       im:format("png")
-      return im:toBlob()
+      return im:toBlob(9)
    else
       local im = gm.Image(rgb, "RGB", "DHW")
       im:format("png")
-      return im:toBlob()
+      return im:toBlob(9)
    end
 end
 function image_loader.save_png(filename, rgb, alpha)
@@ -64,6 +64,7 @@ function image_loader.decode_byte(blob)
       end
       return {im, alpha}
    end
+   load_image()
    local state, ret = pcall(load_image)
    if state then
       return ret[1], ret[2]

+ 5 - 8
lib/minibatch_adam.lua

@@ -22,15 +22,12 @@ local function minibatch_adam(model, criterion,
    local targets_tmp = torch.Tensor(batch_size,
 				    target_size[1] * target_size[2] * target_size[3])
    
-   for t = 1, #train_x, batch_size do
-      if t + batch_size > #train_x then
-	 break
-      end
+   for t = 1, #train_x do
       xlua.progress(t, #train_x)
-      for i = 1, batch_size do
-	 local x, y = transformer(train_x[shuffle[t + i - 1]])
-         inputs_tmp[i]:copy(x)
-	 targets_tmp[i]:copy(y)
+      local xy = transformer(train_x[shuffle[t]], false, batch_size)
+      for i = 1, #xy do
+         inputs_tmp[i]:copy(xy[i][1])
+	 targets_tmp[i]:copy(xy[i][2])
       end
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)

+ 20 - 0
lib/mynn.lua

@@ -0,0 +1,20 @@
+local function load_cunn()
+   require 'nn'
+   require 'cunn'
+end
+local function load_cudnn()
+   require 'cudnn'
+   cudnn.fastest = true
+end
+if mynn then
+   return mynn
+else
+   load_cunn()
+   --load_cudnn()
+   mynn = {}
+   require './LeakyReLU'
+   require './LeakyReLU_deprecated'
+   require './DepthExpand2x'
+   require './RGBWeightedMSECriterion'
+   return mynn
+end

+ 161 - 97
lib/pairwise_transform.lua

@@ -4,10 +4,11 @@ local iproc = require './iproc'
 local reconstruct = require './reconstruct'
 local pairwise_transform = {}
 
-local function random_half(src, p, min_size)
-   p = p or 0.5
-   local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
-   if p > torch.uniform() then
+local function random_half(src, p)
+   p = p or 0.25
+   --local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
+   local filter = "Box"
+   if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
       return src
@@ -21,6 +22,48 @@ local function pcacov(x)
    local ce, cv = torch.symeig(c, 'V')
    return ce, cv
 end
+local function crop_if_large(src, max_size)
+   if src:size(2) > max_size and src:size(3) > max_size then
+      local yi = torch.random(0, src:size(2) - max_size)
+      local xi = torch.random(0, src:size(3) - max_size)
+      return image.crop(src, xi, yi, xi + max_size, yi + max_size)
+   else
+      return src
+   end
+end
+local function active_cropping(x, y, size, offset, p, tries)
+   assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
+   local r = torch.uniform()
+   if p < r then
+      local xi = torch.random(offset, y:size(3) - (size + offset + 1))
+      local yi = torch.random(offset, y:size(2) - (size + offset + 1))
+      local xc = image.crop(x, xi, yi, xi + size, yi + size)
+      local yc = image.crop(y, xi, yi, xi + size, yi + size)
+      yc = yc:float():div(255)
+      xc = xc:float():div(255)
+      return xc, yc
+   else
+      local samples = {}
+      local sum_mse = 0
+      for i = 1, tries do
+	 local xi = torch.random(offset, y:size(3) - (size + offset + 1))
+	 local yi = torch.random(offset, y:size(2) - (size + offset + 1))
+	 local xc = image.crop(x, xi, yi, xi + size, yi + size):float():div(255)
+	 local yc = image.crop(y, xi, yi, xi + size, yi + size):float():div(255)
+	 local mse = (xc - yc):pow(2):mean()
+	 sum_mse = sum_mse + mse
+	 table.insert(samples, {xc = xc, yc = yc, mse = mse})
+      end
+      if sum_mse > 0 then
+	 table.sort(samples,
+		    function (a, b)
+		       return a.mse > b.mse
+		    end)
+      end
+      return samples[1].xc, samples[1].yc
+   end
+end
+
 local function color_noise(src)
    local p = 0.1
    src = src:float():div(255)
@@ -84,20 +127,22 @@ local function overlay_augment(src, p)
       return src
    end
 end
-local INTERPOLATION_PADDING = 16
-function pairwise_transform.scale(src, scale, size, offset, options)
-   options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
-   if options.random_half then
-      src = random_half(src)
+local function data_augment(y, options)
+   y = flip_augment(y)
+   if options.color_noise then
+      y = color_noise(y)
    end
-   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
-   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
-   local down_scale = 1.0 / scale
-   local y = image.crop(src,
-			xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
-			xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
+   if options.overlay then
+      y = overlay_augment(y)
+   end
+   return y
+end
+
+
+local INTERPOLATION_PADDING = 16
+function pairwise_transform.scale(src, scale, size, offset, n, options)
    local filters = {
-      "Box",        -- 0.012756949974688
+      "Box","Box","Box",  -- 0.012756949974688
       "Blackman",   -- 0.013191924552285
       --"Cartom",     -- 0.013753536746706
       --"Hanning",    -- 0.013761314529647
@@ -105,38 +150,36 @@ function pairwise_transform.scale(src, scale, size, offset, options)
       "SincFast",   -- 0.014095824314306
       "Jinc",       -- 0.014244299255442
    }
-   local downscale_filter = filters[torch.random(1, #filters)]
-   
-   y = flip_augment(y)
-   if options.color_noise then
-      y = color_noise(y)
-   end
-   if options.overlay then
-      y = overlay_augment(y)
+   if options.random_half then
+      src = random_half(src)
    end
-   local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
-   x = iproc.scale(x, y:size(3), y:size(2))
-   y = y:float():div(255)
-   x = x:float():div(255)
-
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   local downscale_filter = filters[torch.random(1, #filters)]
+   local y = data_augment(crop_if_large(src, math.max(size * 4, 512)), options)
+   local down_scale = 1.0 / scale
+   local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
+				     y:size(2) * down_scale, downscale_filter),
+			 y:size(3), y:size(2))
+   local batch = {}
+   for i = 1, n do
+      local xc, yc = active_cropping(x, y,
+				     size,
+				     INTERPOLATION_PADDING,
+				     options.active_cropping_rate,
+				     options.active_cropping_tries)
+      if options.rgb then
+      else
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
    end
-
-   y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset -	INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
-   x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
-   
-   return x, y
+   return batch
 end
-function pairwise_transform.jpeg_(src, quality, size, offset, options)
-   options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
+function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
    if options.random_half then
       src = random_half(src)
    end
-   local yi = torch.random(0, src:size(2) - size - 1)
-   local xi = torch.random(0, src:size(3) - size - 1)
+   src = crop_if_large(src, math.max(size * 4, 512))
    local y = src
    local x
 
@@ -150,63 +193,64 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
-      x:samplingFactors({1.0, 1.0, 1.0})
+      if options.jpeg_sampling_factors == 444 then
+	 x:samplingFactors({1.0, 1.0, 1.0})
+      else -- 422
+	 x:samplingFactors({2.0, 1.0, 1.0})
+      end
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
    end
    
-   y = image.crop(y, xi, yi, xi + size, yi + size)
-   x = image.crop(x, xi, yi, xi + size, yi + size)
-   y = y:float():div(255)
-   x = x:float():div(255)
-   x, y = flip_augment(x, y)
-   
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   local batch = {}
+   for i = 1, n do
+      local xc, yc = active_cropping(x, y, size, 0,
+				     options.active_cropping_rate,
+				     options.active_cropping_tries)
+      xc, yc = flip_augment(xc, yc)
+      
+      if options.rgb then
+      else
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      table.insert(batch, {xc, image.crop(yc, offset, offset, size - offset, size - offset)})
    end
-   
-   return x, image.crop(y, offset, offset, size - offset, size - offset)
+   return batch
 end
-function pairwise_transform.jpeg(src, category, level, size, offset, options)
+function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
    if category == "anime_style_art" then
       if level == 1 then
-	 if torch.uniform() > 0.7 then
+	 if torch.uniform() > 0.8 then
 	    return pairwise_transform.jpeg_(src, {},
-					    size, offset,
-					    options)
+					    size, offset, n, options)
 	 else
 	    return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-					    size, offset,
-					    options)
+					    size, offset, n, options)
 	 end
       elseif level == 2 then
-	 if torch.uniform() > 0.7 then	 
+	 local r = torch.uniform()
+	 if torch.uniform() > 0.8 then
 	    return pairwise_transform.jpeg_(src, {},
-					    size, offset,
-					    options)
+					    size, offset, n, options)
 	 else
-	    local r = torch.uniform()
 	    if r > 0.6 then
 	       return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
-					       size, offset,
-					    options)
+					       size, offset, n, options)
 	    elseif r > 0.3 then
 	       local quality1 = torch.random(37, 70)
 	       local quality2 = quality1 - torch.random(5, 10)
 	       return pairwise_transform.jpeg_(src, {quality1, quality2},
-					    size, offset,
-					    options)
+					       size, offset, n, options)
 	    else
 	       local quality1 = torch.random(52, 70)
-	       return pairwise_transform.jpeg_(src,
-					       {quality1,
-						quality1 - torch.random(5, 15),
-						quality1 - torch.random(15, 25)},
-					       size, offset,
-					       options)
+	       local quality2 = quality1 - torch.random(5, 15)
+	       local quality3 = quality1 - torch.random(15, 25)
+	       
+	       return pairwise_transform.jpeg_(src, 
+					       {quality1, quality2, quality3},
+					       size, offset, n, options)
 	    end
 	 end
       else
@@ -216,23 +260,25 @@ function pairwise_transform.jpeg(src, category, level, size, offset, options)
       if level == 1 then
 	 if torch.uniform() > 0.7 then
 	    return pairwise_transform.jpeg_(src, {},
-					    size, offset,
+					    size, offset, n,
 					    options)
 	 else
 	    return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
-					    size, offset,
+					    size, offset, n,
 					    options)
 	 end
       elseif level == 2 then
 	 if torch.uniform() > 0.7 then
 	    return pairwise_transform.jpeg_(src, {},
-					    size, offset,
+					    size, offset, n,
 					    options)
 	 else
 	    return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-					    size, offset,
+					    size, offset, n,
 					    options)
 	 end
+      else
+	 error("unknown noise level: " .. level)
       end
    else
       error("unknown category: " .. category)
@@ -242,6 +288,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    if options.random_half then
       src = random_half(src)
    end
+   src = crop_if_large(src, math.max(size * 4, 512))
    local down_scale = 1.0 / scale
    local filters = {
       "Box",        -- 0.012756949974688
@@ -270,7 +317,11 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
-      x:samplingFactors({1.0, 1.0, 1.0})
+      if options.jpeg_sampling_factors == 444 then
+	 x:samplingFactors({1.0, 1.0, 1.0})
+      else -- 422
+	 x:samplingFactors({2.0, 1.0, 1.0})
+      end
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
@@ -321,10 +372,11 @@ function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset
 						     size, offset, options)
 	    else
 	       local quality1 = torch.random(52, 70)
+	       local quality2 = quality1 - torch.random(5, 15)
+	       local quality3 = quality1 - torch.random(15, 25)
+	       
 	       return pairwise_transform.jpeg_scale_(src, scale,
-						     {quality1,
-						      quality1 - torch.random(5, 15),
-						      quality1 - torch.random(15, 25)},
+						     {quality1, quality2, quality3 },
 						     size, offset, options)
 	    end
 	 end
@@ -354,14 +406,13 @@ end
 local function test_jpeg()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
-   local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, {})
-   image.display({image = y, legend = "y:0"})
-   image.display({image = x, legend = "x:0"})
    for i = 2, 9 do
-      local y, x = pairwise_transform.jpeg_(random_half(src),
-					    {i * 10}, 128, 0, {color_noise = false, random_half = true, overlay = true, rgb = true})
-      image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
-      image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
+      local xy = pairwise_transform.jpeg_(random_half(src),
+					  {i * 10}, 128, 0, 2, {color_noise = false, random_half = true, overlay = true, rgb = true})
+      for i = 1, #xy do
+	 image.display({image = xy[i][1], legend = "y:" .. (i * 10), max=1,min=0})
+	 image.display({image = xy[i][2], legend = "x:" .. (i * 10),max=1,min=0})
+      end
       --print(x:mean(), y:mean())
    end
 end
@@ -370,27 +421,40 @@ local function test_scale()
    torch.setdefaulttensortype('torch.FloatTensor')
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
+   local options = {color_noise = true,
+		    random_half = true,
+		    overlay = false,
+		    active_cropping_rate = 1.5,
+		    active_cropping_tries = 10,
+		    rgb = true
+   }
    for i = 1, 9 do
-      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true, overlay = true})
-      image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
+      local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
+      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
+      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
+      print(xy[1][1]:size(), xy[1][2]:size())
       --print(x:mean(), y:mean())
    end
 end
 local function test_jpeg_scale()
    torch.setdefaulttensortype('torch.FloatTensor')
    local loader = require './image_loader'
-   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
+   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
+   local options = {color_noise = true,
+		    random_half = true,
+		    overlay = true,
+		    active_cropping_ratio = 0.5,
+		    active_cropping_times = 10
+   }
    for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true, overlay = true})
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
       image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())
       --print(x:mean(), y:mean())
    end
    for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true, overlay = true})
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
       image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())

+ 7 - 1
lib/portable.lua

@@ -1,9 +1,13 @@
 local function load_cuda()
+   require 'nn'
    require 'cunn'
 end
+local function load_cudnn()
+   require 'cudnn'
+   --cudnn.fastest = true
+end
 
 if pcall(load_cuda) then
-   require 'cunn'
 else
    --[[ TODO: fakecuda does not work.
       
@@ -13,3 +17,5 @@ else
    require('fakecuda').init(true)
    --]]
 end
+if pcall(load_cudnn) then
+end

+ 31 - 11
lib/reconstruct.lua

@@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
    end
    return new_x
 end
-function model_is_rgb(model)
+local reconstruct = {}
+function reconstruct.is_rgb(model)
    if model:get(model:size() - 1).weight:size(1) == 3 then
       -- 3ch RGB
       return true
@@ -57,8 +58,23 @@ function model_is_rgb(model)
       return false
    end
 end
-
-local reconstruct = {}
+function reconstruct.offset_size(model)
+   local conv = model:findModules("nn.SpatialConvolutionMM")
+   if #conv > 0 then
+      local offset = 0
+      for i = 1, #conv do
+	 offset = offset + (conv[i].kW - 1) / 2
+      end
+      return math.floor(offset)
+   else
+      conv = model:findModules("cudnn.SpatialConvolution")
+      local offset = 0
+      for i = 1, #conv do
+	 offset = offset + (conv[i].kW - 1) / 2
+      end
+      return math.floor(offset)
+   end
+end
 function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128
    local output_size = block_size - offset * 2
@@ -172,18 +188,22 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
    return output
 end
 
-function reconstruct.image(model, x, offset, block_size)
-   if model_is_rgb(model) then
-      return reconstruct.image_rgb(model, x, offset, block_size)
+function reconstruct.image(model, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return reconstruct.image_rgb(model, x,
+				   reconstruct.offset_size(model), block_size)
    else
-      return reconstruct.image_y(model, x, offset, block_size)
+      return reconstruct.image_y(model, x,
+				 reconstruct.offset_size(model), block_size)
    end
 end
-function reconstruct.scale(model, scale, x, offset, block_size)
-   if model_is_rgb(model) then
-      return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+function reconstruct.scale(model, scale, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return reconstruct.scale_rgb(model, scale, x,
+				   reconstruct.offset_size(model), block_size)
    else
-      return reconstruct.scale_y(model, scale, x, offset, block_size)
+      return reconstruct.scale_y(model, scale, x,
+				 reconstruct.offset_size(model), block_size)
    end
 end
 

+ 9 - 11
lib/settings.lua

@@ -1,5 +1,6 @@
 require 'xlua'
 require 'pl'
+require 'trepl'
 
 -- global settings
 
@@ -18,6 +19,7 @@ cmd:text("waifu2x")
 cmd:text("Options:")
 cmd:option("-seed", 11, 'fixed input seed')
 cmd:option("-data_dir", "./data", 'data directory')
+-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
 cmd:option("-test", "images/miku_small.png", 'test image file')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", '(noise|scale|noise_scale)')
@@ -30,9 +32,15 @@ cmd:option("-scale", 2.0, 'scale')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
 cmd:option("-crop_size", 128, 'crop size')
+cmd:option("-max_size", -1, 'crop if image size larger then this value.')
 cmd:option("-batch_size", 2, 'mini batch size')
 cmd:option("-epoch", 200, 'epoch')
 cmd:option("-core", 2, 'cpu core')
+cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
+cmd:option("-validation_ratio", 0.1, 'validation ratio')
+cmd:option("-validation_crops", 40, 'number of crop region in validation')
+cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
+cmd:option("-active_cropping_tries", 20, 'active cropping tries')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
@@ -81,16 +89,6 @@ torch.setnumthreads(settings.core)
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
-settings.validation_ratio = 0.1
-settings.validation_crops = 30
-
-local srcnn = require './srcnn'
-if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
-   settings.create_model = srcnn.waifu4x
-   settings.block_offset = 13
-else
-   settings.create_model = srcnn.waifu2x
-   settings.block_offset = 7
-end
+settings.backend = "cunn"
 
 return settings

+ 42 - 53
lib/srcnn.lua

@@ -1,77 +1,66 @@
-require './LeakyReLU'
 
--- ref: http://arxiv.org/abs/1502.01852
-function nn.SpatialConvolutionMM:reset(stdv)
-   stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
-   self.weight:normal(0, stdv)
-   self.bias:zero()
-end
+require './mynn'
 
+-- ref: http://arxiv.org/abs/1502.01852
 -- ref: http://arxiv.org/abs/1501.00092
 local srcnn = {}
-function srcnn.waifu2x(color)
+function srcnn.waifu2x_cunn(ch)
    local model = nn.Sequential()
-   local ch = nil
-   if color == "rgb" then
-      ch = 3
-   elseif color == "y" then
-      ch = 1
-   else
-      if color then
-	 error("unknown color: " .. color)
-      else
-	 error("unknown color: nil")
-      end
-   end
-   -- very deep model
    model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(mynn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
---model:cuda()
---print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
-   return model, 7
+   return model
 end
-
--- current 4x is worse then 2x * 2
-function srcnn.waifu4x(color)
+function srcnn.waifu2x_cudnn(ch)
    local model = nn.Sequential()
-
-   local ch = nil
+   model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(mynn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(nn.View(-1):setNumInputDims(3))
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+   
+   return model
+end
+function srcnn.create(model_name, backend, color)
+   local ch = 3
    if color == "rgb" then
       ch = 3
    elseif color == "y" then
       ch = 1
    else
-      error("unknown color: " .. color)
+      error("unsupported color: " + color)
+   end
+   if backend == "cunn" then
+      return srcnn.waifu2x_cunn(ch)
+   elseif backend == "cudnn" then
+      return srcnn.waifu2x_cudnn(ch)
+   else
+      error("unsupported backend: " +  backend)
    end
-   
-   model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
-   model:add(nn.View(-1):setNumInputDims(3))
-   
-   return model, 13
 end
 return srcnn

+ 106 - 53
train.lua

@@ -1,9 +1,12 @@
 require './lib/portable'
+require './lib/mynn'
 require 'optim'
 require 'xlua'
 require 'pl'
+require 'snappy'
 
 local settings = require './lib/settings'
+local srcnn = require './lib/srcnn'
 local minibatch_adam = require './lib/minibatch_adam'
 local iproc = require './lib/iproc'
 local reconstruct = require './lib/reconstruct'
@@ -11,11 +14,11 @@ local pairwise_transform = require './lib/pairwise_transform'
 local image_loader = require './lib/image_loader'
 
 local function save_test_scale(model, rgb, file)
-   local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
+   local up = reconstruct.scale(model, settings.scale, rgb)
    image.save(file, up)
 end
 local function save_test_jpeg(model, rgb, file)
-   local im, count = reconstruct.image(model, rgb, settings.block_offset)
+   local im, count = reconstruct.image(model, rgb)
    image.save(file, im)
 end
 local function split_data(x, test_size)
@@ -35,10 +38,14 @@ local function make_validation_set(x, transformer, n)
    n = n or 4
    local data = {}
    for i = 1, #x do
-      for k = 1, n do
-	 local x, y = transformer(x[i], true)
-	 table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
-			     y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+      for k = 1, math.max(n / 8, 1) do
+	 local xy = transformer(x[i], true, 8)
+	 for j = 1, #xy do
+	    local x = xy[j][1]
+	    local y = xy[j][2]
+	    table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
+				y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+	 end
       end
       xlua.progress(i, #x)
       collectgarbage()
@@ -58,15 +65,96 @@ local function validate(model, criterion, data)
    return loss / #data
 end
 
+local function create_criterion(model)
+   if reconstruct.is_rgb(model) then
+      local offset = reconstruct.offset_size(model)
+      local output_w = settings.crop_size - offset * 2
+      local weight = torch.Tensor(3, output_w * output_w)
+      weight[1]:fill(0.299 * 3) -- R
+      weight[2]:fill(0.587 * 3) -- G
+      weight[3]:fill(0.114 * 3) -- B
+      return mynn.RGBWeightedMSECriterion(weight):cuda()
+   else
+      return nn.MSECriterion():cuda()
+   end
+end
+local function transformer(x, is_validation, n, offset)
+   local size = x[1]
+   local dec = snappy.decompress(x[2]:string())
+   x = torch.ByteTensor(size[1], size[2], size[3])
+   x:storage():string(dec)
+   
+   n = n or settings.batch_size;
+   if is_validation == nil then is_validation = false end
+   local color_noise = nil 
+   local overlay = nil
+   local active_cropping_ratio = nil
+   local active_cropping_tries = nil
+   
+   if is_validation then
+      active_cropping_rate = 0.0
+      active_cropping_tries = 0
+      color_noise = false
+      overlay = false
+   else
+      active_cropping_rate = settings.active_cropping_rate
+      active_cropping_tries = settings.active_cropping_tries
+      color_noise = settings.color_noise
+      overlay = settings.overlay
+   end
+   
+   if settings.method == "scale" then
+      return pairwise_transform.scale(x,
+				      settings.scale,
+				      settings.crop_size, offset,
+				      n,
+				      { color_noise = color_noise,
+					overlay = overlay,
+					random_half = settings.random_half,
+					active_cropping_rate = active_cropping_rate,
+					active_cropping_tries = active_cropping_tries,
+					rgb = (settings.color == "rgb")
+				      })
+   elseif settings.method == "noise" then
+      return pairwise_transform.jpeg(x,
+				     settings.category,
+				     settings.noise_level,
+				     settings.crop_size, offset,
+				     n,
+				     { color_noise = color_noise,
+				       overlay = overlay,
+				       active_cropping_rate = active_cropping_rate,
+				       active_cropping_tries = active_cropping_tries,
+				       random_half = settings.random_half,
+				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
+				       rgb = (settings.color == "rgb")
+				     })
+   elseif settings.method == "noise_scale" then
+      return pairwise_transform.jpeg_scale(x,
+					   settings.scale,
+					   settings.category,
+					   settings.noise_level,
+					   settings.crop_size, offset,
+					   n,
+					   { color_noise = color_noise,
+					     overlay = overlay,
+					     jpeg_sampling_factors = settings.jpeg_sampling_factors,
+					     random_half = settings.random_half,
+					     rgb = (settings.color == "rgb")
+					   })
+   end
+end
+
 local function train()
-   local model, offset = settings.create_model(settings.color)
-   assert(offset == settings.block_offset)
-   local criterion = nn.MSECriterion():cuda()
+   local model = srcnn.create(settings.method, settings.backend, settings.color)
+   local offset = reconstruct.offset_size(model)
+   local pairwise_func = function(x, is_validation, n)
+      return transformer(x, is_validation, n, offset)
+   end
+   local criterion = create_criterion(model)
    local x = torch.load(settings.images)
    local lrd_count = 0
-   local train_x, valid_x = split_data(x,
-				       math.floor(settings.validation_ratio * #x))
-   local test = image_loader.load_float(settings.test)
+   local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
    local adam_config = {
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
@@ -77,45 +165,9 @@ local function train()
    elseif settings.color == "rgb" then
       ch = 3
    end
-   local transformer = function(x, is_validation)
-      if is_validation == nil then is_validation = false end
-      local color_noise = (not is_validation) and settings.color_noise
-      local overlay = (not is_validation) and settings.overlay
-      if settings.method == "scale" then
-	 return pairwise_transform.scale(x,
-					 settings.scale,
-					 settings.crop_size, offset,
-					 { color_noise = color_noise,
-					   overlay = overlay,
-					   random_half = settings.random_half,
-					   rgb = (settings.color == "rgb")
-					 })
-      elseif settings.method == "noise" then
-	 return pairwise_transform.jpeg(x,
-					settings.category,
-					settings.noise_level,
-					settings.crop_size, offset,
-					{ color_noise = color_noise,
-					  overlay = overlay,
-					  random_half = settings.random_half,
-					  rgb = (settings.color == "rgb")
-					})
-      elseif settings.method == "noise_scale" then
-	 return pairwise_transform.jpeg_scale(x,
-					      settings.scale,
-					      settings.category,
-					      settings.noise_level,
-					      settings.crop_size, offset,
-					      { color_noise = color_noise,
-						overlay = overlay,
-						random_half = settings.random_half,
-						rgb = (settings.color == "rgb")
-					      })
-      end
-   end
    local best_score = 100000.0
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, transformer, settings.validation_crops)
+   local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
    valid_x = nil
    
    collectgarbage()
@@ -125,7 +177,7 @@ local function train()
       model:training()
       print("# " .. epoch)
       print(minibatch_adam(model, criterion, train_x, adam_config,
-			   transformer,
+			   pairwise_func,
 			   {ch, settings.crop_size, settings.crop_size},
 			   {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
 			  ))
@@ -133,6 +185,7 @@ local function train()
       print("# validation")
       local score = validate(model, criterion, valid_xy)
       if score < best_score then
+	 local test_image = image_loader.load_float(settings.test) -- reload
 	 lrd_count = 0
 	 best_score = score
 	 print("* update best model")
@@ -140,16 +193,16 @@ local function train()
 	 if settings.method == "noise" then
 	    local log = path.join(settings.model_dir,
 				  ("noise%d_best.png"):format(settings.noise_level))
-	    save_test_jpeg(model, test, log)
+	    save_test_jpeg(model, test_image, log)
 	 elseif settings.method == "scale" then
 	    local log = path.join(settings.model_dir,
 				  ("scale%.1f_best.png"):format(settings.scale))
-	    save_test_scale(model, test, log)
+	    save_test_scale(model, test_image, log)
 	 elseif settings.method == "noise_scale" then
 	    local log = path.join(settings.model_dir,
 				  ("noise%d_scale%.1f_best.png"):format(settings.noise_level,
 									settings.scale))
-	    save_test_scale(model, test, log)
+	    save_test_scale(model, test_image, log)
 	 end
       else
 	 lrd_count = lrd_count + 1

+ 6 - 3
train.sh

@@ -1,10 +1,13 @@
 #!/bin/sh
 
-th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
+th convert_data.lua
+
+th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8  -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
 th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
 
-th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
+th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8  -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
 th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
 
-th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
+th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
 th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
+

+ 11 - 0
train_ukbench.sh

@@ -0,0 +1,11 @@
+#!/bin/sh
+
+th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
+th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
+
+th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
+th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
+
+th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
+th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
+

+ 14 - 13
waifu2x.lua

@@ -1,12 +1,11 @@
 require './lib/portable'
 require 'sys'
 require 'pl'
-require './lib/LeakyReLU'
+require './lib/mynn'
 
 local iproc = require './lib/iproc'
 local reconstruct = require './lib/reconstruct'
 local image_loader = require './lib/image_loader'
-local BLOCK_OFFSET = 7
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
@@ -22,19 +21,21 @@ local function convert_image(opt)
    end
    if opt.m == "noise" then
       local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
+      --local srcnn = require 'lib/srcnn'
+      --local model = srcnn.waifu2x("rgb"):cuda()
       model:evaluate()
-      new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
+      new_x = reconstruct.image(model, x, opt.crop_size)
    elseif opt.m == "scale" then
       local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
       model:evaluate()
-      new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+      new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
    elseif opt.m == "noise_scale" then
       local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
       local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
       noise_model:evaluate()
       scale_model:evaluate()
-      x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
-      new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+      x = reconstruct.image(noise_model, x)
+      new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
    else
       error("undefined method:" .. opt.method)
    end
@@ -62,17 +63,17 @@ local function convert_frames(opt)
 	 local x, alpha = image_loader.load_float(lines[i])
 	 local new_x = nil
 	 if opt.m == "noise" and opt.noise_level == 1 then
-	    new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
+	    new_x = reconstruct.image(noise1_model, x, opt.crop_size)
 	 elseif opt.m == "noise" and opt.noise_level == 2 then
-	    new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+	    new_x = reconstruct.image(noise2_model, x)
 	 elseif opt.m == "scale" then
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-	    x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    x = reconstruct.image(noise1_model, x)
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
-	    x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    x = reconstruct.image(noise2_model, x)
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
 	 else
 	    error("undefined method:" .. opt.method)
 	 end

+ 6 - 7
web.lua

@@ -4,8 +4,8 @@ local uuid = require 'uuid'
 local ffi = require 'ffi'
 local md5 = require 'md5'
 require 'pl'
-require './lib/portable'
-require './lib/LeakyReLU'
+require 'lib.portable'
+require 'lib.mynn'
 
 local cmd = torch.CmdLine()
 cmd:text()
@@ -23,7 +23,7 @@ local iproc = require './lib/iproc'
 local reconstruct = require './lib/reconstruct'
 local image_loader = require './lib/image_loader'
 
-local MODEL_DIR = "./models/anime_style_art_rgb"
+local MODEL_DIR = "./models/anime_style_art_rgb3"
 
 local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
 local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
@@ -40,7 +40,6 @@ local CURL_OPTIONS = {
    max_redirects = 2
 }
 local CURL_MAX_SIZE = 2 * 1024 * 1024
-local BLOCK_OFFSET = 7 -- see srcnn.lua
 
 local function valid_size(x, scale)
    if scale == 0 then
@@ -80,13 +79,13 @@ local function get_image(req)
 end
 
 local function apply_denoise1(x)
-   return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
+   return reconstruct.image(noise1_model, x)
 end
 local function apply_denoise2(x)
-   return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+   return reconstruct.image(noise2_model, x)
 end
 local function apply_scale2x(x)
-   return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
+   return reconstruct.scale(scale20_model, 2.0, x)
 end
 local function cache_do(cache, x, func)
    if path.exists(cache) then