Jelajahi Sumber

Merge pull request #71 from nagadomi/photo

Add support for photo scaling/jpeg denoising.
nagadomi 9 tahun lalu
induk
melakukan
e4b239ec14

+ 11 - 1
README.md

@@ -1,6 +1,7 @@
 # waifu2x
 
-Image Super-Resolution for anime-style-art using Deep Convolutional Neural Networks.
+Image Super-Resolution for Anime-style art using Deep Convolutional Neural Networks.
+And it supports photo.
 
 Demo-Application can be found at http://waifu2x.udp.jp/ .
 
@@ -123,6 +124,15 @@ th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.
 
 See also `th waifu2x.lua -h`.
 
+### Using photo model
+
+Please add `-model_dir models/photo` to command line option, if you want to use photo model.
+For example,
+
+```
+th waifu2x.lua -model_dir models/photo -m scale -i input_image.png -o output_image.png
+```
+
 ### Video Encoding
 
 \* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.

+ 8 - 2
assets/index.html

@@ -4,7 +4,8 @@
     <meta charset="UTF-8">
     <title>waifu2x</title>
     <link href="style.css" rel="stylesheet" type="text/css">
-    <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
+    <script type="text/javascript" src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
+    <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
     <script type="text/javascript" src="ui.js"></script>
   </head>
   <body>
@@ -17,7 +18,7 @@
       <a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
     </div>
     <div class="about">
-      <div>Single-Image Super-Resolution for anime/fan-art using Deep Convolutional Neural Networks. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
+      <div>Single-Image Super-Resolution for Anime-Style Art using Deep Convolutional Neural Networks. And it supports photo. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
     </div>
     <form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
       <fieldset>
@@ -32,6 +33,11 @@
           Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
         </div>
       </fieldset>
+      <fieldset>
+        <legend>Style</legend>
+        <label><input type="radio" name="style" value="art" checked>Art</label>
+        <label><input type="radio" name="style" value="photo">Photo</label>
+      </fieldset>
       <fieldset class="noise-field">
         <legend>Noise Reduction (expect JPEG Artifact)</legend>
         <label><input type="radio" name="noise" value="0"> None</label>

+ 7 - 1
assets/index.ja.html

@@ -5,6 +5,7 @@
     <link href="style.css" rel="stylesheet" type="text/css">
     <title>waifu2x</title>
     <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
+    <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
     <script type="text/javascript" src="ui.js"></script>    
   </head>
   <body>
@@ -17,7 +18,7 @@
       <a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
     </div>
     <div class="about">
-      <div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
+      <div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. 写真にも対応. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
     </div>
     <form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
       <fieldset>
@@ -32,6 +33,11 @@
           制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
         </div>
       </fieldset>
+      <fieldset>
+        <legend>スタイル</legend>
+        <label><input type="radio" name="style" value="art" checked>イラスト</label>
+        <label><input type="radio" name="style" value="photo">写真</label>
+      </fieldset>
       <fieldset class="noise-field">
         <legend>ノイズ除去 (JPEGノイズを想定)</legend>
         <label><input type="radio" name="noise" value="0"> なし</label>

+ 6 - 0
assets/index.ru.html

@@ -6,6 +6,7 @@
     <title>waifu2x</title>
     <link href="style.css" rel="stylesheet" type="text/css">
     <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
+    <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
     <script type="text/javascript" src="ui.js"></script>
   </head>
   <body>
@@ -33,6 +34,11 @@
           Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
         </div>
       </fieldset>
+      <fieldset>
+        <legend>Стиль</legend>
+        <label><input type="radio" name="style" value="art" checked>Произведение искусства</label>
+        <label><input type="radio" name="style" value="photo">фото</label>
+      </fieldset>
       <fieldset class="noise-field">
 	<legend>Устранение шума (артефактов JPEG)</legend>
         <label><input type="radio" name="noise" value="0"> Нет</label>

+ 16 - 24
assets/ui.js

@@ -1,4 +1,5 @@
 $(function (){
+    var expires = 365;
     function clear_file() {
 	var new_file = $("#file").clone();
 	new_file.change(clear_url);
@@ -19,6 +20,7 @@ $(function (){
 	} else {
 	    $("h1").html("w<s>/a/</s>ifu2x");
 	}
+	$.cookie("style", checked.val(), {expires: expires});
     }
     function on_change_noise_level(e)
     {
@@ -30,6 +32,7 @@ $(function (){
 	if (checked.val() != 0) {
 	    checked.parents("label").css("font-weight", "bold");
 	}
+	$.cookie("noise", checked.val(), {expires: expires});
     }
     function on_change_scale_factor(e)
     {
@@ -41,40 +44,29 @@ $(function (){
 	if (checked.val() != 0) {
 	    checked.parents("label").css("font-weight", "bold");
 	}
+	$.cookie("scale", checked.val(), {expires: expires});
     }
-    function on_change_white_noise(e)
+    function restore_from_cookie()
     {
-	$("input[name=white_noise]").parents("label").each(
-	    function (i, elm) {
-		$(elm).css("font-weight", "normal");
-	    });
-	var checked = $("input[name=white_noise]:checked");
-	if (checked.val() != 0) {
-	    checked.parents("label").css("font-weight", "bold");
+	if ($.cookie("style")) {
+	    $("input[name=style]").filter("[value=" + $.cookie("style") + "]").prop("checked", true)
 	}
-    }
-    function on_click_experimental_button(e)
-    {
-	if ($(this).hasClass("close")) {
-	    $(".experimental .container").show();
-	    $(this).removeClass("close");
-	} else {
-	    $(".experimental .container").hide();
-	    $(this).addClass("close");
+	if ($.cookie("noise")) {
+	    $("input[name=noise]").filter("[value=" + $.cookie("noise") + "]").prop("checked", true)
+	}
+	if ($.cookie("scale")) {
+	    $("input[name=scale]").filter("[value=" + $.cookie("scale") + "]").prop("checked", true)
 	}
-	e.preventDefault();
-	e.stopPropagation();
     }
     
     $("#url").change(clear_file);
     $("#file").change(clear_url);
-    //$("input[name=style]").change(on_change_style);
+    $("input[name=style]").change(on_change_style);
     $("input[name=noise]").change(on_change_noise_level);
     $("input[name=scale]").change(on_change_scale_factor);
-    //$("input[name=white_noise]").change(on_change_white_noise);
-    //$(".experimental .button").click(on_click_experimental_button)
-    
-    //on_change_style();
+
+    restore_from_cookie();
+    on_change_style();
     on_change_scale_factor();
     on_change_noise_level();
 })

+ 20 - 0
lib/data_augmentation.lua

@@ -1,5 +1,6 @@
 require 'image'
 local iproc = require 'iproc'
+local gm = require 'graphicsmagick'
 
 local data_augmentation = {}
 
@@ -50,6 +51,25 @@ function data_augmentation.overlay(src, p)
       return src
    end
 end
+function data_augmentation.unsharp_mask(src, p)
+   if torch.uniform() < p then
+      local radius = 0 -- auto
+      local sigma = torch.uniform(0.5, 1.5)
+      local amount = torch.uniform(0.1, 0.9)
+      local threshold = torch.uniform(0.0, 0.05)
+      local unsharp = gm.Image(src, "RGB", "DHW"):
+	 unsharpMask(radius, sigma, amount, threshold):
+	 toTensor("float", "RGB", "DHW")
+      
+      if src:type() == "torch.ByteTensor" then
+	 return iproc.float2byte(unsharp)
+      else
+	 return unsharp
+      end
+   else
+      return src
+   end
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)

+ 20 - 19
lib/minibatch_adam.lua

@@ -3,30 +3,32 @@ require 'cutorch'
 require 'xlua'
 
 local function minibatch_adam(model, criterion,
-			      train_x,
-			      config, transformer,
-			      input_size, target_size)
+			      train_x, train_y,
+			      config)
    local parameters, gradParameters = model:getParameters()
    config = config or {}
    local sum_loss = 0
    local count_loss = 0
    local batch_size = config.xBatchSize or 32
-   local shuffle = torch.randperm(#train_x)
+   local shuffle = torch.randperm(train_x:size(1))
    local c = 1
-   local inputs = torch.Tensor(batch_size,
-			       input_size[1], input_size[2], input_size[3]):cuda()
-   local targets = torch.Tensor(batch_size,
-				target_size[1] * target_size[2] * target_size[3]):cuda()
    local inputs_tmp = torch.Tensor(batch_size,
-			       input_size[1], input_size[2], input_size[3])
+				   train_x:size(2), train_x:size(3), train_x:size(4)):zero()
    local targets_tmp = torch.Tensor(batch_size,
-				    target_size[1] * target_size[2] * target_size[3])
-   for t = 1, #train_x do
-      xlua.progress(t, #train_x)
-      local xy = transformer(train_x[shuffle[t]], false, batch_size)
-      for i = 1, #xy do
-         inputs_tmp[i]:copy(xy[i][1])
-	 targets_tmp[i]:copy(xy[i][2])
+				    train_y:size(2)):zero()
+   local inputs = inputs_tmp:clone():cuda()
+   local targets = targets_tmp:clone():cuda()
+
+   print("## update")
+   for t = 1, train_x:size(1), batch_size do
+      if t + batch_size -1 > train_x:size(1) then
+	 break
+      end
+      xlua.progress(t, train_x:size(1))
+
+      for i = 1, batch_size do
+         inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
+	 targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
       end
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
@@ -43,13 +45,12 @@ local function minibatch_adam(model, criterion,
 	 return f, gradParameters
       end
       optim.adam(feval, parameters, config)
-      
       c = c + 1
-      if c % 20 == 0 then
+      if c % 50 == 0 then
 	 collectgarbage()
       end
    end
-   xlua.progress(#train_x, #train_x)
+   xlua.progress(train_x:size(1), train_x:size(1))
    
    return { loss = sum_loss / count_loss}
 end

+ 43 - 34
lib/pairwise_transform.lua

@@ -7,7 +7,7 @@ local pairwise_transform = {}
 
 local function random_half(src, p)
    if torch.uniform() < p then
-      local filter = ({"Box","Box","Blackman","Sinc","Lanczos"})[torch.random(1, 5)]
+      local filter = ({"Box","Box","Blackman","Sinc","Lanczos", "Catrom"})[torch.random(1, 6)]
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
       return src
@@ -38,6 +38,7 @@ local function preprocess(src, crop_size, options)
    dest = data_augmentation.flip(dest)
    dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
    dest = data_augmentation.overlay(dest, options.random_overlay_rate)
+   dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
    dest = data_augmentation.shift_1px(dest)
    
    return dest
@@ -45,6 +46,10 @@ end
 local function active_cropping(x, y, size, p, tries)
    assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
    local r = torch.uniform()
+   local t = "float"
+   if x:type() == "torch.ByteTensor" then
+      t = "byte"
+   end
    if p < r then
       local xi = torch.random(0, y:size(3) - (size + 1))
       local yi = torch.random(0, y:size(2) - (size + 1))
@@ -52,6 +57,10 @@ local function active_cropping(x, y, size, p, tries)
       local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
       return xc, yc
    else
+      local lowres = gm.Image(x, "RGB", "DHW"):
+	 size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
+	 size(x:size(3), x:size(2), "Box"):
+	 toTensor(t, "RGB", "DHW")
       local best_se = 0.0
       local best_xc, best_yc
       local m = torch.FloatTensor(x:size(1), size, size)
@@ -59,13 +68,13 @@ local function active_cropping(x, y, size, p, tries)
 	 local xi = torch.random(0, y:size(3) - (size + 1))
 	 local yi = torch.random(0, y:size(2) - (size + 1))
 	 local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
-	 local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
+	 local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
 	 local xcf = iproc.byte2float(xc)
-	 local ycf = iproc.byte2float(yc)
-	 local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
+	 local lcf = iproc.byte2float(lc)
+	 local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
 	 if se >= best_se then
 	    best_xc = xcf
-	    best_yc = ycf
+	    best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
 	    best_se = se
 	 end
       end
@@ -73,15 +82,23 @@ local function active_cropping(x, y, size, p, tries)
    end
 end
 function pairwise_transform.scale(src, scale, size, offset, n, options)
-   local filters = {
-      "Box","Box",  -- 0.012756949974688
-      "Blackman",   -- 0.013191924552285
-      --"Cartom",     -- 0.013753536746706
-      --"Hanning",    -- 0.013761314529647
-      --"Hermite",    -- 0.013850225205266
-      "Sinc",   -- 0.014095824314306
-      "Lanczos",       -- 0.014244299255442
-   }
+   local filters;
+
+   if options.style == "photo" then
+      filters = {
+	 "Box", "lanczos", "Catrom"
+      }
+   else
+      filters = {
+	 "Box","Box",  -- 0.012756949974688
+	 "Blackman",   -- 0.013191924552285
+	 --"Catrom",     -- 0.013753536746706
+	 --"Hanning",    -- 0.013761314529647
+	 --"Hermite",    -- 0.013850225205266
+	 "Sinc",   -- 0.014095824314306
+	 "Lanczos",       -- 0.014244299255442
+      }
+   end
    local unstable_region_offset = 8
    local downscale_filter = filters[torch.random(1, #filters)]
    local y = preprocess(src, size, options)
@@ -122,10 +139,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg"):depth(8)
-      if options.jpeg_sampling_factors == 444 then
-	 x:samplingFactors({1.0, 1.0, 1.0})
-      else -- 420
+      if torch.uniform() < options.jpeg_chroma_subsampling_rate then
+	 -- YUV 420
 	 x:samplingFactors({2.0, 1.0, 1.0})
+      else
+	 -- YUV 444
+	 x:samplingFactors({1.0, 1.0, 1.0})
       end
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
@@ -188,23 +207,10 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
 	 error("unknown noise level: " .. level)
       end
    elseif style == "photo" then
-      if level == 1 then
-	 return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
-					 size, offset, n,
-					 options)
-      elseif level == 2 then
-	 if torch.uniform() > 0.6 then
-	    return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
-					    size, offset, n, options)
-	 else
-	    local quality1 = torch.random(40, 60)
-	    local quality2 = quality1 - torch.random(5, 10)
-	    return pairwise_transform.jpeg_(src, {quality1, quality2},
-					    size, offset, n, options)
-	 end
-      else
-	 error("unknown noise level: " .. level)
-      end
+      -- level adjusting by -nr_rate
+      return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
+				      size, offset, n,
+				      options)
    else
       error("unknown style: " .. style)
    end
@@ -215,6 +221,8 @@ function pairwise_transform.test_jpeg(src)
    local options = {random_color_noise_rate = 0.5,
 		    random_half_rate = 0.5,
 		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
+		    jpeg_chroma_subsampling_rate = 0.5,
 		    nr_rate = 1.0,
 		    active_cropping_rate = 0.5,
 		    active_cropping_tries = 10,
@@ -237,6 +245,7 @@ function pairwise_transform.test_scale(src)
    local options = {random_color_noise_rate = 0.5,
 		    random_half_rate = 0.5,
 		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
 		    active_cropping_rate = 0.5,
 		    active_cropping_tries = 10,
 		    max_size = 256,

+ 32 - 14
lib/settings.lua

@@ -30,35 +30,53 @@ cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
 cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
 cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
+cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
-cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
+cmd:option("-learning_rate", 0.001, 'learning rate for adam')
 cmd:option("-crop_size", 46, 'crop size')
 cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
-cmd:option("-batch_size", 8, 'mini batch size')
-cmd:option("-epoch", 200, 'number of total epochs to run')
+cmd:option("-batch_size", 32, 'mini batch size')
+cmd:option("-patches", 16, 'number of patch samples')
+cmd:option("-inner_epoch", 4, 'number of inner epochs')
+cmd:option("-epoch", 30, 'number of epochs to run')
 cmd:option("-thread", -1, 'number of CPU threads')
-cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
+cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
 cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
 cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
+cmd:option("-save_history", 0, 'save all model (0|1)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
    settings[k] = v
 end
-if settings.method == "noise" then
-   settings.model_file = string.format("%s/noise%d_model.t7",
-				       settings.model_dir, settings.noise_level)
-elseif settings.method == "scale" then
-   settings.model_file = string.format("%s/scale%.1fx_model.t7",
-				       settings.model_dir, settings.scale)
-elseif settings.method == "noise_scale" then
-   settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
-				       settings.model_dir, settings.noise_level, settings.scale)
+if settings.save_history == 1 then
+   settings.save_history = true
 else
-   error("unknown method: " .. settings.method)
+   settings.save_history = false
+end
+if settings.save_history then
+   if settings.method == "noise" then
+      settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
+					  settings.model_dir, settings.noise_level)
+   elseif settings.method == "scale" then
+      settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
+					  settings.model_dir, settings.scale)
+   else
+      error("unknown method: " .. settings.method)
+   end
+else
+   if settings.method == "noise" then
+      settings.model_file = string.format("%s/noise%d_model.t7",
+					  settings.model_dir, settings.noise_level)
+   elseif settings.method == "scale" then
+      settings.model_file = string.format("%s/scale%.1fx_model.t7",
+					  settings.model_dir, settings.scale)
+   else
+      error("unknown method: " .. settings.method)
+   end
 end
 if not (settings.color == "rgb" or settings.color == "y") then
    error("color must be y or rgb")

File diff ditekan karena terlalu besar
+ 0 - 0
models/photo/noise1_model.json


File diff ditekan karena terlalu besar
+ 161 - 0
models/photo/noise1_model.t7


File diff ditekan karena terlalu besar
+ 0 - 0
models/photo/noise2_model.json


File diff ditekan karena terlalu besar
+ 161 - 0
models/photo/noise2_model.t7


File diff ditekan karena terlalu besar
+ 0 - 0
models/photo/scale2.0x_model.json


File diff ditekan karena terlalu besar
+ 161 - 0
models/photo/scale2.0x_model.t7


+ 38 - 44
tools/benchmark.lua

@@ -23,6 +23,7 @@ cmd:option("-noise_level", 1, 'model noise level')
 cmd:option("-jpeg_quality", 75, 'jpeg quality')
 cmd:option("-jpeg_times", 1, 'jpeg compression times')
 cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
+cmd:option("-range_bug", 0, 'Reproducing the dynamic range bug that is caused by MATLAB\'s rgb2ycbcr(1|0)')
 
 local opt = cmd:parse(arg)
 torch.setdefaulttensortype('torch.FloatTensor')
@@ -41,25 +42,33 @@ local function rgb2y_matlab(x)
    return y:byte():float()
 end
 
-local function MSE(x1, x2)
+local function RGBMSE(x1, x2)
    x1 = iproc.float2byte(x1):float()
    x2 = iproc.float2byte(x2):float()
    return (x1 - x2):pow(2):mean()
 end
 local function YMSE(x1, x2)
-   local x1_2 = rgb2y_matlab(x1)
-   local x2_2 = rgb2y_matlab(x2)
-   return (x1_2 - x2_2):pow(2):mean()
+   if opt.range_bug == 1 then
+      local x1_2 = rgb2y_matlab(x1)
+      local x2_2 = rgb2y_matlab(x2)
+      return (x1_2 - x2_2):pow(2):mean()
+   else
+      local x1_2 = image.rgb2y(x1):mul(255.0)
+      local x2_2 = image.rgb2y(x2):mul(255.0)
+      return (x1_2 - x2_2):pow(2):mean()
+   end
 end
-local function PSNR(x1, x2)
-   local mse = MSE(x1, x2)
-   return 10 * math.log10((255.0 * 255.0) / mse)
+local function MSE(x1, x2, color)
+   if color == "y" then
+      return YMSE(x1, x2)
+   else
+      return RGBMSE(x1, x2)
+   end
 end
-local function YPSNR(x1, x2)
-   local mse = YMSE(x1, x2)
+local function PSNR(x1, x2, color)
+   local mse = MSE(x1, x2, color)
    return 10 * math.log10((255.0 * 255.0) / mse)
 end
-
 local function transform_jpeg(x, opt)
    for i = 1, opt.jpeg_times do
       jpeg = gm.Image(x, "RGB", "DHW")
@@ -69,7 +78,7 @@ local function transform_jpeg(x, opt)
       jpeg:fromBlob(blob, len)
       x = jpeg:toTensor("byte", "RGB", "DHW")
    end
-   return x
+   return iproc.byte2float(x)
 end
 local function baseline_scale(x, filter)
    return iproc.scale(x,
@@ -110,62 +119,47 @@ local function benchmark(opt, x, input_func, model1, model2)
 	 end
 	 baseline_output = baseline_scale(input, opt.filter)
       end
-      if opt.color == "y" then
-	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
-	 model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
-	 if model2 then
-	    model2_mse = model2_mse + YMSE(ground_truth, model2_output)
-	    model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
-	 end
-	 if baseline_output then
-	    baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
-	    baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
-	 end
-      elseif opt.color == "rgb" then
-	 model1_mse = model1_mse + MSE(ground_truth, model1_output)
-	 model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
-	 if model2 then
-	    model2_mse = model2_mse + MSE(ground_truth, model2_output)
-	    model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
-	 end
-	 if baseline_output then
-	    baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
-	    baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
-	 end
-      else
-	 error("Unknown color: " .. opt.color)
+      model1_mse = model1_mse + MSE(ground_truth, model1_output, opt.color)
+      model1_psnr = model1_psnr + PSNR(ground_truth, model1_output, opt.color)
+      if model2 then
+	 model2_mse = model2_mse + MSE(ground_truth, model2_output, opt.color)
+	 model2_psnr = model2_psnr + PSNR(ground_truth, model2_output, opt.color)
+      end
+      if baseline_output then
+	 baseline_mse = baseline_mse + MSE(ground_truth, baseline_output, opt.color)
+	 baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output, opt.color)
       end
       if model2 then
 	 if baseline_output then
 	    io.stdout:write(
-	       string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
+	       string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
 			     i, #x,
-			     baseline_mse / i,
-			     model1_mse / i, model2_mse / i,
+			     math.sqrt(baseline_mse / i),
+			     math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 			     baseline_psnr / i,
 			     model1_psnr / i, model2_psnr / i
 	    ))
 	 else
 	    io.stdout:write(
-	       string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
+	       string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
 			     i, #x,
-			     model1_mse / i, model2_mse / i,
+			     math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 			     model1_psnr / i, model2_psnr / i
 	    ))
 	 end
       else
 	 if baseline_output then
 	    io.stdout:write(
-	       string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
+	       string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
 			     i, #x,
-			     baseline_mse / i, model1_mse / i,
+			     math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
 			     baseline_psnr / i, model1_psnr / i
 	    ))
 	 else
 	    io.stdout:write(
-	       string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
+	       string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r",
 			     i, #x,
-			     model1_mse / i, model1_psnr / i
+			     math.sqrt(model1_mse / i), model1_psnr / i
 	    ))
 	 end
       end

+ 81 - 39
train.lua

@@ -35,14 +35,14 @@ local function split_data(x, test_size)
    end
    return train_x, valid_x
 end
-local function make_validation_set(x, transformer, n, batch_size)
+local function make_validation_set(x, transformer, n, patches)
    n = n or 4
    local data = {}
    for i = 1, #x do
-      for k = 1, math.max(n / batch_size, 1) do
-	 local xy = transformer(x[i], true, batch_size)
-	 local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
-	 local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
+      for k = 1, math.max(n / patches, 1) do
+	 local xy = transformer(x[i], true, patches)
+	 local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
+	 local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
 	 for j = 1, #xy do
 	    tx[j]:copy(xy[j][1])
 	    ty[j]:copy(xy[j][2])
@@ -83,7 +83,8 @@ local function create_criterion(model)
 end
 local function transformer(x, is_validation, n, offset)
    x = compression.decompress(x)
-   n = n or settings.batch_size;
+   n = n or settings.patches
+
    if is_validation == nil then is_validation = false end
    local random_color_noise_rate = nil 
    local random_overlay_rate = nil
@@ -110,6 +111,7 @@ local function transformer(x, is_validation, n, offset)
 					 random_half_rate = settings.random_half_rate,
 					 random_color_noise_rate = random_color_noise_rate,
 					 random_overlay_rate = random_overlay_rate,
+					 random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
 					 max_size = settings.max_size,
 					 active_cropping_rate = active_cropping_rate,
 					 active_cropping_tries = active_cropping_tries,
@@ -125,8 +127,9 @@ local function transformer(x, is_validation, n, offset)
 					random_half_rate = settings.random_half_rate,
 					random_color_noise_rate = random_color_noise_rate,
 					random_overlay_rate = random_overlay_rate,
+					random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
 					max_size = settings.max_size,
-					jpeg_sampling_factors = settings.jpeg_sampling_factors,
+					jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
 					active_cropping_rate = active_cropping_rate,
 					active_cropping_tries = active_cropping_tries,
 					nr_rate = settings.nr_rate,
@@ -135,7 +138,24 @@ local function transformer(x, is_validation, n, offset)
    end
 end
 
+local function resampling(x, y, train_x, transformer, input_size, target_size)
+   print("## resampling")
+   for t = 1, #train_x do
+      xlua.progress(t, #train_x)
+      local xy = transformer(train_x[t], false, settings.patches)
+      for i = 1, #xy do
+	 local index = (t - 1) * settings.patches + i
+         x[index]:copy(xy[i][1])
+	 y[index]:copy(xy[i][2])
+      end
+      if t % 50 == 0 then
+	 collectgarbage()
+      end
+   end
+end
+
 local function train()
+   local LR_MIN = 1.0e-5
    local model = srcnn.create(settings.method, settings.backend, settings.color)
    local offset = reconstruct.offset_size(model)
    local pairwise_func = function(x, is_validation, n)
@@ -143,12 +163,12 @@ local function train()
    end
    local criterion = create_criterion(model)
    local x = torch.load(settings.images)
-   local lrd_count = 0
    local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
    }
+   local lrd_count = 0
    local ch = nil
    if settings.color == "y" then
       ch = 1
@@ -159,48 +179,70 @@ local function train()
    print("# make validation-set")
    local valid_xy = make_validation_set(valid_x, pairwise_func,
 					settings.validation_crops,
-					settings.batch_size)
+					settings.patches)
    valid_x = nil
    
    collectgarbage()
    model:cuda()
    print("load .. " .. #train_x)
+
+   local x = torch.Tensor(settings.patches * #train_x,
+			  ch, settings.crop_size, settings.crop_size)
+   local y = torch.Tensor(settings.patches * #train_x,
+			  ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
+
    for epoch = 1, settings.epoch do
       model:training()
       print("# " .. epoch)
-      print(minibatch_adam(model, criterion, train_x, adam_config,
-			   pairwise_func,
-			   {ch, settings.crop_size, settings.crop_size},
-			   {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
-			  ))
-      model:evaluate()
-      print("# validation")
-      local score = validate(model, criterion, valid_xy)
-      if score < best_score then
-	 local test_image = image_loader.load_float(settings.test) -- reload
-	 lrd_count = 0
-	 best_score = score
-	 print("* update best model")
-	 torch.save(settings.model_file, model)
-	 if settings.method == "noise" then
-	    local log = path.join(settings.model_dir,
-				  ("noise%d_best.png"):format(settings.noise_level))
-	    save_test_jpeg(model, test_image, log)
-	 elseif settings.method == "scale" then
-	    local log = path.join(settings.model_dir,
-				  ("scale%.1f_best.png"):format(settings.scale))
-	    save_test_scale(model, test_image, log)
-	 end
-      else
-	 lrd_count = lrd_count + 1
-	 if lrd_count > 5 then
+      resampling(x, y, train_x, pairwise_func)
+      for i = 1, settings.inner_epoch do
+	 print(minibatch_adam(model, criterion, x, y, adam_config))
+	 model:evaluate()
+	 print("# validation")
+	 local score = validate(model, criterion, valid_xy)
+	 if score < best_score then
+	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0
-	    adam_config.learningRate = adam_config.learningRate * 0.9
-	    print("* learning rate decay: " .. adam_config.learningRate)
+	    best_score = score
+	    print("* update best model")
+	    if settings.save_history then
+	       local model_clone = model:clone()
+	       w2nn.cleanup_model(model_clone)
+	       torch.save(string.format(settings.model_file, epoch, i), model_clone)
+	       if settings.method == "noise" then
+		  local log = path.join(settings.model_dir,
+					("noise%d_best.%d-%d.png"):format(settings.noise_level,
+									  epoch, i))
+		  save_test_jpeg(model, test_image, log)
+	       elseif settings.method == "scale" then
+		  local log = path.join(settings.model_dir,
+					("scale%.1f_best.%d-%d.png"):format(settings.scale,
+									    epoch, i))
+		  save_test_scale(model, test_image, log)
+	       end
+	    else
+	       torch.save(settings.model_file, model)
+	       if settings.method == "noise" then
+		  local log = path.join(settings.model_dir,
+					("noise%d_best.png"):format(settings.noise_level))
+		  save_test_jpeg(model, test_image, log)
+	       elseif settings.method == "scale" then
+		  local log = path.join(settings.model_dir,
+					("scale%.1f_best.png"):format(settings.scale))
+		  save_test_scale(model, test_image, log)
+	       end
+	    end
+	 else
+	    lrd_count = lrd_count + 1
+	    if lrd_count > 2 and adam_config.learningRate > LR_MIN then
+	       adam_config.learningRate = adam_config.learningRate * 0.8
+	       print("* learning rate decay: " .. adam_config.learningRate)
+	       lrd_count = 0
+	    end
 	 end
+	 print("current: " .. score .. ", best: " .. best_score)
+	 collectgarbage()
       end
-      print("current: " .. score .. ", best: " .. best_score)
-      collectgarbage()
    end
 end
 if settings.gpu > 0 then

+ 7 - 7
web.lua

@@ -12,8 +12,9 @@ local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
 local image_loader = require 'image_loader'
 local alpha_util = require 'alpha_util'
+local gm = require 'graphicsmagick'
 
--- Notes:  turbo and xlua has different implementation of string:split().
+-- Note:  turbo and xlua has different implementation of string:split().
 --         Therefore, string:split() has conflict issue.
 --         In this script, use turbo's string:split().
 local turbo = require 'turbo'
@@ -36,13 +37,13 @@ if cudnn then
    cudnn.benchmark = false
 end
 local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
-local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
+local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
 local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
 local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
 local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
---local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
---local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
---local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
+local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
+local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
+local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
 local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
 local CACHE_DIR = path.join(ROOT, "cache")
 local MAX_NOISE_IMAGE = 2560 * 2560
@@ -143,7 +144,7 @@ local function convert(x, alpha, options)
 	    x = reconstruct.image(art_noise2_model, x)
 	    cleanup_model(art_noise2_model)
 	 end
-      else --[[photo
+      else -- photo
 	 if options.border then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
 	 end
@@ -163,7 +164,6 @@ local function convert(x, alpha, options)
 	    x = reconstruct.image(photo_noise2_model, x)
 	    cleanup_model(photo_noise2_model)
 	 end
-      --]]
       end
       image_loader.save_png(cache_file, x)
 

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini