ソースを参照

Merge pull request #95 from nagadomi/dev

Merge from dev branch
nagadomi 9 年 前
コミット
7849f51f42
44 ファイル変更760 行追加465 行削除
  1. 38 0
      appendix/benchmark.md
  2. 6 0
      assets/index.es.html
  3. 6 0
      assets/index.fr.html
  4. 6 0
      assets/index.html
  5. 6 0
      assets/index.ja.html
  6. 6 0
      assets/index.pt.html
  7. 6 0
      assets/index.ru.html
  8. 6 1
      lib/LeakyReLU.lua
  9. 19 0
      lib/PSNRCriterion.lua
  10. 1 46
      lib/cleanup_model.lua
  11. 4 2
      lib/minibatch_adam.lua
  12. 5 20
      lib/pairwise_transform.lua
  13. 15 3
      lib/settings.lua
  14. 10 0
      lib/srcnn.lua
  15. 1 2
      lib/w2nn.lua
  16. 0 0
      models/anime_style_art/noise1_model.json
  17. 42 75
      models/anime_style_art/noise1_model.t7
  18. 0 0
      models/anime_style_art/noise2_model.json
  19. 42 75
      models/anime_style_art/noise2_model.t7
  20. 0 0
      models/anime_style_art/noise3_model.json
  21. 128 0
      models/anime_style_art/noise3_model.t7
  22. 0 0
      models/anime_style_art/scale2.0x_model.json
  23. 42 75
      models/anime_style_art/scale2.0x_model.t7
  24. 0 0
      models/anime_style_art_rgb/noise1_model.json
  25. 0 0
      models/anime_style_art_rgb/noise2_model.json
  26. 0 0
      models/anime_style_art_rgb/noise3_model.json
  27. 128 0
      models/anime_style_art_rgb/noise3_model.t7
  28. 0 0
      models/anime_style_art_rgb/scale2.0x_model.json
  29. 42 75
      models/anime_style_art_rgb/scale2.0x_model.t7
  30. 0 0
      models/photo/noise1_model.json
  31. 0 0
      models/photo/noise2_model.json
  32. 0 0
      models/photo/noise3_model.json
  33. 128 0
      models/photo/noise3_model.t7
  34. 0 0
      models/photo/scale2.0x_model.json
  35. 0 25
      tools/cleanup_model.lua
  36. 0 5
      tools/export_model.lua
  37. 29 12
      train.lua
  38. 0 5
      train.sh
  39. 1 4
      train_photo.sh
  40. 15 36
      waifu2x.lua
  41. 20 4
      web.lua
  42. 1 0
      webgen/locales/en.yml
  43. 1 0
      webgen/locales/ja.yml
  44. 6 0
      webgen/templates/index.html.erb

+ 38 - 0
appendix/benchmark.md

@@ -0,0 +1,38 @@
+# Benchmark results
+
+
+## dataset
+
+    photo_set: 300 various photos.
+    art_set  : 90 artworks (PNG only).
+
+## 2x upscaling model
+
+| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo   | ukbench|
+|---------------|----------------------|------------------------|---------|--------|
+| photo\_test   |                29.83 |                  29.81 |**29.89**|  29.86 |
+| art\_test     |                36.02 |               **36.24**|  34.92  |  34.85 |
+
+The evaluation metric is PSNR(Y only), higher is better.
+
+## Denosing level 1 model
+
+| Dataset/Model            | anime\_style\_art | anime\_style\_art\_rgb | photo   |
+|--------------------------|-------------------|------------------------|---------|
+| photo\_test Quality 80   |             36.07 |               **36.20**|   36.01 |
+| photo\_test Quality 50,45|             31.72 |                 32.01  |**32.31**|
+| art\_test Quality 80     |             40.39 |               **42.48**|   40.35 |
+| art\_test Quality 50,45  |             35.45 |               **36.70**|   36.27 |
+
+The evaluation metric is PSNR(RGB), higher is better.
+
+## Denosing level 2 model
+
+| Dataset/Model            | anime\_style\_art | anime\_style\_art\_rgb | photo   |
+|--------------------------|-------------------|------------------------|---------|
+| photo\_test Quality 80   |             34.03 |                  34.42 |**36.06**|
+| photo\_test Quality 50,45|             31.95 |                  32.31 |**32.42**|
+| art\_test Quality 80     |             39.20 |               **41.12**|   40.48 |
+| art\_test Quality 50,45  |             36.14 |               **37.78**|   36.55 |
+
+The evaluation metric is PSNR(RGB), higher is better.

