Pārlūkot izejas kodu

remove noise_scale training

nagadomi 9 gadi atpakaļ
vecāks
revīzija
da786e15ba
3 mainītis faili ar 3 papildinājumiem un 165 dzēšanām
  1. 0 144
      lib/pairwise_transform.lua
  2. 3 3
      lib/settings.lua
  3. 0 18
      train.lua

+ 0 - 144
lib/pairwise_transform.lua

@@ -276,125 +276,6 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
       error("unknown category: " .. category)
    end
 end
-function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
-   if options.random_half then
-      src = random_half(src)
-   end
-   src = crop_if_large(src, math.max(size * 4, 512))
-   local down_scale = 1.0 / scale
-   local filters = {
-      "Box",        -- 0.012756949974688
-      "Blackman",   -- 0.013191924552285
-      --"Cartom",     -- 0.013753536746706
-      --"Hanning",    -- 0.013761314529647
-      --"Hermite",    -- 0.013850225205266
-      "SincFast",   -- 0.014095824314306
-      "Jinc",       -- 0.014244299255442
-   }
-   local downscale_filter = filters[torch.random(1, #filters)]
-   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
-   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
-   local y = src
-   local x
-   
-   if options.color_noise then
-      y = color_noise(y)
-   end
-   if options.overlay then
-      y = overlay_augment(y)
-   end
-   
-   x = y
-   x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
-   for i = 1, #quality do
-      x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg")
-      if options.jpeg_sampling_factors == 444 then
-	 x:samplingFactors({1.0, 1.0, 1.0})
-      else -- 422
-	 x:samplingFactors({2.0, 1.0, 1.0})
-      end
-      local blob, len = x:toBlob(quality[i])
-      x:fromBlob(blob, len)
-      x = x:toTensor("byte", "RGB", "DHW")
-   end
-   x = iproc.scale(x, y:size(3), y:size(2))
-   y = image.crop(y,
-		  xi, yi,
-		  xi + size, yi + size)
-   x = image.crop(x,
-		  xi, yi,
-		  xi + size, yi + size)
-   x = x:float():div(255)
-   y = y:float():div(255)
-   x, y = flip_augment(x, y)
-
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
-   end
-   
-   return x, image.crop(y, offset, offset, size - offset, size - offset)
-end
-function pairwise_transform.jpeg_scale(src, scale, category, level, size, offset, options)
-   options = options or {color_noise = false, random_half = true}
-   if category == "anime_style_art" then
-      if level == 1 then
-	 if torch.uniform() > 0.7 then
-	    return pairwise_transform.jpeg_scale_(src, scale, {},
-						  size, offset, options)
-	 else
-	    return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
-						  size, offset, options)
-	 end
-      elseif level == 2 then
-	 if torch.uniform() > 0.7 then
-	    return pairwise_transform.jpeg_scale_(src, scale, {},
-						  size, offset, options)
-	 else
-	    local r = torch.uniform()
-	    if r > 0.6 then
-	       return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
-						     size, offset, options)
-	    elseif r > 0.3 then
-	       local quality1 = torch.random(37, 70)
-	       local quality2 = quality1 - torch.random(5, 10)
-	       return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
-						     size, offset, options)
-	    else
-	       local quality1 = torch.random(52, 70)
-	       local quality2 = quality1 - torch.random(5, 15)
-	       local quality3 = quality1 - torch.random(15, 25)
-	       
-	       return pairwise_transform.jpeg_scale_(src, scale,
-						     {quality1, quality2, quality3 },
-						     size, offset, options)
-	    end
-	 end
-      else
-	 error("unknown noise level: " .. level)
-      end
-   elseif category == "photo" then
-      if level == 1 then
-	 if torch.uniform() > 0.7 then
-	    return pairwise_transform.jpeg_scale_(src, scale, {},
-						  size, offset, options)
-	 else
-	 return pairwise_transform.jpeg_scale_(src, scale, {torch.random(80, 95)},
-					       size, offset, options)
-	 end
-      elseif level == 2 then
-	 return pairwise_transform.jpeg_scale_(src, scale, {torch.random(70, 85)},
-					       size, offset, options)
-      else
-	 error("unknown noise level: " .. level)
-      end
-   else
-      error("unknown category: " .. category)
-   end
-end
-
 local function test_jpeg()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
@@ -428,31 +309,6 @@ local function test_scale()
       --print(x:mean(), y:mean())
    end
 end
-local function test_jpeg_scale()
-   torch.setdefaulttensortype('torch.FloatTensor')
-   local loader = require './image_loader'
-   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
-   local options = {color_noise = true,
-		    random_half = true,
-		    overlay = true,
-		    active_cropping_ratio = 0.5,
-		    active_cropping_times = 10
-   }
-   for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, options)
-      image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
-      --print(x:mean(), y:mean())
-   end
-   for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, options)
-      image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
-      --print(x:mean(), y:mean())
-   end
-end
 local function test_color_noise()
    torch.setdefaulttensortype('torch.FloatTensor')
    local loader = require './image_loader'

+ 3 - 3
lib/settings.lua

@@ -15,14 +15,14 @@ local settings = {}
 
 local cmd = torch.CmdLine()
 cmd:text()
-cmd:text("waifu2x")
+cmd:text("waifu2x-training")
 cmd:text("Options:")
 cmd:option("-seed", 11, 'fixed input seed')
 cmd:option("-data_dir", "./data", 'data directory')
--- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
+-- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slower than cunn
 cmd:option("-test", "images/miku_small.png", 'test image file')
 cmd:option("-model_dir", "./models", 'model directory')
-cmd:option("-method", "scale", '(noise|scale|noise_scale)')
+cmd:option("-method", "scale", '(noise|scale)')
 cmd:option("-noise_level", 1, '(1|2)')
 cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')

+ 0 - 18
train.lua

@@ -126,19 +126,6 @@ local function transformer(x, is_validation, n, offset)
 				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
 				       rgb = (settings.color == "rgb")
 				     })
-   elseif settings.method == "noise_scale" then
-      return pairwise_transform.jpeg_scale(x,
-					   settings.scale,
-					   settings.category,
-					   settings.noise_level,
-					   settings.crop_size, offset,
-					   n,
-					   { color_noise = color_noise,
-					     overlay = overlay,
-					     jpeg_sampling_factors = settings.jpeg_sampling_factors,
-					     random_half = settings.random_half,
-					     rgb = (settings.color == "rgb")
-					   })
    end
 end
 
@@ -195,11 +182,6 @@ local function train()
 	    local log = path.join(settings.model_dir,
 				  ("scale%.1f_best.png"):format(settings.scale))
 	    save_test_scale(model, test_image, log)
-	 elseif settings.method == "noise_scale" then
-	    local log = path.join(settings.model_dir,
-				  ("noise%d_scale%.1f_best.png"):format(settings.noise_level,
-									settings.scale))
-	    save_test_scale(model, test_image, log)
 	 end
       else
 	 lrd_count = lrd_count + 1