Selaa lähdekoodia

Add new models

upconv_7 is 2.3x faster than previous model
nagadomi 9 vuotta sitten
vanhempi
commit
51ae485cd1

+ 4 - 250
lib/pairwise_transform.lua

@@ -1,255 +1,9 @@
-require 'image'
-local gm = require 'graphicsmagick'
-local iproc = require 'iproc'
-local data_augmentation = require 'data_augmentation'
-
+require 'pl'
 local pairwise_transform = {}
 
-local function random_half(src, p, filters)
-   if torch.uniform() < p then
-      local filter = filters[torch.random(1, #filters)]
-      return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
-   else
-      return src
-   end
-end
-local function crop_if_large(src, max_size)
-   local tries = 4
-   if src:size(2) > max_size and src:size(3) > max_size then
-      local rect
-      for i = 1, tries do
-	 local yi = torch.random(0, src:size(2) - max_size)
-	 local xi = torch.random(0, src:size(3) - max_size)
-	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-	 -- ignore simple background
-	 if rect:float():std() >= 0 then
-	    break
-	 end
-      end
-      return rect
-   else
-      return src
-   end
-end
-local function preprocess(src, crop_size, options)
-   local dest = src
-   dest = random_half(dest, options.random_half_rate, options.downsampling_filters)
-   dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
-   dest = data_augmentation.flip(dest)
-   dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
-   dest = data_augmentation.overlay(dest, options.random_overlay_rate)
-   dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
-   dest = data_augmentation.shift_1px(dest)
-   
-   return dest
-end
-local function active_cropping(x, y, size, p, tries)
-   assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
-   local r = torch.uniform()
-   local t = "float"
-   if x:type() == "torch.ByteTensor" then
-      t = "byte"
-   end
-   if p < r then
-      local xi = torch.random(0, y:size(3) - (size + 1))
-      local yi = torch.random(0, y:size(2) - (size + 1))
-      local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
-      local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
-      return xc, yc
-   else
-      local lowres = gm.Image(x, "RGB", "DHW"):
-	 size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
-	 size(x:size(3), x:size(2), "Box"):
-	 toTensor(t, "RGB", "DHW")
-      local best_se = 0.0
-      local best_xc, best_yc
-      local m = torch.FloatTensor(x:size(1), size, size)
-      for i = 1, tries do
-	 local xi = torch.random(0, y:size(3) - (size + 1))
-	 local yi = torch.random(0, y:size(2) - (size + 1))
-	 local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
-	 local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
-	 local xcf = iproc.byte2float(xc)
-	 local lcf = iproc.byte2float(lc)
-	 local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
-	 if se >= best_se then
-	    best_xc = xcf
-	    best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
-	    best_se = se
-	 end
-      end
-      return best_xc, best_yc
-   end
-end
-function pairwise_transform.scale(src, scale, size, offset, n, options)
-   local filters = options.downsampling_filters
-   local unstable_region_offset = 8
-   local downsampling_filter = filters[torch.random(1, #filters)]
-   local y = preprocess(src, size, options)
-   assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
-   local down_scale = 1.0 / scale
-   local x
-   if options.gamma_correction then
-      x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
-					       y:size(2) * down_scale, downsampling_filter),
-		      y:size(3), y:size(2), options.upsampling_filter)
-   else
-      x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
-				  y:size(2) * down_scale, downsampling_filter),
-		      y:size(3), y:size(2), options.upsampling_filter)
-   end
-   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
-		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
-   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
-		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
-   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
-   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
-   
-   local batch = {}
-   for i = 1, n do
-      local xc, yc = active_cropping(x, y,
-				     size,
-				     options.active_cropping_rate,
-				     options.active_cropping_tries)
-      xc = iproc.byte2float(xc)
-      yc = iproc.byte2float(yc)
-      if options.rgb then
-      else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
-      end
-      table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
-   end
-   return batch
-end
-function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
-   local unstable_region_offset = 8
-   local y = preprocess(src, size, options)
-   local x = y
-
-   for i = 1, #quality do
-      x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg"):depth(8)
-      if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-	 -- YUV 420
-	 x:samplingFactors({2.0, 1.0, 1.0})
-      else
-	 -- YUV 444
-	 x:samplingFactors({1.0, 1.0, 1.0})
-      end
-      local blob, len = x:toBlob(quality[i])
-      x:fromBlob(blob, len)
-      x = x:toTensor("byte", "RGB", "DHW")
-   end
-   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
-		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
-   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
-		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
-   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
-   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
-   
-   local batch = {}
-   for i = 1, n do
-      local xc, yc = active_cropping(x, y, size,
-				     options.active_cropping_rate,
-				     options.active_cropping_tries)
-      xc = iproc.byte2float(xc)
-      yc = iproc.byte2float(yc)
-      if options.rgb then
-      else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
-      end
-      if torch.uniform() < options.nr_rate then
-	 -- reducing noise
-	 table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
-      else
-	 -- ratain useful details
-	 table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
-      end
-   end
-   return batch
-end
-function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
-   if style == "art" then
-      if level == 1 then
-	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-					 size, offset, n, options)
-      elseif level == 2 or level == 3 then
-	 -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
-	 local r = torch.uniform()
-	 if r > 0.6 then
-	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
-					    size, offset, n, options)
-	 elseif r > 0.3 then
-	    local quality1 = torch.random(37, 70)
-	    local quality2 = quality1 - torch.random(5, 10)
-	    return pairwise_transform.jpeg_(src, {quality1, quality2},
-					    size, offset, n, options)
-	 else
-	    local quality1 = torch.random(52, 70)
-	    local quality2 = quality1 - torch.random(5, 15)
-	    local quality3 = quality1 - torch.random(15, 25)
-	    
-	    return pairwise_transform.jpeg_(src, 
-					    {quality1, quality2, quality3},
-					    size, offset, n, options)
-	 end
-      else
-	 error("unknown noise level: " .. level)
-      end
-   elseif style == "photo" then
-      -- level adjusting by -nr_rate
-      return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
-				      size, offset, n,
-				      options)
-   else
-      error("unknown style: " .. style)
-   end
-end
+pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_scale'))
+pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_jpeg'))
 
