Переглянути джерело

Add support for resizing with gamma correction

nagadomi 9 роки тому
батько
коміт
13f702b968
5 змінених файлів з 44 додано та 18 видалено
  1. 1 1
      lib/iproc.lua
  2. 10 3
      lib/pairwise_transform.lua
  3. 14 9
      lib/settings.lua
  4. 17 4
      tools/benchmark.lua
  5. 2 1
      train.lua

+ 1 - 1
lib/iproc.lua

@@ -73,7 +73,7 @@ function iproc.scale_with_gamma22(src, width, height, filter)
    im:gammaCorrection(1.0 / 2.2):
       size(math.ceil(width), math.ceil(height), filter):
       gammaCorrection(2.2)
-   local dest = im:toTensor("float", "RGB", "DHW")
+   local dest = im:toTensor("float", "RGB", "DHW"):clamp(0.0, 1.0)
    if conversion then
       dest = iproc.float2byte(dest)
    end

+ 10 - 3
lib/pairwise_transform.lua

@@ -88,9 +88,16 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
    local y = preprocess(src, size, options)
    assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
    local down_scale = 1.0 / scale
-   local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
-				     y:size(2) * down_scale, downsampling_filter),
-			 y:size(3), y:size(2))
+   local x
+   if options.gamma_correction then
+      x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
+					       y:size(2) * down_scale, downsampling_filter),
+		      y:size(3), y:size(2))
+   else
+      x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
+				  y:size(2) * down_scale, downsampling_filter),
+		      y:size(3), y:size(2))
+   end
    x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
 		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
    y = iproc.crop(y, unstable_region_offset, unstable_region_offset,

+ 14 - 9
lib/settings.lua

@@ -49,21 +49,26 @@ cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing detai
 cmd:option("-save_history", 0, 'save all model (0|1)')
 cmd:option("-plot", 0, 'plot loss chart(0|1)')
 cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
+cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) in scale training (0|1)')
+
+local function to_bool(settings, name)
+   if settings[name] == 1 then
+      settings[name] = true
+   else
+      settings[name] = false
+   end
+end
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
    settings[k] = v
 end
-if settings.plot == 1 then
-   settings.plot = true
+to_bool(settings, "plot")
+to_bool(settings, "save_history")
+to_bool(settings, "gamma_correction")
+
+if settings.plot then
    require 'gnuplot'
-else
-   settings.plot = false
-end
-if settings.save_history == 1 then
-   settings.save_history = true
-else
-   settings.save_history = false
 end
 if settings.save_history then
    if settings.method == "noise" then

+ 17 - 4
tools/benchmark.lua

@@ -24,6 +24,7 @@ cmd:option("-jpeg_quality", 75, 'jpeg quality')
 cmd:option("-jpeg_times", 1, 'jpeg compression times')
 cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
 cmd:option("-range_bug", 0, 'Reproducing the dynamic range bug that is caused by MATLAB\'s rgb2ycbcr(1|0)')
+cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) (0|1)')
 
 local opt = cmd:parse(arg)
 torch.setdefaulttensortype('torch.FloatTensor')
@@ -31,6 +32,11 @@ if cudnn then
    cudnn.fastest = true
    cudnn.benchmark = false
 end
+if opt.gamma_correction == 1 then
+   opt.gamma_correction = true
+else
+   opt.gamma_correction = false
+end
 
 local function rgb2y_matlab(x)
    local y = torch.Tensor(1, x:size(2), x:size(3)):zero()
@@ -87,10 +93,17 @@ local function baseline_scale(x, filter)
 		      filter)
 end
 local function transform_scale(x, opt)
-   return iproc.scale(x,
-		      x:size(3) * 0.5,
-		      x:size(2) * 0.5,
-		      opt.filter)
+   if opt.gamma_correction then
+      return iproc.scale_with_gamma22(x,
+			 x:size(3) * 0.5,
+			 x:size(2) * 0.5,
+			 opt.filter)
+   else
+      return iproc.scale(x,
+			 x:size(3) * 0.5,
+			 x:size(2) * 0.5,
+			 opt.filter)
+   end
 end
 
 local function benchmark(opt, x, input_func, model1, model2)

+ 2 - 1
train.lua

@@ -120,7 +120,8 @@ local function transformer(x, is_validation, n, offset)
 					 max_size = settings.max_size,
 					 active_cropping_rate = active_cropping_rate,
 					 active_cropping_tries = active_cropping_tries,
-					 rgb = (settings.color == "rgb")
+					 rgb = (settings.color == "rgb"),
+					 gamma_correction = settings.gamma_correction
 				      })
    elseif settings.method == "noise" then
       return pairwise_transform.jpeg(x,