nagadomi 9 年之前
父節點
當前提交
48411a4dde
共有 2 個文件被更改,包括 89 次插入192 次删除
  1. 73 182
      lib/reconstruct.lua
  2. 16 10
      lib/srcnn.lua

+ 73 - 182
lib/reconstruct.lua

@@ -2,71 +2,28 @@ require 'image'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
 local srcnn = require 'srcnn'
 local srcnn = require 'srcnn'
 
 
-local function reconstruct_y(model, x, offset, block_size)
+local function reconstruct_nn(model, x, inner_scale, offset, block_size)
    if x:dim() == 2 then
    if x:dim() == 2 then
       x = x:reshape(1, x:size(1), x:size(2))
       x = x:reshape(1, x:size(1), x:size(2))
    end
    end
-   local new_x = torch.Tensor():resizeAs(x):zero()
-   local output_size = block_size - offset * 2
-   local input = torch.CudaTensor(1, 1, block_size, block_size)
-   
-   for i = 1, x:size(2), output_size do
-      for j = 1, x:size(3), output_size do
-	 if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
-	    local index = {{},
-			   {i, i + block_size - 1},
-			   {j, j + block_size - 1}}
-	    input:copy(x[index])
-	    local output = model:forward(input):view(1, output_size, output_size)
-	    local output_index = {{},
-				  {i + offset, offset + i + output_size - 1},
-				  {offset + j, offset + j + output_size - 1}}
-	    new_x[output_index]:copy(output)
-	 end
-      end
-   end
-   return new_x
-end
-local function reconstruct_rgb(model, x, offset, block_size)
-   local new_x = torch.Tensor():resizeAs(x):zero()
-   local output_size = block_size - offset * 2
-   local input = torch.CudaTensor(1, 3, block_size, block_size)
-   
-   for i = 1, x:size(2), output_size do
-      for j = 1, x:size(3), output_size do
-	 if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
-	    local index = {{},
-			   {i, i + block_size - 1},
-			   {j, j + block_size - 1}}
-	    input:copy(x[index])
-	    local output = model:forward(input):view(3, output_size, output_size)
-	    local output_index = {{},
-				  {i + offset, offset + i + output_size - 1},
-				  {offset + j, offset + j + output_size - 1}}
-	    new_x[output_index]:copy(output)
-	 end
-      end
-   end
-   return new_x
-end
-local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
-   local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero()
-   local input_block_size = block_size / scale
+   local ch = x:size(1)
+   local new_x = torch.Tensor(x:size(1), x:size(2) * inner_scale, x:size(3) * inner_scale):zero()
+   local input_block_size = block_size / inner_scale
    local output_block_size = block_size
    local output_block_size = block_size
    local output_size = output_block_size - offset * 2
    local output_size = output_block_size - offset * 2
-   local output_size_in_input = input_block_size - offset
-   local input = torch.CudaTensor(1, 3, input_block_size, input_block_size)
-   
+   local output_size_in_input = input_block_size - math.ceil(offset / inner_scale) * 2
+   local input = torch.CudaTensor(1, ch, input_block_size, input_block_size)
    for i = 1, x:size(2), output_size_in_input do
    for i = 1, x:size(2), output_size_in_input do
-      for j = 1, new_x:size(3), output_size_in_input do
+      for j = 1, x:size(3), output_size_in_input do
 	 if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
 	 if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
 	    local index = {{},
 	    local index = {{},
 			   {i, i + input_block_size - 1},
 			   {i, i + input_block_size - 1},
 			   {j, j + input_block_size - 1}}
 			   {j, j + input_block_size - 1}}
 	    input:copy(x[index])
 	    input:copy(x[index])
-	    local output = model:forward(input):view(3, output_size, output_size)
-	    local ii = (i - 1) * scale + 1
-	    local jj = (j - 1) * scale + 1
+	    local output = model:forward(input)
+	    output = output:view(ch, output_size, output_size)
+	    local ii = (i - 1) * inner_scale + 1
+	    local jj = (j - 1) * inner_scale + 1
 	    local output_index = {{}, { ii , ii + output_size - 1 },
 	    local output_index = {{}, { ii , ii + output_size - 1 },
 	       { jj, jj + output_size - 1}}
 	       { jj, jj + output_size - 1}}
 	    new_x[output_index]:copy(output)
 	    new_x[output_index]:copy(output)
@@ -88,31 +45,44 @@ end
 function reconstruct.offset_size(model)
 function reconstruct.offset_size(model)
    return srcnn.offset_size(model)
    return srcnn.offset_size(model)
 end
 end
-function reconstruct.no_resize(model)
-   return srcnn.has_resize(model)
+function reconstruct.has_resize(model)
+   return srcnn.scale_factor(model) > 1
+end
+function reconstruct.inner_scale(model)
+   return srcnn.scale_factor(model)
+end
+local function padding_params(x, model, block_size)
+   local p = {}
+   local offset = reconstruct.offset_size(model)
+   p.x_w = x:size(3)
+   p.x_h = x:size(2)
+   p.inner_scale = reconstruct.inner_scale(model)
+   local input_offset = math.ceil(offset / p.inner_scale)
+   local input_block_size = block_size / p.inner_scale
+   local process_size = input_block_size - input_offset * 2
+   local h_blocks = math.floor(p.x_h / process_size) +
+      ((p.x_h % process_size == 0 and 0) or 1)
+   local w_blocks = math.floor(p.x_w / process_size) +
+      ((p.x_w % process_size == 0 and 0) or 1)
+   local h = (h_blocks * process_size) + input_offset * 2
+   local w = (w_blocks * process_size) + input_offset * 2
+   p.pad_h1 = input_offset
+   p.pad_w1 = input_offset
+   p.pad_h2 = (h - input_offset) - p.x_h
+   p.pad_w2 = (w - input_offset) - p.x_w
+   return p
 end
 end
 function reconstruct.image_y(model, x, offset, block_size)
 function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128
    block_size = block_size or 128
-   local output_size = block_size - offset * 2
-   local h_blocks = math.floor(x:size(2) / output_size) +
-      ((x:size(2) % output_size == 0 and 0) or 1)
-   local w_blocks = math.floor(x:size(3) / output_size) +
-      ((x:size(3) % output_size == 0 and 0) or 1)
-   
-   local h = offset + h_blocks * output_size + offset
-   local w = offset + w_blocks * output_size + offset
-   local pad_h1 = offset
-   local pad_w1 = offset
-   local pad_h2 = (h - offset) - x:size(2)
-   local pad_w2 = (w - offset) - x:size(3)
-   x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_y(model, x[1], offset, block_size)
+   local p = padding_params(x, model, block_size)
+   x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
+   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
+   x = iproc.crop(x, p.pad_w1, p.pad_w2, p.pad_w1 + p.x_w, p.pad_w2 + p.x_h)
+   y = iproc.crop(y, 0, 0, p.x_w, p.x_h)
    y[torch.lt(y, 0)] = 0
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    y[torch.gt(y, 1)] = 1
    x[1]:copy(y)
    x[1]:copy(y)
-   local output = image.yuv2rgb(iproc.crop(x,
-					   pad_w1, pad_h1,
-					   x:size(3) - pad_w2, x:size(2) - pad_h2))
+   local output = image.yuv2rgb(x)
    output[torch.lt(output, 0)] = 0
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
    output[torch.gt(output, 1)] = 1
    x = nil
    x = nil
@@ -124,38 +94,25 @@ end
 function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
 function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
    upsampling_filter = upsampling_filter or "Box"
    upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
    block_size = block_size or 128
-
    local x_lanczos
    local x_lanczos
-   if reconstruct.no_resize(model) then
+   if reconstruct.has_resize(model) then
       x_lanczos = x:clone()
       x_lanczos = x:clone()
    else
    else
       x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
       x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
       x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
       x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
    end
    end
-   if x:size(2) * x:size(3) > 2048*2048 then
+   local p = padding_params(x, model, block_size)
+   if p.x_w * p.x_h > 2048*2048 then
       collectgarbage()
       collectgarbage()
    end
    end
-   local output_size = block_size - offset * 2
-   local h_blocks = math.floor(x:size(2) / output_size) +
-      ((x:size(2) % output_size == 0 and 0) or 1)
-   local w_blocks = math.floor(x:size(3) / output_size) +
-      ((x:size(3) % output_size == 0 and 0) or 1)
-   
-   local h = offset + h_blocks * output_size + offset
-   local w = offset + w_blocks * output_size + offset
-   local pad_h1 = offset
-   local pad_w1 = offset
-   local pad_h2 = (h - offset) - x:size(2)
-   local pad_w2 = (w - offset) - x:size(3)
-   x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   x_lanczos = image.rgb2yuv(iproc.padding(x_lanczos, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_y(model, x[1], offset, block_size)
+   x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
+   x_lanczos = image.rgb2yuv(x_lanczos)
+   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
+   y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
    y[torch.lt(y, 0)] = 0
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    y[torch.gt(y, 1)] = 1
    x_lanczos[1]:copy(y)
    x_lanczos[1]:copy(y)
-   local output = image.yuv2rgb(iproc.crop(x_lanczos,
-					   pad_w1, pad_h1,
-					   x_lanczos:size(3) - pad_w2, x_lanczos:size(2) - pad_h2))
+   local output = image.yuv2rgb(x_lanczos)
    output[torch.lt(output, 0)] = 0
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
    output[torch.gt(output, 1)] = 1
    x = nil
    x = nil
@@ -167,27 +124,13 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
 end
 end
 function reconstruct.image_rgb(model, x, offset, block_size)
 function reconstruct.image_rgb(model, x, offset, block_size)
    block_size = block_size or 128
    block_size = block_size or 128
-   local output_size = block_size - offset * 2
-   local h_blocks = math.floor(x:size(2) / output_size) +
-      ((x:size(2) % output_size == 0 and 0) or 1)
-   local w_blocks = math.floor(x:size(3) / output_size) +
-      ((x:size(3) % output_size == 0 and 0) or 1)
-   
-   local h = offset + h_blocks * output_size + offset
-   local w = offset + w_blocks * output_size + offset
-   local pad_h1 = offset
-   local pad_w1 = offset
-   local pad_h2 = (h - offset) - x:size(2)
-   local pad_w2 = (w - offset) - x:size(3)
-
-   x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-   if x:size(2) * x:size(3) > 2048*2048 then
+   local p = padding_params(x, model, block_size)
+   x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
+   if p.x_w * p.x_h > 2048*2048 then
       collectgarbage()
       collectgarbage()
    end
    end
-   local y = reconstruct_rgb(model, x, offset, block_size)
-   local output = iproc.crop(y,
-			     pad_w1, pad_h1,
-			     y:size(3) - pad_w2, y:size(2) - pad_h2)
+   local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
+   local output = iproc.crop(y, 0, 0, p.x_w, p.x_h)
    output[torch.lt(output, 0)] = 0
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
    output[torch.gt(output, 1)] = 1
    x = nil
    x = nil
@@ -197,79 +140,27 @@ function reconstruct.image_rgb(model, x, offset, block_size)
    return output
    return output
 end
 end
 function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
 function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
-   if reconstruct.no_resize(model) then
-      block_size = block_size or 128
-      local input_block_size = block_size / scale
-      local x_w = x:size(3)
-      local x_h = x:size(2)
-      local process_size = input_block_size - offset * 2
-      -- TODO: under construction!! bug in 4x
-      local h_blocks = math.floor(x_h / process_size) + 2
---	 ((x_h % process_size == 0 and 0) or 1)
-      local w_blocks = math.floor(x_w / process_size) + 2
---	 ((x_w % process_size == 0 and 0) or 1)
-      local h = offset + (h_blocks * process_size) + offset
-      local w = offset + (w_blocks * process_size) + offset
-      local pad_h1 = offset
-      local pad_w1 = offset
-
-      local pad_h2 = (h - offset) - x:size(2)
-      local pad_w2 = (w - offset) - x:size(3)
-
-      x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-      if x:size(2) * x:size(3) > 2048*2048 then
-	 collectgarbage()
-      end
-      local y 
-      y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
-      local output = iproc.crop(y,
-				pad_w1, pad_h1,
-				pad_w1 + x_w * scale, pad_h1 + x_h * scale)
-      output[torch.lt(output, 0)] = 0
-      output[torch.gt(output, 1)] = 1
-      x = nil
-      y = nil
-      collectgarbage()
-
-      return output
-   else
-      upsampling_filter = upsampling_filter or "Box"
-      block_size = block_size or 128
+   upsampling_filter = upsampling_filter or "Box"
+   block_size = block_size or 128
+   if not reconstruct.has_resize(model) then
       x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
       x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
-      if x:size(2) * x:size(3) > 2048*2048 then
-	 collectgarbage()
-      end
-      local output_size = block_size - offset * 2
-      local h_blocks = math.floor(x:size(2) / output_size) +
-	 ((x:size(2) % output_size == 0 and 0) or 1)
-      local w_blocks = math.floor(x:size(3) / output_size) +
-	 ((x:size(3) % output_size == 0 and 0) or 1)
-      
-      local h = offset + h_blocks * output_size + offset
-      local w = offset + w_blocks * output_size + offset
-      local pad_h1 = offset
-      local pad_w1 = offset
-      local pad_h2 = (h - offset) - x:size(2)
-      local pad_w2 = (w - offset) - x:size(3)
-      x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-      if x:size(2) * x:size(3) > 2048*2048 then
-	 collectgarbage()
-      end
-      local y 
-      y = reconstruct_rgb(model, x, offset, block_size)
-      local output = iproc.crop(y,
-				pad_w1, pad_h1,
-				y:size(3) - pad_w2, y:size(2) - pad_h2)
-      output[torch.lt(output, 0)] = 0
-      output[torch.gt(output, 1)] = 1
-      x = nil
-      y = nil
+   end
+   local p = padding_params(x, model, block_size)
+   x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
+   if p.x_w * p.x_h > 2048*2048 then
       collectgarbage()
       collectgarbage()
-
-      return output
    end
    end
-end
+   local y
+   y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
+   local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
+   output[torch.lt(output, 0)] = 0
+   output[torch.gt(output, 1)] = 1
+   x = nil
+   y = nil
+   collectgarbage()
 
 
+   return output
+end
 function reconstruct.image(model, x, block_size)
 function reconstruct.image(model, x, block_size)
    local i2rgb = false
    local i2rgb = false
    if x:size(1) == 1 then
    if x:size(1) == 1 then

+ 16 - 10
lib/srcnn.lua

@@ -59,7 +59,7 @@ function srcnn.color(model)
    end
    end
 end
 end
 function srcnn.name(model)
 function srcnn.name(model)
-   if model.w2nn_arch_name then
+   if model.w2nn_arch_name ~= nil then
       return model.w2nn_arch_name
       return model.w2nn_arch_name
    else
    else
       local conv = model:findModules("nn.SpatialConvolutionMM")
       local conv = model:findModules("nn.SpatialConvolutionMM")
@@ -71,7 +71,7 @@ function srcnn.name(model)
       elseif #conv == 12 then
       elseif #conv == 12 then
 	 return "vgg_12"
 	 return "vgg_12"
       else
       else
-	 error("unsupported model name")
+	 error("unsupported model")
       end
       end
    end
    end
 end
 end
@@ -91,19 +91,21 @@ function srcnn.offset_size(model)
 	 end
 	 end
 	 return math.floor(offset)
 	 return math.floor(offset)
       else
       else
-	 error("unsupported model name")
+	 error("unsupported model")
       end
       end
    end
    end
 end
 end
-function srcnn.has_resize(model)
-   if model.w2nn_resize ~= nil then
-      return model.w2nn_resize
+function srcnn.scale_factor(model)
+   if model.w2nn_scale_factor ~= nil then
+      return model.w2nn_scale_factor
    else
    else
       local name = srcnn.name(model)
       local name = srcnn.name(model)
-      if name:match("upconv") ~= nil then
-	 return true
+      if name == "upconv_7" then
+	 return 2
+      elseif name == "upconv_8_4x" then
+	 return 4
       else
       else
-	 return false
+	 return 1
       end
       end
    end
    end
 end
 end
@@ -146,7 +148,7 @@ function srcnn.vgg_7(backend, ch)
 
 
    model.w2nn_arch_name = "vgg_7"
    model.w2nn_arch_name = "vgg_7"
    model.w2nn_offset = 7
    model.w2nn_offset = 7
-   model.w2nn_resize = false
+   model.w2nn_scale_factor = 1
    model.w2nn_channels = ch
    model.w2nn_channels = ch
    --model:cuda()
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
@@ -183,6 +185,7 @@ function srcnn.vgg_12(backend, ch)
 
 
    model.w2nn_arch_name = "vgg_12"
    model.w2nn_arch_name = "vgg_12"
    model.w2nn_offset = 12
    model.w2nn_offset = 12
+   model.w2nn_scale_factor = 1
    model.w2nn_resize = false
    model.w2nn_resize = false
    model.w2nn_channels = ch
    model.w2nn_channels = ch
    --model:cuda()
    --model:cuda()
@@ -211,6 +214,7 @@ function srcnn.dilated_7(backend, ch)
 
 
    model.w2nn_arch_name = "dilated_7"
    model.w2nn_arch_name = "dilated_7"
    model.w2nn_offset = 12
    model.w2nn_offset = 12
+   model.w2nn_scale_factor = 1
    model.w2nn_resize = false
    model.w2nn_resize = false
    model.w2nn_channels = ch
    model.w2nn_channels = ch
 
 
@@ -240,6 +244,7 @@ function srcnn.upconv_7(backend, ch)
 
 
    model.w2nn_arch_name = "upconv_7"
    model.w2nn_arch_name = "upconv_7"
    model.w2nn_offset = 12
    model.w2nn_offset = 12
+   model.w2nn_scale_factor = 2
    model.w2nn_resize = true
    model.w2nn_resize = true
    model.w2nn_channels = ch
    model.w2nn_channels = ch
 
 
@@ -269,6 +274,7 @@ function srcnn.upconv_8_4x(backend, ch)
 
 
    model.w2nn_arch_name = "upconv_8_4x"
    model.w2nn_arch_name = "upconv_8_4x"
    model.w2nn_offset = 12
    model.w2nn_offset = 12
+   model.w2nn_scale_factor = 4
    model.w2nn_resize = true
    model.w2nn_resize = true
    model.w2nn_channels = ch
    model.w2nn_channels = ch