+ 6 - 0
assets/index.es.html

@@ -106,6 +106,12 @@
 		Alto
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		Highest
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    Es necesario utilizar la reducción de ruido si la imagen dispone de artefactos de compresión; de lo contrario podría producir el efecto opuesto.

+ 6 - 0
assets/index.fr.html

@@ -106,6 +106,12 @@
 		Haute
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		Highest
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    Il est nécessaire d'utiliser la réduction du bruit si l'image possède du bruit. Autrement, cela risque de causer l'effet opposé.

+ 6 - 0
assets/index.html

@@ -106,6 +106,12 @@
 		High
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		Highest
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    You need use noise reduction if image actually has noise or it may cause opposite effect.

+ 6 - 0
assets/index.ja.html

@@ -106,6 +106,12 @@
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		最高
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    ノイズ除去は細部が消えることがあります。JPEGノイズがある場合に使用します。

+ 6 - 0
assets/index.pt.html

@@ -106,6 +106,12 @@
 		Alta
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		Highest
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    Quando usando a escala 2x, Nós nunca recomendamos usar um nível alto de redução de ruído, quase sempre deixa a imagem pior, faz sentido apenas para casos raros quando a imagem tinha uma qualidade muito má desde o começo.

+ 6 - 0
assets/index.ru.html

@@ -106,6 +106,12 @@
 		Сильно
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		Highest
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект.

+ 6 - 1
lib/LeakyReLU.lua

@@ -17,7 +17,7 @@ function LeakyReLU:updateOutput(input)
    
    return self.output
 end
- 
+
 function LeakyReLU:updateGradInput(input, gradOutput)
    self.gradInput:resizeAs(gradOutput)
    -- filter positive
@@ -29,3 +29,8 @@ function LeakyReLU:updateGradInput(input, gradOutput)
    
    return self.gradInput
 end
+
+function LeakyReLU:clearState()
+   nn.utils.clear(self, 'negative')
+   return parent.clearState(self)
+end

+ 19 - 0
lib/PSNRCriterion.lua

@@ -0,0 +1,19 @@
+local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
+
+function PSNRCriterion:__init()
+   parent.__init(self)
+   self.image = torch.Tensor()
+   self.diff = torch.Tensor()
+end
+function PSNRCriterion:updateOutput(input, target)
+   self.image:resizeAs(input):copy(input)
+   self.image:clamp(0.0, 1.0)
+   self.diff:resizeAs(self.image):copy(self.image)
+   
+   local mse = math.max(self.diff:add(-1, target):pow(2):mean(), (0.1/255)^2)
+   self.output = 10 * math.log10(1.0 / mse)
+   return self.output
+end
+function PSNRCriterion:updateGradInput(input, target)
+   error("PSNRCriterion does not support backward")
+end

+ 1 - 46
lib/cleanup_model.lua