-function pairwise_transform.test_jpeg(src)
-   torch.setdefaulttensortype("torch.FloatTensor")
-   local options = {random_color_noise_rate = 0.5,
-		    random_half_rate = 0.5,
-		    random_overlay_rate = 0.5,
-		    random_unsharp_mask_rate = 0.5,
-		    jpeg_chroma_subsampling_rate = 0.5,
-		    nr_rate = 1.0,
-		    active_cropping_rate = 0.5,
-		    active_cropping_tries = 10,
-		    max_size = 256,
-		    rgb = true
-   }
-   local image = require 'image'
-   local src = image.lena()
-   for i = 1, 9 do
-      local xy = pairwise_transform.jpeg(src,
-					 "art",
-					 torch.random(1, 2),
-					 128, 7, 1, options)
-      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
-      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
-   end
-end
-function pairwise_transform.test_scale(src)
-   torch.setdefaulttensortype("torch.FloatTensor")
-   local options = {random_color_noise_rate = 0.5,
-		    random_half_rate = 0.5,
-		    random_overlay_rate = 0.5,
-		    random_unsharp_mask_rate = 0.5,
-		    active_cropping_rate = 0.5,
-		    active_cropping_tries = 10,
-		    max_size = 256,
-		    rgb = true
-   }
-   local image = require 'image'
-   local src = image.lena()
+print(pairwise_transform)
 
-   for i = 1, 10 do
-      local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
-      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
-      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
-   end
-end
 return pairwise_transform

+ 117 - 0
lib/pairwise_transform_jpeg.lua

