nagadomi 7 年之前
父節點
當前提交
fa6ee00624

+ 34 - 3
convert_data.lua

@@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
       return x, y
    end
 end
-
+local function padding_x(x, pad)
+   if pad > 0 then
+      x = iproc.padding(x, pad, pad, pad, pad)
+   end
+   return x
+end
+local function padding_xy(x, y, pad, y_zero)
+   local scale = y:size(2) / x:size(2)
+   if pad > 0 then
+      x = iproc.padding(x, pad, pad, pad, pad)
+      if y_zero then
+	 y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
+      else
+	 y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
+      end
+   end
+   return x, y
+end
 local function load_images(list)
    local MARGIN = 32
    local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
@@ -78,6 +95,7 @@ local function load_images(list)
       if csv_meta and csv_meta.filters then
 	 filters = csv_meta.filters
       end
+      local basename_y = path.basename(filename)
       local im, meta = image_loader.load_byte(filename)
       local skip = false
       local alpha_color = torch.random(0, 1)
@@ -100,25 +118,38 @@ local function load_images(list)
 	       -- method == user
 	       local yy = im
 	       local xx, meta2 = image_loader.load_byte(csv_meta.x)
+	       if settings.invert_x then
+		  xx = (-(xx:long()) + 255):byte()
+	       end
+
 	       if xx then
 		  if meta2 and meta2.alpha then
 		     xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
 		  end
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
+		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
+		  if settings.grayscale then
+		     xx = iproc.rgb2y(xx)
+		     yy = iproc.rgb2y(yy)
+		  end
 		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
-				  {data = {filters = filters, has_x = true}}})
+				  {data = {filters = filters, has_x = true, basename = basename_y}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
 	       end
 	    else
 	       im = crop_if_large(im, settings.max_training_image_size)
 	       im = iproc.crop_mod4(im)
+	       im = padding_x(im, settings.padding)
 	       local scale = 1.0
 	       if settings.random_half_rate > 0.0 then
 		  scale = 2.0
 	       end
 	       if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
-		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
+		  if settings.grayscale then
+		     im = iproc.rgb2y(im)
+		  end
+		  table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
 	       end

+ 42 - 0
lib/ShakeShakeTable.lua

@@ -0,0 +1,42 @@
+local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
+
+function ShakeShakeTable:__init()
+   parent.__init(self)
+   self.alpha = torch.Tensor()
+   self.beta = torch.Tensor()
+   self.first = torch.Tensor()
+   self.second = torch.Tensor()
+   self.train = true
+end
+function ShakeShakeTable:updateOutput(input)
+   local batch_size = input[1]:size(1)
+   if self.train then
+      self.alpha:resize(batch_size):uniform()
+      self.beta:resize(batch_size):uniform()
+      self.second:resizeAs(input[1]):copy(input[2])
+      for i = 1, batch_size do
+	 self.second[i]:mul(self.alpha[i])
+      end
+      self.output:resizeAs(input[1]):copy(input[1])
+      for i = 1, batch_size do
+	 self.output[i]:mul(1.0 - self.alpha[i])
+      end
+      self.output:add(self.second):mul(2)
+   else
+      self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
+   end
+   return self.output
+end
+function ShakeShakeTable:updateGradInput(input, gradOutput)
+   local batch_size = input[1]:size(1)
+   self.first:resizeAs(gradOutput):copy(gradOutput)
+   for i = 1, batch_size do
+      self.first[i]:mul(self.beta[i])
+   end
+   self.second:resizeAs(gradOutput):copy(gradOutput)
+   for i = 1, batch_size do
+      self.second[i]:mul(1.0 - self.beta[i])
+   end
+   self.gradOutput = {self.first, self.second}
+   return self.gradOutput
+end

+ 33 - 1
lib/data_augmentation.lua

@@ -102,7 +102,9 @@ function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
       local scale = torch.uniform(scale_min, scale_max)
       local h = math.floor(x:size(2) * scale)
       local w = math.floor(x:size(3) * scale)
-      x = iproc.scale(x, w, h, "Triangle")
+      local filters = {"Lanczos", "Catrom"}
+      local x_filter = filters[torch.random(1, 2)]
+      x = iproc.scale(x, w, h, x_filter)
       y = iproc.scale(y, w, h, "Triangle")
       return x, y
    else
@@ -139,6 +141,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
       return x, y
    end
 end
+function data_augmentation.pairwise_flip(x, y)
+   local flip = torch.random(1, 4)
+   local tr = torch.random(1, 2)
+   local x, conversion = iproc.byte2float(x)
+   y = iproc.byte2float(y)
+   x = x:contiguous()
+   y = y:contiguous()
+   if tr == 1 then
+      -- pass
+   elseif tr == 2 then
+      x = x:transpose(2, 3):contiguous()
+      y = y:transpose(2, 3):contiguous()
+   end
+   if flip == 1 then
+      x = iproc.hflip(x)
+      y = iproc.hflip(y)
+   elseif flip == 2 then
+      x = iproc.vflip(x)
+      y = iproc.vflip(y)
+   elseif flip == 3 then
+      x = iproc.hflip(iproc.vflip(x))
+      y = iproc.hflip(iproc.vflip(y))
+   elseif flip == 4 then
+   end
+   if conversion then
+      x = iproc.float2byte(x)
+      y = iproc.float2byte(y)
+   end
+   return x, y
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)

+ 15 - 2
lib/iproc.lua

@@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
    return dest
 end
 function iproc.padding(img, w1, w2, h1, h2)
+   local conversion
+   img, conversion = iproc.byte2float(img)
    image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
@@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
    flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
    flow[1]:add(-h1)
    flow[2]:add(-w1)
-   return image.warp(img, flow, "simple", false, "clamp")
+   local dest = image.warp(img, flow, "simple", false, "clamp")
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
 end
 function iproc.zero_padding(img, w1, w2, h1, h2)
+   local conversion
+   img, conversion = iproc.byte2float(img)
    image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
@@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
    flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
    flow[1]:add(-h1)
    flow[2]:add(-w1)
-   return image.warp(img, flow, "simple", false, "pad", 0)
+   local dest = image.warp(img, flow, "simple", false, "pad", 0)
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
 end
 function iproc.white_noise(src, std, rgb_weights, gamma)
    gamma = gamma or 0.454545
@@ -217,6 +229,7 @@ function iproc.rgb2y(src)
    src, conversion = iproc.byte2float(src)
    local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
    dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
+   dest:clamp(0, 1)
    if conversion then
       dest = iproc.float2byte(dest)
    end

+ 4 - 2
lib/pairwise_transform_jpeg.lua

@@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       if torch.uniform() < options.nr_rate then
 	 -- reducing noise

+ 4 - 2
lib/pairwise_transform_scale.lua

@@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 16 - 5
lib/pairwise_transform_user.lua

@@ -1,4 +1,5 @@
 local pairwise_utils = require 'pairwise_transform_utils'
+local data_augmentation = require 'data_augmentation'
 local iproc = require 'iproc'
 local gm = {}
 gm.Image = require 'graphicsmagick.Image'
@@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
    if options.active_cropping_rate > 0 then
       lowres_y = pairwise_utils.low_resolution(y)
    end
-   if options.pairwise_flip then
+   if options.pairwise_flip and n == 1 then
+      xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
+   elseif options.pairwise_flip then
       xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
    end
    assert(#xs == #ys)
+   local perm = torch.randperm(#xs)
    for i = 1, n do
-      local t = (i % #xs) + 1
+      local t = perm[(i % #xs) + 1]
       local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
 						    options.active_cropping_rate,
 						    options.active_cropping_tries)
@@ -34,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       if options.gcn then
 	 local mean = xc:mean()
@@ -46,7 +52,12 @@ function pairwise_transform.user(x, y, size, offset, n, options)
 	    xc:add(-mean)
 	 end
       end
-      table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      yc = iproc.crop(yc, offset, offset, size - offset, size - offset)
+      if options.pairwise_y_binary then
+	 yc[torch.lt(yc, 0.5)] = 0
+	 yc[torch.gt(yc, 0)] = 1
+      end
+      table.insert(batch, {xc, yc})
    end
 
    return batch

+ 19 - 12
lib/pairwise_transform_utils.lua

@@ -108,12 +108,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
 
    x = iproc.crop_mod4(x)
    y = iproc.crop_mod4(y)
-
-   if options.pairwise_y_binary then
-      y[torch.lt(y, 128)] = 0
-      y[torch.gt(y, 0)] = 255
-   end
-
    return x, y
 end
 function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
@@ -125,8 +119,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
       t = "byte"
    end
    if p < r then
-      local xi = torch.random(1, x:size(3) - (size + 1)) * scale
-      local yi = torch.random(1, x:size(2) - (size + 1)) * scale
+      local xi = 0
+      local yi = 0
+      if x:size(3) > size + 1 then
+	 xi = torch.random(0, x:size(3) - (size + 1)) * scale
+      end
+      if x:size(2) > size + 1 then
+	 yi = torch.random(0, x:size(2) - (size + 1)) * scale
+      end
       local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
       local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
       return xc, yc
@@ -273,10 +273,17 @@ function pairwise_transform_utils.low_resolution(src)
 	    toTensor("byte", "RGB", "DHW")
    end
 --]]
-   return gm.Image(src, "RGB", "DHW"):
-      size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
-      size(src:size(3), src:size(2), "Box"):
-      toTensor("byte", "RGB", "DHW")
+   if src:size(1) == 1 then
+      return gm.Image(src, "I", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "I", "DHW")
+   else
+      return gm.Image(src, "RGB", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "RGB", "DHW")
+   end
 end
 
 return pairwise_transform_utils

+ 22 - 4
lib/settings.lua

@@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
 cmd:text()
 cmd:text("waifu2x-training")
 cmd:text("Options:")
-cmd:option("-gpu", -1, 'GPU Device ID')
 cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
 cmd:option("-data_dir", "./data", 'path to data directory')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
@@ -74,9 +73,14 @@ cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
-cmd:option("-gpu", 1, 'Device ID')
-cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
+cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
+cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
+cmd:option("-padding", 0, 'replication padding size')
+cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
+cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
+cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
+cmd:option("-invert_x", 0, 'invert x image in convert_lua')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -95,6 +99,10 @@ to_bool(settings, "save_history")
 to_bool(settings, "use_transparent_png")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_flip")
+to_bool(settings, "padding_y_zero")
+to_bool(settings, "grayscale")
+to_bool(settings, "validation_filename_split")
+to_bool(settings, "invert_x")
 
 if settings.plot then
    require 'gnuplot'
@@ -168,10 +176,20 @@ end
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
-cutorch.setDevice(opt.gpu)
 -- patch for lua52
 if not math.log10 then
    math.log10 = function(x) return math.log(x, 10) end
 end
+if settings.gpu:len() > 0 then
+   local gpus = {}
+   local gpu_string = utils.split(settings.gpu, ",")
+   for i = 1, #gpu_string do
+      table.insert(gpus, tonumber(gpu_string[i]))
+   end
+   settings.gpu = gpus
+else
+   settings.gpu = {1}
+end
+cutorch.setDevice(settings.gpu[1])
 
 return settings

+ 74 - 20
lib/srcnn.lua

@@ -4,34 +4,52 @@ require 'w2nn'
 -- ref: http://arxiv.org/abs/1501.00092
 local srcnn = {}
 
-function nn.SpatialConvolutionMM:reset(stdv)
-   local fin = self.kW * self.kH * self.nInputPlane
-   local fout = self.kW * self.kH * self.nOutputPlane
+local function msra_filler(mod)
+   local fin = mod.kW * mod.kH * mod.nInputPlane
+   local fout = mod.kW * mod.kH * mod.nOutputPlane
    stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-   self.weight:normal(0, stdv)
-   self.bias:zero()
+   mod.weight:normal(0, stdv)
+   mod.bias:zero()
+end
+local function identity_filler(mod)
+   assert(mod.nInputPlane <= mod.nOutputPlane)
+   mod.weight:normal(0, 0.01)
+   mod.bias:zero()
+   local num_groups = mod.nInputPlane -- fixed
+   local filler_value = num_groups / mod.nOutputPlane
+   local in_group_size = math.floor(mod.nInputPlane / num_groups)
+   local out_group_size = math.floor(mod.nOutputPlane / num_groups)
+   local x = math.floor(mod.kW / 2)
+   local y = math.floor(mod.kH / 2)
+   for i = 0, num_groups - 1 do
+      for j = i * out_group_size, (i + 1) * out_group_size - 1 do
+	 for k = i * in_group_size, (i + 1) * in_group_size - 1 do
+	    mod.weight[j+1][k+1][y+1][x+1] = filler_value
+	 end
+      end
+   end
+end
+function nn.SpatialConvolutionMM:reset(stdv)
+   msra_filler(self)
 end
 function nn.SpatialFullConvolution:reset(stdv)
-   local fin = self.kW * self.kH * self.nInputPlane
-   local fout = self.kW * self.kH * self.nOutputPlane
-   stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-   self.weight:normal(0, stdv)
-   self.bias:zero()
+   msra_filler(self)
+end
+function nn.SpatialDilatedConvolution:reset(stdv)
+   identity_filler(self)
 end
+
 if cudnn and cudnn.SpatialConvolution then
    function cudnn.SpatialConvolution:reset(stdv)
-      local fin = self.kW * self.kH * self.nInputPlane
-      local fout = self.kW * self.kH * self.nOutputPlane
-      stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-      self.weight:normal(0, stdv)
-      self.bias:zero()
+      msra_filler(self)
    end
    function cudnn.SpatialFullConvolution:reset(stdv)
-      local fin = self.kW * self.kH * self.nInputPlane
-      local fout = self.kW * self.kH * self.nOutputPlane
-      stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-      self.weight:normal(0, stdv)
-      self.bias:zero()
+      msra_filler(self)
+   end
+   if cudnn.SpatialDilatedConvolution then
+      function cudnn.SpatialDilatedConvolution:reset(stdv)
+	 identity_filler(self)
+      end
    end
 end
 function nn.SpatialConvolutionMM:clearState()
@@ -127,6 +145,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
       error("unsupported backend:" .. backend)
    end
 end
+srcnn.SpatialConvolution = SpatialConvolution
+
 local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
    if backend == "cunn" then
       return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
@@ -136,6 +156,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
       error("unsupported backend:" .. backend)
    end
 end
+srcnn.SpatialFullConvolution = SpatialFullConvolution
+
 local function ReLU(backend)
    if backend == "cunn" then
       return nn.ReLU(true)
@@ -145,6 +167,8 @@ local function ReLU(backend)
       error("unsupported backend:" .. backend)
    end
 end
+srcnn.ReLU = ReLU
+
 local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
    if backend == "cunn" then
       return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@@ -154,6 +178,35 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
       error("unsupported backend:" .. backend)
    end
 end
+srcnn.SpatialMaxPooling = SpatialMaxPooling
+
+local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
+   if backend == "cunn" then
+      return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
+   elseif backend == "cudnn" then
+      return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+srcnn.SpatialAveragePooling = SpatialAveragePooling
+
+local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)      
+   if backend == "cunn" then
+      return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+   elseif backend == "cudnn" then
+      if cudnn.SpatialDilatedConvolution then
+	 -- cudnn v 6
+	 return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+      else
+	 return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+      end
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
+
 
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
@@ -548,6 +601,7 @@ function srcnn.create(model_name, backend, color)
       error("unsupported model_name: " .. model_name)
    end
 end
+
 --[[
 local model = srcnn.fcn_v1("cunn", 3):cuda()
 print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())

+ 42 - 0
lib/w2nn.lua

@@ -9,6 +9,40 @@ end
 local function load_cudnn()
    cudnn = require('cudnn')
 end
+local function make_data_parallel_table(model, gpus)
+   if cudnn then
+      local fastest, benchmark = cudnn.fastest, cudnn.benchmark
+      local dpt = nn.DataParallelTable(1, true, true)
+	 :add(model, gpus)
+	 :threads(function()
+	       require 'pl'
+	       local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	       package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+	       require 'torch'
+	       require 'cunn'
+	       require 'w2nn'
+	       local cudnn = require 'cudnn'
+	       cudnn.fastest, cudnn.benchmark = fastest, benchmark
+		 end)
+      dpt.gradInput = nil
+      model = dpt:cuda()
+   else
+      local dpt = nn.DataParallelTable(1, true, true)
+	    :add(model, gpus)
+	 :threads(function()
+	       require 'pl'
+	       local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	       package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+	       require 'torch'
+	       require 'cunn'
+	       require 'w2nn'
+		 end)
+      dpt.gradInput = nil
+      model = dpt:cuda()
+   end
+   return model
+end
+
 if w2nn then
    return w2nn
 else
@@ -27,11 +61,19 @@ else
       model:cuda():evaluate()
       return model
    end
+   function w2nn.data_parallel(model, gpus)
+      if #gpus > 1 then
+	 return make_data_parallel_table(model, gpus)
+      else
+	 return model
+      end
+   end
    require 'LeakyReLU'
    require 'ClippedWeightedHuberCriterion'
    require 'ClippedMSECriterion'
    require 'SSIMCriterion'
    require 'InplaceClip01'
    require 'L1Criterion'
+   require 'ShakeShakeTable'
    return w2nn
 end

+ 130 - 44
tools/benchmark.lua

@@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
 cmd:option("-x_file", "", 'input image for user method')
 cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
 cmd:option("-border", 0, 'border px that will removed')
+cmd:option("-metric", "", '(jaccard)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -203,8 +204,34 @@ local function remove_border(x, border)
 		     x:size(3) - border,
 		     x:size(2) - border)
 end
+local function create_metric(metric)
+   if metric and metric:len() > 0 then
+      if metric == "jaccard" then
+	 return {
+	    name = "jaccard", 
+	    func = function (a, b) 
+	       local ga = iproc.rgb2y(a)
+	       local gb = iproc.rgb2y(b)
+	       local ba = torch.Tensor():resizeAs(ga)
+	       local bb = torch.Tensor():resizeAs(gb)
+	       ba:zero()
+	       bb:zero()
+	       ba[torch.gt(ga, 0.5)] = 1.0
+	       bb[torch.gt(gb, 0.5)] = 1.0
+	       local num_a = ba:sum()
+	       local num_b = bb:sum()
+	       local a_and_b  = ba:cmul(bb):sum()
+	       return (a_and_b / (num_a + num_b - a_and_b))
+	 end}
+      else
+	 error("unknown metric: " .. metric)
+      end
+   else
+      return nil
+   end
+end
 local function benchmark(opt, x, model1, model2)
-   local mse1, mse2
+   local mse1, mse2, am1, am2
    local won = {0, 0}
    local model1_mse = 0
    local model2_mse = 0
@@ -217,6 +244,13 @@ local function benchmark(opt, x, model1, model2)
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
    local detail_fp = nil
+   local am = nil
+   local model1_am = 0
+   local model2_am = 0
+
+   if opt.method == "user" or opt.method == "diff" then
+      am = create_metric(opt.metric)
+   end
    if opt.save_info then
       detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
    end
@@ -406,32 +440,57 @@ local function benchmark(opt, x, model1, model2)
 	 ground_truth = remove_border(ground_truth, opt.border)
 	 model1_output = remove_border(model1_output, opt.border)
       end
-      mse1 = MSE(ground_truth, model1_output, opt.color)
-      model1_mse = model1_mse + mse1
-      model1_psnr = model1_psnr + MSE2PSNR(mse1)
-
+      if am then
+	 am1 = am.func(ground_truth, model1_output)
+	 model1_am = model1_am + am1
+      else
+	 mse1 = MSE(ground_truth, model1_output, opt.color)
+	 model1_mse = model1_mse + mse1
+	 model1_psnr = model1_psnr + MSE2PSNR(mse1)
+      end
       local won_model = 1
       if model2 then
 	 if opt.border > 0 then
 	    model2_output = remove_border(model2_output, opt.border)
 	 end
-	 mse2 = MSE(ground_truth, model2_output, opt.color)
-	 model2_mse = model2_mse + mse2
-	 model2_psnr = model2_psnr + MSE2PSNR(mse2)
-
-	 if mse1 < mse2 then
-	    won[1] = won[1] + 1
-	 elseif mse1 > mse2 then
-	    won[2] = won[2] + 1
-	    won_model = 2
+	 if am then
+	    am2 = am.func(ground_truth, model2_output)
+	    model2_am = model2_am + am2
+	 else
+	    mse2 = MSE(ground_truth, model2_output, opt.color)
+	    model2_mse = model2_mse + mse2
+	    model2_psnr = model2_psnr + MSE2PSNR(mse2)
+	 end
+	 if am then
+	    if am1 < am2 then
+	       won[1] = won[1] + 1
+	    elseif am1 > am2 then
+	       won[2] = won[2] + 1
+	       won_model = 2
+	    end
+	 else
+	    if mse1 < mse2 then
+	       won[1] = won[1] + 1
+	    elseif mse1 > mse2 then
+	       won[2] = won[2] + 1
+	       won_model = 2
+	    end
 	 end
 	 if detail_fp then
-	    detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
-					  MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	    if am then
+	       detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
+	    else
+	       detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
+					     MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	    end
 	 end
       else
 	 if detail_fp then
-	    detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
+	    if am then
+	       detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
+	    else
+	       detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
+	    end
 	 end
       end
       if baseline_output then
@@ -455,46 +514,65 @@ local function benchmark(opt, x, model1, model2)
 	 end
       end
       if opt.show_progress or i == #x then
-	 if model2 then
-	    if baseline_output then
+	 if am then
+	    if model2 then
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
 				i, #x,
 				model1_time,
 				model2_time,
-				math.sqrt(baseline_mse / i),
-				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				baseline_psnr / i,
-				model1_psnr / i, model2_psnr / i,
-				won[1], won[2]
-		  ))
+				am.name, model1_am / i, am.name, model2_am / i
+	       ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
+		  string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
 				i, #x,
 				model1_time,
-				model2_time,
-				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				model1_psnr / i, model2_psnr / i,
-				won[1], won[2]
-		  ))
+				am.name, model1_am / i
+	       ))
 	    end
 	 else
-	    if baseline_output then
-	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
-				i, #x,
-				model1_time,
-				math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
-				baseline_psnr / i, model1_psnr / i
+	    if model2 then
+	       if baseline_output then
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
+				   i, #x,
+				   model1_time,
+				   model2_time,
+				   math.sqrt(baseline_mse / i),
+				   math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
+				   baseline_psnr / i,
+				   model1_psnr / i, model2_psnr / i,
+				   won[1], won[2]
+		  ))
+	       else
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
+				   i, #x,
+				   model1_time,
+				   model2_time,
+				   math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
+				   model1_psnr / i, model2_psnr / i,
+				   won[1], won[2]
 		  ))
+	       end
 	    else
-	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
-				i, #x,
-				model1_time,
-				math.sqrt(model1_mse / i), model1_psnr / i
+	       if baseline_output then
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
+				   i, #x,
+				   model1_time,
+				   math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
+				   baseline_psnr / i, model1_psnr / i
 		  ))
+	       else
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
+				   i, #x,
+				   model1_time,
+				   math.sqrt(model1_mse / i), model1_psnr / i
+		  ))
+	       end
 	    end
 	 end
 	 io.stdout:flush()
@@ -515,6 +593,14 @@ local function benchmark(opt, x, model1, model2)
 	 fp:write(string.format("model2  : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
 				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
       end
+      if model1_am > 0 then
+	 fp:write(string.format("model1  : %s = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model1_am / #x), model1_time))
+      end
+      if model2_am > 0 then
+	 fp:write(string.format("model2  : %s = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model2_am / #x), model2_time))
+      end
       fp:close()
       if detail_fp then
 	 detail_fp:close()

+ 80 - 44
train.lua

@@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
    end
 end
 local function split_data(x, test_size)
-   local index = torch.randperm(#x)
-   local train_size = #x - test_size
-   local train_x = {}
-   local valid_x = {}
-   for i = 1, train_size do
-      train_x[i] = x[index[i]]
-   end
-   for i = 1, test_size do
-      valid_x[i] = x[index[train_size + i]]
+   if settings.validation_filename_split then
+      if not (x[1][2].data and x[1][2].data.basename) then
+	 error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
+      end
+      local basename_db = {}
+      for i = 1, #x do
+	 local meta = x[i][2].data
+	 if basename_db[meta.basename] then
+	    table.insert(basename_db[meta.basename], x[i])
+	 else
+	    basename_db[meta.basename] = {x[i]}
+	 end
+      end
+      local basename_list = {}
+      for k, v in pairs(basename_db) do
+	 table.insert(basename_list, v)
+      end
+      local index = torch.randperm(#basename_list)
+      local train_x = {}
+      local valid_x = {}
+      local pos = 1
+      for i = 1, #basename_list do
+	 if #valid_x >= test_size then
+	    break
+	 end
+	 local xs = basename_list[index[pos]]
+	 for j = 1, #xs do
+	    table.insert(valid_x, xs[j])
+	 end
+	 pos = pos + 1
+      end
+      for i = pos, #basename_list do
+	 local xs = basename_list[index[i]]
+	 for j = 1, #xs do
+	    table.insert(train_x, xs[j])
+	 end
+      end
+      return train_x, valid_x
+   else
+      local index = torch.randperm(#x)
+      local train_size = #x - test_size
+      local train_x = {}
+      local valid_x = {}
+      for i = 1, train_size do
+	 train_x[i] = x[index[i]]
+      end
+      for i = 1, test_size do
+	 valid_x[i] = x[index[train_size + i]]
+      end
+      return train_x, valid_x
    end
-   return train_x, valid_x
 end
 
 local g_transform_pool = nil
@@ -175,35 +215,19 @@ local function transform_pool_init(has_resize, offset)
 						    settings.crop_size, offset,
 						    n, conf)
 	    elseif settings.method == "user" then
-	       if is_validation == nil then is_validation = false end
-	       local rotate_rate = nil 
-	       local scale_rate = nil
-	       local negate_rate = nil
-	       local negate_x_rate = nil
-	       if is_validation then
-		  rotate_rate = 0
-		  scale_rate = 0
-		  negate_rate = 0
-		  negate_x_rate = 0
-	       else
-		  rotate_rate = settings.random_pairwise_rotate_rate
-		  scale_rate = settings.random_pairwise_scale_rate
-		  negate_rate = settings.random_pairwise_negate_rate
-		  negate_x_rate = settings.random_pairwise_negate_x_rate
-	       end
 	       local conf = tablex.update({
 		     gcn = settings.gcn,
 		     max_size = settings.max_size,
 		     active_cropping_rate = active_cropping_rate,
 		     active_cropping_tries = active_cropping_tries,
-		     random_pairwise_rotate_rate = rotate_rate,
+		     random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
 		     random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
 		     random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
-		     random_pairwise_scale_rate = scale_rate,
+		     random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
 		     random_pairwise_scale_min = settings.random_pairwise_scale_min,
 		     random_pairwise_scale_max = settings.random_pairwise_scale_max,
-		     random_pairwise_negate_rate = negate_rate,
-		     random_pairwise_negate_x_rate = negate_x_rate,
+		     random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
+		     random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
 		     pairwise_y_binary = settings.pairwise_y_binary,
 		     pairwise_flip = settings.pairwise_flip,
 		     rgb = (settings.color == "rgb")}, meta)
@@ -290,7 +314,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
       local batch_mse = eval_metric:forward(z, targets)
       loss = loss + criterion:forward(z, targets)
       mse = mse + batch_mse
-      psnr = psnr + (10 * math.log10(1 / batch_mse))
+      psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
       loss_count = loss_count + 1
       if loss_count % 10 == 0 then
 	 xlua.progress(t, #data)
@@ -322,6 +346,10 @@ local function create_criterion(model)
       return w2nn.L1Criterion():cuda()
    elseif settings.loss == "mse" then
       return w2nn.ClippedMSECriterion(0, 1.0):cuda()
+   elseif settings.loss == "bce" then
+      local bce = nn.BCECriterion()
+      bce.sizeAverage = true
+      return bce:cuda()
    else
       error("unsupported loss .." .. settings.loss)
    end
@@ -421,7 +449,10 @@ local function plot(train, valid)
 	 {'validation', torch.Tensor(valid), '-'}})
 end
 local function train()
-   local x = remove_small_image(torch.load(settings.images))
+   local x = torch.load(settings.images)
+   if settings.method ~= "user" then
+      x = remove_small_image(x)
+   end
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local hist_train = {}
    local hist_valid = {}
@@ -429,7 +460,12 @@ local function train()
    if settings.resume:len() > 0 then
       model = torch.load(settings.resume, "ascii")
    else
-      model = srcnn.create(settings.model, settings.backend, settings.color)
+      if stringx.endswith(settings.model, ".lua") then
+	 local create_model = dofile(settings.model)
+	 model = create_model(srcnn, settings)
+      else
+	 model = srcnn.create(settings.model, settings.backend, settings.color)
+      end
    end
    if model.w2nn_input_size then
       if settings.crop_size ~= model.w2nn_input_size then
@@ -484,8 +520,9 @@ local function train()
 		       ch, settings.crop_size, settings.crop_size)
    end
    local instance_loss = nil
+   local pmodel = w2nn.data_parallel(model, settings.gpu)
    for epoch = 1, settings.epoch do
-      model:training()
+      pmodel:training()
       print("# " .. epoch)
       if adam_config.learningRate then
 	 print("learning rate: " .. adam_config.learningRate)
@@ -523,13 +560,13 @@ local function train()
       instance_loss = torch.Tensor(x:size(1)):zero()
 
       for i = 1, settings.inner_epoch do
-	 model:training()
-	 local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 pmodel:training()
+	 local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
 	 instance_loss:copy(il)
 	 print(train_score)
-	 model:evaluate()
+	 pmodel:evaluate()
 	 print("# validation")
-	 local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
+	 local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
 	 table.insert(hist_train, train_score.loss)
 	 table.insert(hist_valid, score.loss)
 	 if settings.plot then
@@ -546,8 +583,9 @@ local function train()
 	    best_score = score_for_update
 	    print("* model has updated")
 	    if settings.save_history then
-	       torch.save(settings.model_file_best, model:clearState(), "ascii")
-	       torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
+	       pmodel:clearState()
+	       torch.save(settings.model_file_best, model, "ascii")
+	       torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.%d-%d.png"):format(settings.noise_level,
@@ -571,7 +609,8 @@ local function train()
 		  save_test_user(model, test_image, log)
 	       end
 	    else
-	       torch.save(settings.model_file, model:clearState(), "ascii")
+	       pmodel:clearState()
+	       torch.save(settings.model_file, model, "ascii")
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.png"):format(settings.noise_level))
@@ -597,9 +636,6 @@ local function train()
       end
    end
 end
-if settings.gpu > 0 then
-   cutorch.setDevice(settings.gpu)
-end
 torch.manualSeed(settings.seed)
 cutorch.manualSeed(settings.seed)
 print(settings)

+ 1 - 1
waifu2x.lua

@@ -276,6 +276,7 @@ local function waifu2x()
    if opt.thread > 0 then
       torch.setnumthreads(opt.thread)
    end
+   cutorch.setDevice(opt.gpu)
    if cudnn then
       cudnn.fastest = true
       if opt.l:len() > 0 then
@@ -293,6 +294,5 @@ local function waifu2x()
    else
       convert_frames(opt)
    end
-   cutorch.setDevice(opt.gpu)
 end
 waifu2x()