@@ -1,48 +1,3 @@
--- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
-
-local function zeroDataSize(data)
-   if type(data) == 'table' then
-      for i = 1, #data do
-	 data[i] = zeroDataSize(data[i])
-      end
-   elseif type(data) == 'userdata' then
-      data = torch.Tensor():typeAs(data)
-   end
-   return data
-end
--- Resize the output, gradInput, etc temporary tensors to zero (so that the
--- on disk size is smaller)
-local function cleanupModel(node)
-   if node.output ~= nil then
-      node.output = zeroDataSize(node.output)
-   end
-   if node.gradInput ~= nil then
-      node.gradInput = zeroDataSize(node.gradInput)
-   end
-   if node.finput ~= nil then
-      node.finput = zeroDataSize(node.finput)
-   end
-   if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
-      if node.negative ~= nil then
-	 node.negative = zeroDataSize(node.negative)
-      end
-   end
-   if tostring(node) == "nn.Dropout" then
-      if node.noise ~= nil then
-	 node.noise = zeroDataSize(node.noise)
-      end
-   end
-   -- Recurse on nodes with 'modules'
-   if (node.modules ~= nil) then
-     if (type(node.modules) == 'table') then
-	for i = 1, #node.modules do
-	   local child = node.modules[i]
-	   cleanupModel(child)
-	end
-     end
-   end
-end
 function w2nn.cleanup_model(model)
-   cleanupModel(model)
-   return model
+   return model:clearState()
 end

+ 4 - 2
lib/minibatch_adam.lua

@@ -2,12 +2,13 @@ require 'optim'
 require 'cutorch'
 require 'xlua'
 
-local function minibatch_adam(model, criterion,
+local function minibatch_adam(model, criterion, eval_metric,
 			      train_x, train_y,
 			      config)
    local parameters, gradParameters = model:getParameters()
    config = config or {}
    local sum_loss = 0
+   local sum_eval = 0
    local count_loss = 0
    local batch_size = config.xBatchSize or 32
    local shuffle = torch.randperm(train_x:size(1))
@@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
 	 gradParameters:zero()
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
+	 sum_eval = sum_eval + eval_metric:forward(output, targets)
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
 	 model:backward(inputs, criterion:backward(output, targets))
@@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
    end
    xlua.progress(train_x:size(1), train_x:size(1))
    
-   return { loss = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
 end
 
 return minibatch_adam

+ 5 - 20
lib/pairwise_transform.lua

@@ -82,30 +82,14 @@ local function active_cropping(x, y, size, p, tries)
    end
 end
 function pairwise_transform.scale(src, scale, size, offset, n, options)
-   local filters;
-
-   if options.style == "photo" then
-      filters = {
-	 "Box", "lanczos", "Catrom"
-      }
-   else
-      filters = {
-	 "Box","Box",  -- 0.012756949974688
-	 "Blackman",   -- 0.013191924552285
-	 --"Catrom",     -- 0.013753536746706
-	 --"Hanning",    -- 0.013761314529647
-	 --"Hermite",    -- 0.013850225205266
-	 "Sinc",   -- 0.014095824314306
-	 "Lanczos",       -- 0.014244299255442
-      }
-   end
+   local filters = options.downsampling_filters
    local unstable_region_offset = 8
-   local downscale_filter = filters[torch.random(1, #filters)]
+   local downsampling_filter = filters[torch.random(1, #filters)]
    local y = preprocess(src, size, options)
    assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
    local down_scale = 1.0 / scale
    local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
-				     y:size(2) * down_scale, downscale_filter),
+				     y:size(2) * down_scale, downsampling_filter),
 			 y:size(3), y:size(2))
    x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
 		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
@@ -184,7 +168,8 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
       if level == 1 then
 	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
 					 size, offset, n, options)
-      elseif level == 2 then
+      elseif level == 2 or level == 3 then
+	 -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
 	 local r = torch.uniform()
 	 if r > 0.6 then
 	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},

+ 15 - 3
lib/settings.lua

@@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", 'method to training (noise|scale)')
-cmd:option("-noise_level", 1, '(1|2)')
+cmd:option("-noise_level", 1, '(1|2|3)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
@@ -42,16 +42,24 @@ cmd:option("-epoch", 30, 'number of epochs to run')
 cmd:option("-thread", -1, 'number of CPU threads')
 cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
 cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
-cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
+cmd:option("-validation_crops", 160, 'number of cropping region per image in validation')
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
 cmd:option("-save_history", 0, 'save all model (0|1)')
+cmd:option("-plot", 0, 'plot loss chart(0|1)')
+cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
    settings[k] = v
 end
+if settings.plot == 1 then
+   settings.plot = true
+   require 'gnuplot'
+else
+   settings.plot = false
+end
 if settings.save_history == 1 then
    settings.save_history = true
 else
@@ -88,10 +96,14 @@ if not (settings.style == "art" or
 	settings.style == "photo") then
    error(string.format("unknown style: %s", settings.style))
 end
-
 if settings.thread > 0 then
    torch.setnumthreads(tonumber(settings.thread))
 end
+if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
+   settings.downsampling_filters = settings.downsampling_filters:split(",")
+else
+   settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
+end
 
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)

+ 10 - 0
lib/srcnn.lua

@@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
    end
 end
 
+function nn.SpatialConvolutionMM:clearState()
+   if self.gradWeight then
+      self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
+   end
+   if self.gradBias then
+      self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
+   end
+   return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
+end
+
 function srcnn.channels(model)
    return model:get(model:size() - 1).weight:size(1)
 end

+ 1 - 2
lib/w2nn.lua

@@ -19,8 +19,7 @@ else
    require 'LeakyReLU'
    require 'LeakyReLU_deprecated'
    require 'DepthExpand2x'
-   require 'WeightedMSECriterion'
+   require 'PSNRCriterion'
    require 'ClippedWeightedHuberCriterion'
-   require 'cleanup_model'
    return w2nn
 end

ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art/noise1_model.json


ファイルの差分が大きいため隠しています
+ 42 - 75
models/anime_style_art/noise1_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art/noise2_model.json


ファイルの差分が大きいため隠しています
+ 42 - 75
models/anime_style_art/noise2_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art/noise3_model.json


ファイルの差分が大きいため隠しています
+ 128 - 0
models/anime_style_art/noise3_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art/scale2.0x_model.json


ファイルの差分が大きいため隠しています
+ 42 - 75
models/anime_style_art/scale2.0x_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art_rgb/noise1_model.json


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art_rgb/noise2_model.json


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art_rgb/noise3_model.json


ファイルの差分が大きいため隠しています
+ 128 - 0
models/anime_style_art_rgb/noise3_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/anime_style_art_rgb/scale2.0x_model.json


ファイルの差分が大きいため隠しています
+ 42 - 75
models/anime_style_art_rgb/scale2.0x_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/photo/noise1_model.json


ファイルの差分が大きいため隠しています
+ 0 - 0
models/photo/noise2_model.json


ファイルの差分が大きいため隠しています
+ 0 - 0
models/photo/noise3_model.json


ファイルの差分が大きいため隠しています
+ 128 - 0
models/photo/noise3_model.t7


ファイルの差分が大きいため隠しています
+ 0 - 0
models/photo/scale2.0x_model.json


+ 0 - 25
tools/cleanup_model.lua

@@ -1,25 +0,0 @@
-require 'pl'
-local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
-package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
-
-require 'w2nn'
-torch.setdefaulttensortype("torch.FloatTensor")
-
-local cmd = torch.CmdLine()
-cmd:text()
-cmd:text("cleanup model")
-cmd:text("Options:")
-cmd:option("-model", "./model.t7", 'path of model file')
-cmd:option("-iformat", "binary", 'input format')
-cmd:option("-oformat", "binary", 'output format')
-
-local opt = cmd:parse(arg)
-local model = torch.load(opt.model, opt.iformat)
-if model then
-   w2nn.cleanup_model(model)
-   model:cuda()
-   model:evaluate()
-   torch.save(opt.model, model, opt.oformat)
-else
-   error("model not found")
-end

+ 0 - 5
tools/export_model.lua

@@ -24,11 +24,6 @@ function export(model, output)
       }
       table.insert(jmodules, jmod)
    end