@@ -0,0 +1,117 @@
+local pairwise_utils = require 'pairwise_transform_utils'
+local gm = require 'graphicsmagick'
+local iproc = require 'iproc'
+local pairwise_transform = {}
+
+function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
+   local unstable_region_offset = 8
+   local y = pairwise_utils.preprocess(src, size, options)
+   local x = y
+
+   for i = 1, #quality do
+      x = gm.Image(x, "RGB", "DHW")
+      x:format("jpeg"):depth(8)
+      if torch.uniform() < options.jpeg_chroma_subsampling_rate then
+	 -- YUV 420
+	 x:samplingFactors({2.0, 1.0, 1.0})
+      else
+	 -- YUV 444
+	 x:samplingFactors({1.0, 1.0, 1.0})
+      end
+      local blob, len = x:toBlob(quality[i])
+      x:fromBlob(blob, len)
+      x = x:toTensor("byte", "RGB", "DHW")
+   end
+   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
+   
+   local batch = {}
+   for i = 1, n do
+      local xc, yc = pairwise_utils.active_cropping(x, y, size, 1,
+						    options.active_cropping_rate,
+						    options.active_cropping_tries)
+      xc = iproc.byte2float(xc)
+      yc = iproc.byte2float(yc)
+      if options.rgb then
+      else
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      if torch.uniform() < options.nr_rate then
+	 -- reducing noise
+	 table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      else
+	 -- ratain useful details
+	 table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      end
+   end
+   return batch
+end
+function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
+   if style == "art" then
+      if level == 1 then
+	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
+					 size, offset, n, options)
+      elseif level == 2 or level == 3 then
+	 -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
+	 local r = torch.uniform()
+	 if r > 0.6 then
+	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
+					    size, offset, n, options)
+	 elseif r > 0.3 then
+	    local quality1 = torch.random(37, 70)
+	    local quality2 = quality1 - torch.random(5, 10)
+	    return pairwise_transform.jpeg_(src, {quality1, quality2},
+					    size, offset, n, options)
+	 else
+	    local quality1 = torch.random(52, 70)
+	    local quality2 = quality1 - torch.random(5, 15)
+	    local quality3 = quality1 - torch.random(15, 25)
+	    
+	    return pairwise_transform.jpeg_(src, 
+					    {quality1, quality2, quality3},
+					    size, offset, n, options)
+	 end
+      else
+	 error("unknown noise level: " .. level)
+      end
+   elseif style == "photo" then
+      -- level adjusting by -nr_rate
+      return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
+				      size, offset, n,
+				      options)
+   else
+      error("unknown style: " .. style)
+   end
+end
+
+function pairwise_transform.test_jpeg(src)
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local options = {random_color_noise_rate = 0.5,
+		    random_half_rate = 0.5,
+		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
+		    jpeg_chroma_subsampling_rate = 0.5,
+		    nr_rate = 1.0,
+		    active_cropping_rate = 0.5,
+		    active_cropping_tries = 10,
+		    max_size = 256,
+		    rgb = true
+   }
+   local image = require 'image'
+   local src = image.lena()
+   for i = 1, 9 do
+      local xy = pairwise_transform.jpeg(src,
+					 "art",
+					 torch.random(1, 2),
+					 128, 7, 1, options)
+      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
+      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
+   end
+end
+return pairwise_transform
+

+ 86 - 0
lib/pairwise_transform_scale.lua

@@ -0,0 +1,86 @@
+local pairwise_utils = require 'pairwise_transform_utils'
+local iproc = require 'iproc'
+local pairwise_transform = {}
+
+function pairwise_transform.scale(src, scale, size, offset, n, options)
+   local filters = options.downsampling_filters
+   local unstable_region_offset = 8
+   local downsampling_filter = filters[torch.random(1, #filters)]
+   local y = pairwise_utils.preprocess(src, size, options)
+   assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
+   local down_scale = 1.0 / scale
+   local x
+   if options.gamma_correction then
+      local small = iproc.scale_with_gamma22(y, y:size(3) * down_scale,
+					     y:size(2) * down_scale, downsampling_filter)
+      if options.x_upsampling then
+	 x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
+      else
+	 x = small
+      end
+   else
+      local small = iproc.scale(y, y:size(3) * down_scale,
+				  y:size(2) * down_scale, downsampling_filter)
+      if options.x_upsampling then
+	 x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
+      else
+	 x = small
+      end
+   end
+
+   if options.x_upsampling then
+      x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		     x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+      y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		     y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+      assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+      assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
+   else
+      assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
+   end
+   local scale_inner = scale
+   if options.x_upsampling then
+      scale_inner = 1
+   end
+   local batch = {}
+
+   for i = 1, n do
+      local xc, yc = pairwise_utils.active_cropping(x, y,
+							size,
+							scale_inner,
+							options.active_cropping_rate,
+							options.active_cropping_tries)
+      xc = iproc.byte2float(xc)
+      yc = iproc.byte2float(yc)
+      if options.rgb then
+      else
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+   end
+   return batch
+end
+function pairwise_transform.test_scale(src)
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local options = {random_color_noise_rate = 0.5,
+		    random_half_rate = 0.5,
+		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
+		    active_cropping_rate = 0.5,
+		    active_cropping_tries = 10,
+		    max_size = 256,
+		    x_upsampling = false,
+		    downsampling_filters = "Box",
+		    rgb = true
+   }
+   local image = require 'image'
+   local src = image.lena()
+
+   for i = 1, 10 do
+      local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
+      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
+      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
+   end
+end
+return pairwise_transform

+ 91 - 0
lib/pairwise_transform_utils.lua

@@ -0,0 +1,91 @@
+require 'image'
+local gm = require 'graphicsmagick'
+local iproc = require 'iproc'
+local data_augmentation = require 'data_augmentation'
+local pairwise_transform_utils = {}
+
+function pairwise_transform_utils.random_half(src, p, filters)
+   if torch.uniform() < p then
+      local filter = filters[torch.random(1, #filters)]
+      return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
+   else
+      return src
+   end
+end
+function pairwise_transform_utils.crop_if_large(src, max_size)
+   local tries = 4
+   if src:size(2) > max_size and src:size(3) > max_size then
+      local rect
+      for i = 1, tries do
+	 local yi = torch.random(0, src:size(2) - max_size)
+	 local xi = torch.random(0, src:size(3) - max_size)
+	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
+	 -- ignore simple background
+	 if rect:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect
+   else
+      return src
+   end
+end
+function pairwise_transform_utils.preprocess(src, crop_size, options)
+   local dest = src
+   dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
+   dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
+   dest = data_augmentation.flip(dest)
+   dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
+   dest = data_augmentation.overlay(dest, options.random_overlay_rate)
+   dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
+   dest = data_augmentation.shift_1px(dest)
+   
+   return dest
+end
+function pairwise_transform_utils.active_cropping(x, y, size, scale, p, tries)
+   assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
+   assert("crop_size % scale == 0", size % scale == 0)
+   local r = torch.uniform()
+   local t = "float"
+   if x:type() == "torch.ByteTensor" then
+      t = "byte"
+   end
+   if p < r then
+      local xi = torch.random(0, x:size(3) - (size + 1))
+      local yi = torch.random(0, x:size(2) - (size + 1))
+      local yc = iproc.crop(y, xi * scale, yi * scale, xi * scale + size, yi * scale + size)
+      local xc = iproc.crop(x, xi, yi, xi + size / scale, yi + size / scale)
+      return xc, yc
+   else
+      local test_scale = 2
+      if test_scale < scale then
+	 test_scale = scale
+      end
+      local lowres = gm.Image(y, "RGB", "DHW"):
+	    size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
+	    size(y:size(3), y:size(2), "Box"):
+	    toTensor(t, "RGB", "DHW")
+      local best_se = 0.0
+      local best_xi, best_yi
+      local m = torch.FloatTensor(y:size(1), size, size)
+      for i = 1, tries do
+	 local xi = torch.random(0, x:size(3) - (size + 1)) * scale
+	 local yi = torch.random(0, x:size(2) - (size + 1)) * scale
+	 local xc = iproc.crop(y, xi, yi, xi + size, yi + size)
+	 local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
+	 local xcf = iproc.byte2float(xc)
+	 local lcf = iproc.byte2float(lc)
+	 local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
+	 if se >= best_se then
+	    best_xi = xi
+	    best_yi = yi
+	    best_se = se
+	 end
+      end
+      local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
+      local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
+      return xc, yc
+   end
+end
+
+return pairwise_transform_utils

+ 105 - 32
lib/reconstruct.lua

@@ -49,6 +49,32 @@ local function reconstruct_rgb(model, x, offset, block_size)
    end
    return new_x
 end
+local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
+   local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero()
+   local input_block_size = block_size / scale
+   local output_block_size = block_size
+   local output_size = output_block_size - offset * 2
+   local output_size_in_input = input_block_size - offset
+   local input = torch.CudaTensor(1, 3, input_block_size, input_block_size)
+   
+   for i = 1, x:size(2), output_size_in_input do
+      for j = 1, new_x:size(3), output_size_in_input do
+	 if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
+	    local index = {{},
+			   {i, i + input_block_size - 1},
+			   {j, j + input_block_size - 1}}
+	    input:copy(x[index])
+	    local output = model:forward(input):view(3, output_size, output_size)
+	    local ii = (i - 1) * scale + 1
+	    local jj = (j - 1) * scale + 1
+	    local output_index = {{}, { ii , ii + output_size - 1 },
+	       { jj, jj + output_size - 1}}
+	    new_x[output_index]:copy(output)
+	 end
+      end
+   end
+   return new_x
+end
 local reconstruct = {}
 function reconstruct.is_rgb(model)
    if srcnn.channels(model) == 3 then
@@ -62,6 +88,9 @@ end
 function reconstruct.offset_size(model)
    return srcnn.offset_size(model)
 end
+function reconstruct.no_resize(model)
+   return srcnn.has_resize(model)
+end
 function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128
    local output_size = block_size - offset * 2
@@ -95,8 +124,14 @@ end
 function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
    upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
-   local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
-   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
+
+   local x_lanczos
+   if reconstruct.no_resize(model) then
+      x_lanczos = x:clone()
+   else
+      x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
+      x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
+   end
    if x:size(2) * x:size(3) > 2048*2048 then
       collectgarbage()
    end
@@ -162,39 +197,77 @@ function reconstruct.image_rgb(model, x, offset, block_size)
    return output
 end
 function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
-   upsampling_filter = upsampling_filter or "Box"
-   block_size = block_size or 128
-   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
-   if x:size(2) * x:size(3) > 2048*2048 then
+   if reconstruct.no_resize(model) then
+      block_size = block_size or 128
+      local input_block_size = block_size / scale
+      local x_w = x:size(3)
+      local x_h = x:size(2)
+      local process_size = input_block_size - offset * 2
+      -- TODO: under construction!! bug in 4x
+      local h_blocks = math.floor(x_h / process_size) + 2
+--	 ((x_h % process_size == 0 and 0) or 1)
+      local w_blocks = math.floor(x_w / process_size) + 2
+--	 ((x_w % process_size == 0 and 0) or 1)
+      local h = offset + (h_blocks * process_size) + offset
+      local w = offset + (w_blocks * process_size) + offset
+      local pad_h1 = offset
+      local pad_w1 = offset
+
+      local pad_h2 = (h - offset) - x:size(2)
+      local pad_w2 = (w - offset) - x:size(3)
+
+      x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+      if x:size(2) * x:size(3) > 2048*2048 then
+	 collectgarbage()
+      end
+      local y 
+      y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
+      local output = iproc.crop(y,
+				pad_w1, pad_h1,
+				pad_w1 + x_w * scale, pad_h1 + x_h * scale)
+      output[torch.lt(output, 0)] = 0
+      output[torch.gt(output, 1)] = 1
+      x = nil
+      y = nil
       collectgarbage()
-   end
-   local output_size = block_size - offset * 2
-   local h_blocks = math.floor(x:size(2) / output_size) +
-      ((x:size(2) % output_size == 0 and 0) or 1)
-   local w_blocks = math.floor(x:size(3) / output_size) +
-      ((x:size(3) % output_size == 0 and 0) or 1)
-   
-   local h = offset + h_blocks * output_size + offset
-   local w = offset + w_blocks * output_size + offset
-   local pad_h1 = offset
-   local pad_w1 = offset
-   local pad_h2 = (h - offset) - x:size(2)
-   local pad_w2 = (w - offset) - x:size(3)
-   x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-   if x:size(2) * x:size(3) > 2048*2048 then
+
+      return output
+   else
+      upsampling_filter = upsampling_filter or "Box"
+      block_size = block_size or 128
+      x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
+      if x:size(2) * x:size(3) > 2048*2048 then
+	 collectgarbage()
+      end
+      local output_size = block_size - offset * 2
+      local h_blocks = math.floor(x:size(2) / output_size) +
+	 ((x:size(2) % output_size == 0 and 0) or 1)
+      local w_blocks = math.floor(x:size(3) / output_size) +
+	 ((x:size(3) % output_size == 0 and 0) or 1)
+      
+      local h = offset + h_blocks * output_size + offset
+      local w = offset + w_blocks * output_size + offset
+      local pad_h1 = offset
+      local pad_w1 = offset
+      local pad_h2 = (h - offset) - x:size(2)
+      local pad_w2 = (w - offset) - x:size(3)
+      x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+      if x:size(2) * x:size(3) > 2048*2048 then
+	 collectgarbage()
+      end
+      local y 
+      y = reconstruct_rgb(model, x, offset, block_size)
+      local output = iproc.crop(y,
+				pad_w1, pad_h1,
+				y:size(3) - pad_w2, y:size(2) - pad_h2)
+      output[torch.lt(output, 0)] = 0
+      output[torch.gt(output, 1)] = 1
+      x = nil
+      y = nil
       collectgarbage()
-   end
-   local y = reconstruct_rgb(model, x, offset, block_size)
-   local output = iproc.crop(y,
-			     pad_w1, pad_h1,
-			     y:size(3) - pad_w2, y:size(2) - pad_h2)
-   output[torch.lt(output, 0)] = 0
-   output[torch.gt(output, 1)] = 1
-   x = nil
-   y = nil
-   collectgarbage()
 
-   return output
+      return output
+   end
 end
 
 function reconstruct.image(model, x, block_size)

+ 2 - 2
lib/settings.lua

@@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", 'method to training (noise|scale)')
-cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
+cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
 cmd:option("-noise_level", 1, '(1|2|3)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
@@ -34,7 +34,7 @@ cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution im
 cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
-cmd:option("-crop_size", 46, 'crop size')
+cmd:option("-crop_size", 48, 'crop size')
 cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
 cmd:option("-batch_size", 8, 'mini batch size')
 cmd:option("-patches", 16, 'number of patch samples')

+ 167 - 26
lib/srcnn.lua

@@ -9,14 +9,23 @@ function nn.SpatialConvolutionMM:reset(stdv)
    self.weight:normal(0, stdv)
    self.bias:zero()
 end
+function nn.SpatialFullConvolution:reset(stdv)
+   stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
+   self.weight:normal(0, stdv)
+   self.bias:zero()
+end
 if cudnn and cudnn.SpatialConvolution then
    function cudnn.SpatialConvolution:reset(stdv)
       stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
       self.weight:normal(0, stdv)
       self.bias:zero()
    end
+   function cudnn.SpatialFullConvolution:reset(stdv)
+      stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
+      self.weight:normal(0, stdv)
+      self.bias:zero()
+   end
 end
-
 function nn.SpatialConvolutionMM:clearState()
    if self.gradWeight then
       self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
@@ -26,9 +35,12 @@ function nn.SpatialConvolutionMM:clearState()
    end
    return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
 end
-
 function srcnn.channels(model)
-   return model:get(model:size() - 1).weight:size(1)
+   if model.w2nn_channels ~= nil then
+      return model.w2nn_channels
+   else
+      return model:get(model:size() - 1).weight:size(1)
+   end
 end
 function srcnn.backend(model)
    local conv = model:findModules("cudnn.SpatialConvolution")
@@ -47,32 +59,54 @@ function srcnn.color(model)
    end
 end
 function srcnn.name(model)
-   local backend_cudnn = false
-   local conv = model:findModules("nn.SpatialConvolutionMM")
-   if #conv == 0 then
-      backend_cudnn = true
-      conv = model:findModules("cudnn.SpatialConvolution")
-   end
-   if #conv == 7 then
-      return "vgg_7"
-   elseif #conv == 12 then
-      return "vgg_12"
+   if model.w2nn_arch_name then
+      return model.w2nn_arch_name
    else
-      return nil
+      local conv = model:findModules("nn.SpatialConvolutionMM")
+      if #conv == 0 then
+	 conv = model:findModules("cudnn.SpatialConvolution")
+      end
+      if #conv == 7 then
+	 return "vgg_7"
+      elseif #conv == 12 then
+	 return "vgg_12"
+      else
+	 error("unsupported model name")
+      end
    end
 end
 function srcnn.offset_size(model)
-   local conv = model:findModules("nn.SpatialConvolutionMM")
-   if #conv == 0 then
-      conv = model:findModules("cudnn.SpatialConvolution")
+   if model.w2nn_offset ~= nil then
+      return model.w2nn_offset
+   else
+      local name = srcnn.name(model)
+      if name:match("vgg_") then
+	 local conv = model:findModules("nn.SpatialConvolutionMM")
+	 if #conv == 0 then
+	    conv = model:findModules("cudnn.SpatialConvolution")
+	 end
+	 local offset = 0
+	 for i = 1, #conv do
+	    offset = offset + (conv[i].kW - 1) / 2
+	 end
+	 return math.floor(offset)
+      else
+	 error("unsupported model name")
+      end
    end
-   local offset = 0
-   for i = 1, #conv do
-      offset = offset + (conv[i].kW - 1) / 2
+end
+function srcnn.has_resize(model)
+   if model.w2nn_resize ~= nil then
+      return model.w2nn_resize
+   else
+      local name = srcnn.name(model)
+      if name:match("upconv") ~= nil then
+	 return true
+      else
+	 return false
+      end
    end
-   return math.floor(offset)
 end
-
 local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
    if backend == "cunn" then
       return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
@@ -82,6 +116,15 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
       error("unsupported backend:" .. backend)
    end
 end
+local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   if backend == "cunn" then
+      return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   elseif backend == "cudnn" then
+      return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
 
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
@@ -100,6 +143,11 @@ function srcnn.vgg_7(backend, ch)
    model:add(w2nn.LeakyReLU(0.1))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
+
+   model.w2nn_arch_name = "vgg_7"
+   model.w2nn_offset = 7
+   model.w2nn_resize = false
+   model.w2nn_channels = ch
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
@@ -132,12 +180,103 @@ function srcnn.vgg_12(backend, ch)
    model:add(w2nn.LeakyReLU(0.1))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
+
+   model.w2nn_arch_name = "vgg_12"
+   model.w2nn_offset = 12
+   model.w2nn_resize = false
+   model.w2nn_channels = ch
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+   
+   return model
+end
+
+-- Dilated Convolution (7 layers)
+function srcnn.dilated_7(backend, ch)
+   local model = nn.Sequential()
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(nn.View(-1):setNumInputDims(3))
+
+   model.w2nn_arch_name = "dilated_7"
+   model.w2nn_offset = 12
+   model.w2nn_resize = false
+   model.w2nn_channels = ch
+
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
    return model
 end
 
+-- Up Convolution
+function srcnn.upconv_7(backend, ch)
+   local model = nn.Sequential()
+
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
+
+   model.w2nn_arch_name = "upconv_7"
+   model.w2nn_offset = 12
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
+function srcnn.upconv_8_4x(backend, ch)
+   local model = nn.Sequential()
+
+   model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
+
+   model.w2nn_arch_name = "upconv_8_4x"
+   model.w2nn_offset = 12
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
 function srcnn.create(model_name, backend, color)
    model_name = model_name or "vgg_7"
    backend = backend or "cunn"
@@ -150,12 +289,14 @@ function srcnn.create(model_name, backend, color)
    else
       error("unsupported color: " .. color)
    end
-   if model_name == "vgg_7" then
-      return srcnn.vgg_7(backend, ch)
-   elseif model_name == "vgg_12" then
-      return srcnn.vgg_12(backend, ch)
+   if srcnn[model_name] then
+      return srcnn[model_name](backend, ch)
    else
       error("unsupported model_name: " .. model_name)
    end
 end
+
+--local model = srcnn.upconv_8_4x("cunn", 3):cuda()
+--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
+
 return srcnn

+ 33 - 7
train.lua

@@ -15,7 +15,9 @@ local pairwise_transform = require 'pairwise_transform'
 local image_loader = require 'image_loader'
 
 local function save_test_scale(model, rgb, file)
-   local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter)
+   local up = reconstruct.scale(model, settings.scale, rgb,
+				settings.scale * settings.crop_size,
+				settings.upsampling_filter)
    image.save(file, up)
 end
 local function save_test_jpeg(model, rgb, file)
@@ -96,6 +98,7 @@ local function create_criterion(model)
       local offset = reconstruct.offset_size(model)
       local output_w = settings.crop_size - offset * 2
       local weight = torch.Tensor(3, output_w * output_w)
+
       weight[1]:fill(0.29891 * 3) -- R
       weight[2]:fill(0.58661 * 3) -- G
       weight[3]:fill(0.11448 * 3) -- B
@@ -108,7 +111,7 @@ local function create_criterion(model)
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    end
 end
-local function transformer(x, is_validation, n, offset)
+local function transformer(model, x, is_validation, n, offset)
    x = compression.decompress(x)
    n = n or settings.patches
 
@@ -145,7 +148,8 @@ local function transformer(x, is_validation, n, offset)
 					 active_cropping_rate = active_cropping_rate,
 					 active_cropping_tries = active_cropping_tries,
 					 rgb = (settings.color == "rgb"),
-					 gamma_correction = settings.gamma_correction
+					 gamma_correction = settings.gamma_correction,
+					 x_upsampling = not srcnn.has_resize(model)
 				      })
    elseif settings.method == "noise" then
       return pairwise_transform.jpeg(x,
@@ -183,6 +187,22 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
       end
    end
 end
+local function remove_small_image(x)
+   local new_x = {}
+   for i = 1, #x do
+      local x_s = compression.size(x[i])
+      if x_s[2] / settings.scale > settings.crop_size + 16 and
+      x_s[3] / settings.scale > settings.crop_size + 16 then
+	 table.insert(new_x, x[i])
+      end
+      if i % 100 == 0 then
+	 collectgarbage()
+      end
+   end
+   print(string.format("removed %d small images", #x - #new_x))
+
+   return new_x
+end
 local function plot(train, valid)
    gnuplot.plot({
 	 {'training', torch.Tensor(train), '-'},
@@ -194,11 +214,11 @@ local function train()
    local model = srcnn.create(settings.model, settings.backend, settings.color)
    local offset = reconstruct.offset_size(model)
    local pairwise_func = function(x, is_validation, n)
-      return transformer(x, is_validation, n, offset)
+      return transformer(model, x, is_validation, n, offset)
    end
    local criterion = create_criterion(model)
    local eval_metric = nn.MSECriterion():cuda()
-   local x = torch.load(settings.images)
+   local x = remove_small_image(torch.load(settings.images))
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local adam_config = {
       learningRate = settings.learning_rate,
@@ -222,10 +242,16 @@ local function train()
    model:cuda()
    print("load .. " .. #train_x)
 
-   local x = torch.Tensor(settings.patches * #train_x,
-			  ch, settings.crop_size, settings.crop_size)
+   local x = nil
    local y = torch.Tensor(settings.patches * #train_x,
 			  ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
+   if srcnn.has_resize(model) then
+      x = torch.Tensor(settings.patches * #train_x,
+		       ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
+   else
+      x = torch.Tensor(settings.patches * #train_x,
+		       ch, settings.crop_size, settings.crop_size)
+   end
    for epoch = 1, settings.epoch do
       model:training()
       print("# " .. epoch)