-   jmodules[1].color = "RGB"
-   jmodules[1].gamma = 0
-   jmodules[#jmodules].color = "RGB"
-   jmodules[#jmodules].gamma = 0
-   
    local fp = io.open(output, "w")
    if not fp then
       error("IO Error: " .. output)

+ 29 - 12
train.lua

@@ -78,7 +78,11 @@ local function create_criterion(model)
       weight[3]:fill(0.11448 * 3) -- B
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    else
-      return nn.MSECriterion():cuda()
+      local offset = reconstruct.offset_size(model)
+      local output_w = settings.crop_size - offset * 2
+      local weight = torch.Tensor(1, output_w * output_w)
+      weight[1]:fill(1.0)
+      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    end
 end
 local function transformer(x, is_validation, n, offset)
@@ -91,8 +95,8 @@ local function transformer(x, is_validation, n, offset)
    local active_cropping_rate = nil
    local active_cropping_tries = nil
    if is_validation then
-      active_cropping_rate = 0
-      active_cropping_tries = 0
+      active_cropping_rate = settings.active_cropping_rate
+      active_cropping_tries = settings.active_cropping_tries
       random_color_noise_rate = 0.0
       random_overlay_rate = 0.0
    else
@@ -108,6 +112,7 @@ local function transformer(x, is_validation, n, offset)
 				      settings.crop_size, offset,
 				      n,
 				      {
+					 downsampling_filters = settings.downsampling_filters,
 					 random_half_rate = settings.random_half_rate,
 					 random_color_noise_rate = random_color_noise_rate,
 					 random_overlay_rate = random_overlay_rate,
@@ -153,8 +158,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
       end
    end
 end
-
+local function plot(train, valid)
+   gnuplot.plot({
+	 {'training', torch.Tensor(train), '-'},
+	 {'validation', torch.Tensor(valid), '-'}})
+end
 local function train()
+   local hist_train = {}
+   local hist_valid = {}
    local LR_MIN = 1.0e-5
    local model = srcnn.create(settings.method, settings.backend, settings.color)
    local offset = reconstruct.offset_size(model)
@@ -162,6 +173,7 @@ local function train()
       return transformer(x, is_validation, n, offset)
    end
    local criterion = create_criterion(model)
+   local eval_metric = w2nn.PSNRCriterion():cuda()
    local x = torch.load(settings.images)
    local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
@@ -175,7 +187,7 @@ local function train()
    elseif settings.color == "rgb" then
       ch = 3
    end
-   local best_score = 100000.0
+   local best_score = 0.0
    print("# make validation-set")
    local valid_xy = make_validation_set(valid_x, pairwise_func,
 					settings.validation_crops,
@@ -196,19 +208,24 @@ local function train()
       print("# " .. epoch)
       resampling(x, y, train_x, pairwise_func)
       for i = 1, settings.inner_epoch do
-	 print(minibatch_adam(model, criterion, x, y, adam_config))
+	 local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 print(train_score)
 	 model:evaluate()
 	 print("# validation")
-	 local score = validate(model, criterion, valid_xy)
-	 if score < best_score then
+	 local score = validate(model, eval_metric, valid_xy)
+
+	 table.insert(hist_train, train_score.PSNR)
+	 table.insert(hist_valid, score)
+	 if settings.plot then
+	    plot(hist_train, hist_valid)
+	 end
+	 if score > best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0
 	    best_score = score
 	    print("* update best model")
 	    if settings.save_history then
-	       local model_clone = model:clone()
-	       w2nn.cleanup_model(model_clone)
-	       torch.save(string.format(settings.model_file, epoch, i), model_clone)
+	       torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.%d-%d.png"):format(settings.noise_level,
@@ -221,7 +238,7 @@ local function train()
 		  save_test_scale(model, test_image, log)
 	       end
 	    else
-	       torch.save(settings.model_file, model)
+	       torch.save(settings.model_file, model:clearState(), "ascii")
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.png"):format(settings.noise_level))

+ 0 - 5
train.sh

@@ -3,10 +3,5 @@
 th convert_data.lua
 
 th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
-th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
-
 th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
-th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
-
 th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
-th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

+ 1 - 4
train_photo.sh

@@ -2,11 +2,8 @@
 
 th convert_data.lua -style photo -data_dir ./data/photo -model_dir models/photo
 
-th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo_uk -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
-th tools/cleanup_model.lua -model models/photo/scale2.0x_model.t7 -oformat ascii
+th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
 
 th train.lua -style photo -method noise -noise_level 1 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.6 -epoch 33
-th tools/cleanup_model.lua -model models/photo/noise1_model.t7 -oformat ascii
 
 th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.8 -epoch 38
-th tools/cleanup_model.lua -model models/photo/noise2_model.t7 -oformat ascii

+ 15 - 36
waifu2x.lua

@@ -69,7 +69,8 @@ local function convert_image(opt)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
-   local model_path, noise1_model, noise2_model, scale_model
+   local model_path, scale_model
+   local noise_model = {}
    local scale_f, image_f
    if opt.tta == 1 then
       scale_f = reconstruct.scale_tta
@@ -84,16 +85,10 @@ local function convert_frames(opt)
       if not scale_model then
 	 error("Load Error: " .. model_path)
       end
-   elseif opt.m == "noise" and opt.noise_level == 1 then
-      model_path = path.join(opt.model_dir, "noise1_model.t7")
-      noise1_model = torch.load(model_path, "ascii")
-      if not noise1_model then
-	 error("Load Error: " .. model_path)
-      end
-   elseif opt.m == "noise" and opt.noise_level == 2 then
-      model_path = path.join(opt.model_dir, "noise2_model.t7")
-      noise2_model = torch.load(model_path, "ascii")
-      if not noise2_model then
+   elseif opt.m == "noise" then
+      model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
+      noise_model[opt.noise_level] = torch.load(model_path, "ascii")
+      if not noise_model[opt.noise_level] then
 	 error("Load Error: " .. model_path)
       end
    elseif opt.m == "noise_scale" then
@@ -102,18 +97,10 @@ local function convert_frames(opt)
       if not scale_model then
 	 error("Load Error: " .. model_path)
       end
-      if opt.noise_level == 1 then
-	 model_path = path.join(opt.model_dir, "noise1_model.t7")
-	 noise1_model = torch.load(model_path, "ascii")
-	 if not noise1_model then
-	    error("Load Error: " .. model_path)
-	 end
-      elseif opt.noise_level == 2 then
-	 model_path = path.join(opt.model_dir, "noise2_model.t7")
-	 noise2_model = torch.load(model_path, "ascii")
-	 if not noise2_model then
-	    error("Load Error: " .. model_path)
-	 end
+      model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
+      noise_model[opt.noise_level] = torch.load(model_path, "ascii")
+      if not noise_model[opt.noise_level] then
+	 error("Load Error: " .. model_path)
       end
    end
    local fp = io.open(opt.l)
@@ -130,24 +117,16 @@ local function convert_frames(opt)
       if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
 	 local x, alpha = image_loader.load_float(lines[i])
 	 local new_x = nil
-	 if opt.m == "noise" and opt.noise_level == 1 then
-	    new_x = image_f(noise1_model, x, opt.crop_size)
-	    new_x = alpha_util.composite(new_x, alpha)
-	 elseif opt.m == "noise" and opt.noise_level == 2 then
-	    new_x = image_f(noise2_model, x, opt.crop_size)
+	 if opt.m == "noise" then
+	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
-	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    x = image_f(noise1_model, x, opt.crop_size)
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
-	    new_x = alpha_util.composite(new_x, alpha, scale_model)
-	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
+	 elseif opt.m == "noise_scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    x = image_f(noise2_model, x, opt.crop_size)
+	    x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
@@ -185,7 +164,7 @@ local function waifu2x()
    cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
    cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
-   cmd:option("-noise_level", 1, '(1|2)')
+   cmd:option("-noise_level", 1, '(1|2|3)')
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")

+ 20 - 4
web.lua

@@ -38,12 +38,14 @@ if cudnn then
 end
 local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
 local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
+local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
 local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
 local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
-local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
+local art_noise3_model = torch.load(path.join(ART_MODEL_DIR, "noise3_model.t7"), "ascii")
 local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
 local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
 local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
+local photo_noise3_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), "ascii")
 local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
 local CACHE_DIR = path.join(ROOT, "cache")
 local MAX_NOISE_IMAGE = 2560 * 2560
@@ -114,7 +116,7 @@ local function get_image(req)
 end
 local function cleanup_model(model)
    if CLEANUP_MODEL then
-      w2nn.cleanup_model(model) -- release GPU memory
+      model:clearState() -- release GPU memory
    end
 end
 local function convert(x, alpha, options)
@@ -151,9 +153,12 @@ local function convert(x, alpha, options)
 	 elseif options.method == "noise1" then
 	    x = reconstruct.image(art_noise1_model, x)
 	    cleanup_model(art_noise1_model)
-	 else -- options.method == "noise2"
+	 elseif options.method == "noise2" then
 	    x = reconstruct.image(art_noise2_model, x)
 	    cleanup_model(art_noise2_model)
+	 elseif options.method == "noise3" then
+	    x = reconstruct.image(art_noise3_model, x)
+	    cleanup_model(art_noise3_model)
 	 end
       else -- photo
 	 if options.border then
@@ -174,6 +179,9 @@ local function convert(x, alpha, options)
 	 elseif options.method == "noise2" then
 	    x = reconstruct.image(photo_noise2_model, x)
 	    cleanup_model(photo_noise2_model)
+	 elseif options.method == "noise3" then
+	    x = reconstruct.image(photo_noise3_model, x)
+	    cleanup_model(photo_noise3_model)
 	 end
       end
       image_loader.save_png(cache_file, x)
@@ -229,17 +237,25 @@ function APIHandler:post()
 				   alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 elseif noise == 2 then
-	    prefix = style .. "_noise1_"
+	    prefix = style .. "_noise2_"
 	    x = convert(x, alpha, {method = "noise2", style = style,
 				   prefix = prefix .. hash, 
 				   alpha_prefix = alpha_prefix, border = border})
 	    border = false
+	 elseif noise == 3 then
+	    prefix = style .. "_noise3_"
+	    x = convert(x, alpha, {method = "noise3", style = style,
+				   prefix = prefix .. hash, 
+				   alpha_prefix = alpha_prefix, border = border})
+	    border = false
 	 end
 	 if scale == 1 or scale == 2 then
 	    if noise == 1 then
 	       prefix = style .. "_noise1_scale_"
 	    elseif noise == 2 then
 	       prefix = style .. "_noise2_scale_"
+	    elseif noise == 3 then
+	       prefix = style .. "_noise3_scale_"
 	    else
 	       prefix = style .. "_scale_"
 	    end

+ 1 - 0
webgen/locales/en.yml

@@ -14,6 +14,7 @@ expect_jpeg: expect JPEG artifact
 nr_none: None
 nr_medium: Medium
 nr_high: High
+nr_highest: Highest
 nr_hint: "You need use noise reduction if image actually has noise or it may cause opposite effect."
 upscaling: Upscaling
 up_none: None

+ 1 - 0
webgen/locales/ja.yml

@@ -14,6 +14,7 @@ expect_jpeg: JPEGノイズを想定
 nr_none: なし
 nr_medium: 中
 nr_high: 高
+nr_highest: 最高
 nr_hint: "ノイズ除去は細部が消えることがあります。JPEGノイズがある場合に使用します。"
 upscaling: 拡大
 up_none: なし

+ 6 - 0
webgen/templates/index.html.erb

@@ -106,6 +106,12 @@
 		<%= t[:nr_high] %>
 	      </span>
 	    </label>
+	    <label>
+	      <input type="radio" name="noise" class="radio" value="3">
+	      <span class="r-text">
+		<%= t[:nr_highest] %>
+	      </span>
+	    </label>
 	  </div>
 	  <div class="option-hint">
 	    <%= t[:nr_hint] %>

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません