فهرست منبع

Merge pull request #58 from nagadomi/dev

Sync from development branch
nagadomi 9 سال پیش
والد
کامیت
68593d9c51
62فایلهای تغییر یافته به همراه1986 افزوده شده و 973 حذف شده
  1. 2 0
      .gitattributes
  2. 12 1
      .gitignore
  3. 29 21
      README.md
  4. 9 1
      appendix/purge_cache.lua
  5. 10 42
      assets/index.html
  6. 7 41
      assets/index.ja.html
  7. 11 43
      assets/index.ru.html
  8. 52 0
      assets/style.css
  9. 80 0
      assets/ui.js
  10. 33 34
      convert_data.lua
  11. 0 34
      cudnn2cunn.lua
  12. 0 0
      data/.gitkeep
  13. 0 23
      export_model.lua
  14. 1 2
      images/gen.sh
  15. BIN
      images/lena_waifu2x.png
  16. BIN
      images/lena_waifu2x_ukbench.png
  17. BIN
      images/miku_CC_BY-NC_noisy_waifu2x.png
  18. BIN
      images/miku_noisy_waifu2x.png
  19. BIN
      images/miku_small.png
  20. BIN
      images/miku_small_lanczos3.png
  21. BIN
      images/miku_small_noisy_waifu2x.png
  22. BIN
      images/miku_small_waifu2x.png
  23. BIN
      images/slide.odp
  24. BIN
      images/slide.png
  25. BIN
      images/slide_noise_reduction.png
  26. BIN
      images/slide_result.png
  27. BIN
      images/slide_upscaling.png
  28. 39 0
      lib/ClippedWeightedHuberCriterion.lua
  29. 77 0
      lib/DepthExpand2x.lua
  30. 4 3
      lib/LeakyReLU.lua
  31. 31 0
      lib/LeakyReLU_deprecated.lua
  32. 25 0
      lib/WeightedMSECriterion.lua
  33. 4 24
      lib/cleanup_model.lua
  34. 17 0
      lib/compression.lua
  35. 104 0
      lib/data_augmentation.lua
  36. 81 39
      lib/image_loader.lua
  37. 112 4
      lib/iproc.lua
  38. 7 12
      lib/minibatch_adam.lua
  39. 203 240
      lib/pairwise_transform.lua
  40. 0 4
      lib/portable.lua
  41. 95 16
      lib/reconstruct.lua
  42. 31 28
      lib/settings.lua
  43. 46 52
      lib/srcnn.lua
  44. 26 0
      lib/w2nn.lua
  45. 0 0
      models/anime_style_art_rgb/noise1_model.json
  46. 33 27
      models/anime_style_art_rgb/noise1_model.t7
  47. 0 0
      models/anime_style_art_rgb/noise2_model.json
  48. 33 27
      models/anime_style_art_rgb/noise2_model.t7
  49. 0 0
      models/anime_style_art_rgb/scale2.0x_model.json
  50. 33 27
      models/anime_style_art_rgb/scale2.0x_model.t7
  51. 0 0
      models/ukbench/scale2.0x_model.json
  52. 33 27
      models/ukbench/scale2.0x_model.t7
  53. 169 0
      tools/benchmark.lua
  54. 25 0
      tools/cleanup_model.lua
  55. 43 0
      tools/cudnn2cunn.lua
  56. 43 0
      tools/cunn2cudnn.lua
  57. 54 0
      tools/export_model.lua
  58. 112 63
      train.lua
  59. 8 6
      train.sh
  60. 9 0
      train_ukbench.sh
  61. 124 42
      waifu2x.lua
  62. 119 90
      web.lua

+ 2 - 0
.gitattributes

@@ -0,0 +1,2 @@
+models/*/*.json binary
+*.t7 binary

+ 12 - 1
.gitignore

@@ -1,4 +1,15 @@
 *~
+work/
 cache/*.png
-models/*.png
+cache/url_*
+data/
+!data/.gitkeep
+
+models/
+!models/anime_style_art
+!models/anime_style_art_rgb
+!models/ukbench
+models/*/*.png
+
 waifu2x.log
+

+ 29 - 21
README.md

@@ -19,16 +19,11 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
 
 ## Public AMI
 ```
-AMI ID: ami-0be01e4f
-AMI NAME: waifu2x-server
-Instance Type: g2.2xlarge
-Region: US West (N.California)
-OS: Ubuntu 14.04
-User: ubuntu
-Created at: 2015-08-12
+TODO
 ```
 
 ## Third Party Software
+
 [Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party)
 
 ## Dependencies
@@ -37,10 +32,12 @@ Created at: 2015-08-12
 - NVIDIA GPU
 
 ### Platform
+
 - [Torch7](http://torch.ch/)
 - [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit)
 
 ### lualocks packages (excludes torch7's default packages)
+- lua-csnappy
 - md5
 - uuid
 - [turbo](https://github.com/kernelsauce/turbo)
@@ -57,34 +54,44 @@ See: [NVIDIA CUDA Getting Started Guide for Linux](http://docs.nvidia.com/cuda/c
 Download [CUDA](http://developer.nvidia.com/cuda-downloads)
 
 ```
-sudo dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb
+sudo dpkg -i cuda-repo-ubuntu1404_7.5-18_amd64.deb
 sudo apt-get update
 sudo apt-get install cuda
 ```
 
+#### Install Package
+
+```
+sudo apt-get install libsnappy-dev
+```
+
 #### Install Torch7
 
 See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
 
-#### Validation
-
-Test the waifu2x command line tool.
+And install luarocks packages.
 ```
-th waifu2x.lua
+luarocks install graphicsmagick # upgrade
+luarocks install lua-csnappy
+luarocks install md5
+luarocks install uuid
+PREFIX=$HOME/torch/install luarocks install turbo # if you need to use web application
 ```
 
-### Setting Up the Web Application Environment (if you needed)
+#### Getting waifu2x
 
-#### Install packages
+```
+git clone --depth 1 https://github.com/nagadomi/waifu2x.git
+```
 
+#### Validation
+
+Testing the waifu2x command line tool.
 ```
-luarocks install md5
-luarocks install uuid
-PREFIX=$HOME/torch/install luarocks install turbo
+th waifu2x.lua
 ```
 
 ## Web Application
-Run.
 ```
 th web.lua
 ```
@@ -114,11 +121,11 @@ th waifu2x.lua -m noise_scale -noise_level 1 -i input_image.png -o output_image.
 th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.png
 ```
 
-See also `images/gen.sh`.
+See also `th waifu2x.lua -h`.
 
 ### Video Encoding
 
-\* `avconv` is `ffmpeg` on Ubuntu 14.04.
+\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.
 
 Extracting images and audio from a video. (range: 00:09:00 ~ 00:12:00)
 ```
@@ -144,6 +151,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
 ```
 
 ## Training Your Own Model
+Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
 
 ### Data Preparation
 
@@ -151,7 +159,7 @@ Genrating a file list.
 ```
 find /path/to/image/dir -name "*.png" > data/image_list.txt
 ```
-(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.)
+You should use noise free images. In my case, waifu2x is trained with 6000 high-resolution-noise-free-PNG images.
 
 Converting training data.
 ```

+ 9 - 1
appendix/purge_cache.lua

@@ -3,7 +3,15 @@ require 'pl'
 CACHE_DIR="cache"
 TTL = 3600 * 24
 
-local files = dir.getfiles(CACHE_DIR, "*.png")
+local files = {}
+local image_cache = dir.getfiles(CACHE_DIR, "*.png")
+local url_cache = dir.getfiles(CACHE_DIR, "url_*")
+for i = 1, #image_cache do 
+   table.insert(files, image_cache[i])
+end
+for i = 1, #url_cache do 
+   table.insert(files, url_cache[i])
+end
 local now = os.time()
 for i, f in pairs(files) do
    if now - path.getmtime(f) > TTL then

+ 10 - 42
assets/index.html

@@ -2,51 +2,17 @@
 <html>
   <head>
     <meta charset="UTF-8">
-    <link rel="canonical" href="http://waifu2x.udp.jp/">
     <title>waifu2x</title>
-    <style type="text/css">
-    body {
-      margin: 1em 2em 1em 2em;
-      background: LightGray;
-      width: 640px;
-    }
-    fieldset {
-      margin-top: 1em;
-      margin-bottom: 1em;
-    }
-    .about {
-      position: relative;
-      display: inline-block;
-      font-size: 0.9em;
-      padding: 1em 5px 0.2em 0;
-    }
-    .help {
-      font-size: 0.85em;
-      margin: 1em 0 0 0;
-    }
-    </style>
+    <link href="style.css" rel="stylesheet" type="text/css">
     <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
-    <script type="text/javascript">
-    function clear_file() {
-      var new_file = $("#file").clone();
-      new_file.change(clear_url);
-      $("#file").replaceWith(new_file);
-    }
-    function clear_url() {
-      $("#url").val("")
-    }
-    $(function (){
-      $("#url").change(clear_file);
-      $("#file").change(clear_url);
-    })
-    </script>
+    <script type="text/javascript" src="ui.js"></script>
   </head>
   <body>
     <h1>waifu2x</h1>
     <div class="header">
-      <div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
-        <img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
-        <a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
+      <div class="github-banner">
+        <img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
+        <a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
       </div>
       <a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
     </div>
@@ -66,12 +32,14 @@
           Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
         </div>
       </fieldset>
-      <fieldset>
+      <fieldset class="noise-field">
         <legend>Noise Reduction (expect JPEG Artifact)</legend>
         <label><input type="radio" name="noise" value="0"> None</label>
         <label><input type="radio" name="noise" value="1" checked="checked"> Medium</label>
         <label><input type="radio" name="noise" value="2"> High</label>
-        <div class="help">When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.</div>
+        <div class="help">
+	  When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.
+	</div>
       </fieldset>
       <fieldset>
         <legend>Upscaling</legend>
@@ -82,7 +50,7 @@
       <input type="submit"/>
     </form>
     <div class="help">
-      <ul style="padding-left: 15px;">
+      <ul class="padding-left">
         <li>If you are using Firefox, Please press the CTRL+S key to save image. "Save Image" option doesn't work.
       </ul>
     </div>

+ 7 - 41
assets/index.ja.html

@@ -2,51 +2,17 @@
 <html lang="ja">
   <head>
     <meta charset="UTF-8">
-    <link rel="canonical" href="http://waifu2x.udp.jp/">
+    <link href="style.css" rel="stylesheet" type="text/css">
     <title>waifu2x</title>
-    <style type="text/css">
-    body {
-      margin: 1em 2em 1em 2em;
-      background: LightGray;
-      width: 640px;
-    }
-    fieldset {
-      margin-top: 1em;
-      margin-bottom: 1em;
-    }
-    .about {
-      position: relative;
-      display: inline-block;
-      font-size: 0.8em;
-      padding: 1em 5px 0.2em 0;
-    }
-    .help {
-      font-size: 0.8em;
-      margin: 1em 0 0 0;
-    }
-    </style>
     <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
-    <script type="text/javascript">
-    function clear_file() {
-      var new_file = $("#file").clone();
-      new_file.change(clear_url);
-      $("#file").replaceWith(new_file);
-    }
-    function clear_url() {
-      $("#url").val("")
-    }
-    $(function (){
-      $("#url").change(clear_file);
-      $("#file").change(clear_url);
-    })
-    </script>
+    <script type="text/javascript" src="ui.js"></script>    
   </head>
   <body>
     <h1>waifu2x</h1>
     <div class="header">
-      <div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
-        <img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
-        <a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
+      <div class="github-banner">
+        <img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
+        <a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
       </div>
       <a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
     </div>
@@ -66,7 +32,7 @@
           制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
         </div>
       </fieldset>
-      <fieldset>
+      <fieldset class="noise-field">
         <legend>ノイズ除去 (JPEGノイズを想定)</legend>
         <label><input type="radio" name="noise" value="0"> なし</label>
         <label><input type="radio" name="noise" value="1" checked="checked"> 弱</label>
@@ -81,7 +47,7 @@
       <input type="submit" value="実行"/>
     </form>
     <div class="help">
-      <ul style="padding-left: 15px;">
+      <ul class="padding-left">
         <li>なし/なしで入力画像を変換せずに出力する。ブラウザのタブで変換結果を比較したい人用。
         <li>Firefoxの方は、右クリから画像が保存できないようなので、CTRL+SキーかALTキー後 ファイル - ページを保存 で画像を保存してください。
       </ul>

+ 11 - 43
assets/index.ru.html

@@ -2,51 +2,18 @@
 <html>
   <head>
     <meta charset="UTF-8">
-    <link rel="canonical" href="http://waifu2x.udp.jp/">
+    <link href="style.css" rel="stylesheet" type="text/css">
     <title>waifu2x</title>
-    <style type="text/css">
-    body {
-      margin: 1em 2em 1em 2em;
-      background: LightGray;
-      width: 640px;
-    }
-    fieldset {
-      margin-top: 1em;
-      margin-bottom: 1em;
-    }
-    .about {
-      position: relative;
-      display: inline-block;
-      font-size: 0.9em;
-      padding: 1em 5px 0.2em 0;
-    }
-    .help {
-      font-size: 0.85em;
-      margin: 1em 0 0 0;
-    }
-    </style>
+    <link href="style.css" rel="stylesheet" type="text/css">
     <script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
-    <script type="text/javascript">
-    function clear_file() {
-      var new_file = $("#file").clone();
-      new_file.change(clear_url);
-      $("#file").replaceWith(new_file);
-    }
-    function clear_url() {
-      $("#url").val("")
-    }
-    $(function (){
-      $("#url").change(clear_file);
-      $("#file").change(clear_url);
-    })
-    </script>
+    <script type="text/javascript" src="ui.js"></script>
   </head>
   <body>
     <h1>waifu2x</h1>
     <div class="header">
-      <div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
-        <img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
-        <a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
+      <div class="github-banner">
+        <img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
+        <a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
       </div>
       <a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
     </div>
@@ -66,11 +33,11 @@
           Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
         </div>
       </fieldset>
-      <fieldset>
-      <legend>Устранение шума (артефактов JPEG)</legend>
+      <fieldset class="noise-field">
+	<legend>Устранение шума (артефактов JPEG)</legend>
         <label><input type="radio" name="noise" value="0"> Нет</label>
         <label><input type="radio" name="noise" value="1" checked="checked"> Средне</label>
-        <label><input type="radio" name="noise" value="2"> Сильно (не рекомендуется)</label>
+        <label><input type="radio" name="noise" value="2"> Сильно</label>
         <div class="help">Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект. Также не рекомендуется сильное устранение шума, оно даёт выгоду только в редких случаях, когда картинка изначально была сильно испорчена.</div>
       </fieldset>
       <fieldset>
@@ -82,8 +49,9 @@
       <input type="submit"/>
     </form>
     <div class="help">
-      <ul style="padding-left: 15px;">
+      <ul class="padding-left">
         <li>Если Вы используете Firefox, для сохранения изображения Вам придётся нажать Ctrl+S (опция в меню "Сохранить изображение" работать не будет!)
+	</li>
       </ul>
     </div>
   </body>

+ 52 - 0
assets/style.css

@@ -0,0 +1,52 @@
+body {
+    margin: 1em 2em 1em 2em;
+    background: LightGray;
+    width: 640px;
+}
+fieldset {
+    margin-top: 1em;
+    margin-bottom: 1em;
+}
+.about {
+    position: relative;
+    display: inline-block;
+    font-size: 0.9em;
+    padding: 1em 5px 0.2em 0;
+}
+.help {
+    font-size: 0.8em;
+    margin: 1em 0 0 0;
+}
+.github-banner {
+    position:absolute;
+    display:block;
+    top:0;
+    left:540px;
+    max-height:140px;
+}
+.github-banner-image {
+    position: absolute;
+    display: block; 
+    left: 0;
+    top: 0;
+    width: 149px;
+    height: 149px;
+    border: 0;
+}
+.github-banner-link {
+    position: absolute;
+    display: block; 
+    left:0;
+    top:0; 
+    width:149px;
+    height:130px;
+}
+.padding-left {
+    padding-left: 15px;
+}
+.hide {
+    display: none;
+}
+.experimental {
+    margin-bottom: 1em;
+}

+ 80 - 0
assets/ui.js

@@ -0,0 +1,80 @@
+$(function (){
+    function clear_file() {
+	var new_file = $("#file").clone();
+	new_file.change(clear_url);
+	$("#file").replaceWith(new_file);
+    }
+    function clear_url() {
+	$("#url").val("")
+    }
+    function on_change_style(e) {
+	$("input[name=style]").parents("label").each(
+	    function (i, elm) {
+		$(elm).css("font-weight", "normal");
+	    });
+	var checked = $("input[name=style]:checked");
+	checked.parents("label").css("font-weight", "bold");
+	if (checked.val() == "art") {
+	    $("h1").text("waifu2x");
+	} else {
+	    $("h1").html("w<s>/a/</s>ifu2x");
+	}
+    }
+    function on_change_noise_level(e)
+    {
+	$("input[name=noise]").parents("label").each(
+	    function (i, elm) {
+		$(elm).css("font-weight", "normal");
+	    });
+	var checked = $("input[name=noise]:checked");
+	if (checked.val() != 0) {
+	    checked.parents("label").css("font-weight", "bold");
+	}
+    }
+    function on_change_scale_factor(e)
+    {
+	$("input[name=scale]").parents("label").each(
+	    function (i, elm) {
+		$(elm).css("font-weight", "normal");
+	    });
+	var checked = $("input[name=scale]:checked");
+	if (checked.val() != 0) {
+	    checked.parents("label").css("font-weight", "bold");
+	}
+    }
+    function on_change_white_noise(e)
+    {
+	$("input[name=white_noise]").parents("label").each(
+	    function (i, elm) {
+		$(elm).css("font-weight", "normal");
+	    });
+	var checked = $("input[name=white_noise]:checked");
+	if (checked.val() != 0) {
+	    checked.parents("label").css("font-weight", "bold");
+	}
+    }
+    function on_click_experimental_button(e)
+    {
+	if ($(this).hasClass("close")) {
+	    $(".experimental .container").show();
+	    $(this).removeClass("close");
+	} else {
+	    $(".experimental .container").hide();
+	    $(this).addClass("close");
+	}
+	e.preventDefault();
+	e.stopPropagation();
+    }
+    
+    $("#url").change(clear_file);
+    $("#file").change(clear_url);
+    //$("input[name=style]").change(on_change_style);
+    $("input[name=noise]").change(on_change_noise_level);
+    $("input[name=scale]").change(on_change_scale_factor);
+    //$("input[name=white_noise]").change(on_change_white_noise);
+    //$(".experimental .button").click(on_click_experimental_button)
+    
+    //on_change_style();
+    on_change_scale_factor();
+    on_change_noise_level();
+})

+ 33 - 34
convert_data.lua

@@ -1,48 +1,47 @@
-require './lib/portable'
-require 'image'
-local settings = require './lib/settings'
-local image_loader = require './lib/image_loader'
-
-local function count_lines(file)
-   local fp = io.open(file, "r")
-   local count = 0
-   for line in fp:lines() do
-      count = count + 1
-   end
-   fp:close()
-   
-   return count
-end
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 
-local function crop_4x(x)
-   local w = x:size(3) % 4
-   local h = x:size(2) % 4
-   return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
-end
+require 'pl'
+require 'image'
+local compression = require 'compression'
+local settings = require 'settings'
+local image_loader = require 'image_loader'
+local iproc = require 'iproc'
 
 local function load_images(list)
-   local count = count_lines(list)
-   local fp = io.open(list, "r")
+   local MARGIN = 32
+   local lines = utils.split(file.read(list), "\n")
    local x = {}
-   local c = 0
-   for line in fp:lines() do
-      local im = crop_4x(image_loader.load_byte(line))
-      if im then
-	 if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then
-	    table.insert(x, im)
-	 end
+   for i = 1, #lines do
+      local line = lines[i]
+      local im, alpha = image_loader.load_byte(line)
+      if alpha then
+	 io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
       else
-	 print("error:" .. line)
+	 im = iproc.crop_mod4(im)
+	 local scale = 1.0
+	 if settings.random_half then
+	    scale = 2.0
+	 end
+	 if im then
+	    if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
+	       table.insert(x, compression.compress(im))
+	    else
+	       io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
+	    end
+	 else
+	    io.stderr:write(string.format("\n%s: skip: load error.\n", line))
+	 end
       end
-      c = c + 1
-      xlua.progress(c, count)
-      if c % 10 == 0 then
+      xlua.progress(i, #lines)
+      if i % 10 == 0 then
 	 collectgarbage()
       end
    end
    return x
 end
+
+torch.manualSeed(settings.seed)
 print(settings)
 local x = load_images(settings.image_list)
 torch.save(settings.images, x)
-

+ 0 - 34
cudnn2cunn.lua

@@ -1,34 +0,0 @@
-require 'cunn'
-require 'cudnn'
-require 'cutorch'
-require './lib/LeakyReLU'
-local srcnn = require 'lib/srcnn'
-
-local function cudnn2cunn(cudnn_model)
-   local cunn_model = srcnn.waifu2x("y")
-   local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
-   local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
-
-   for i = 1, #from_seq do
-      local from = from_seq[i]
-      local to = to_seq[i]
-      to.weight:copy(from.weight)
-      to.bias:copy(from.bias)
-   end
-   cunn_model:cuda()
-   cunn_model:evaluate()
-   return cunn_model
-end
-
-local cmd = torch.CmdLine()
-cmd:text()
-cmd:text("convert cudnn model to cunn model ")
-cmd:text("Options:")
-cmd:option("-model", "./model.t7", 'path of cudnn model file')
-cmd:option("-iformat", "ascii", 'input format')
-cmd:option("-oformat", "ascii", 'output format')
-
-local opt = cmd:parse(arg)
-local cudnn_model = torch.load(opt.model, opt.iformat)
-local cunn_model = cudnn2cunn(cudnn_model)
-torch.save(opt.model, cunn_model, opt.oformat)

+ 0 - 0
data/.gitkeep


+ 0 - 23
export_model.lua

@@ -1,23 +0,0 @@
--- adapted from https://github.com/marcan/cl-waifu2x
-require './lib/portable'
-require './lib/LeakyReLU'
-local cjson = require "cjson"
-
-local model = torch.load(arg[1], "ascii")
-
-local jmodules = {}
-local modules = model:findModules("nn.SpatialConvolutionMM")
-for i = 1, #modules, 1 do
-   local module = modules[i]
-   local jmod = {
-      kW = module.kW,
-      kH = module.kH,
-      nInputPlane = module.nInputPlane,
-      nOutputPlane = module.nOutputPlane,
-      bias = torch.totable(module.bias:float()),
-      weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
-   }
-   table.insert(jmodules, jmod)
-end
-
-io.write(cjson.encode(jmodules))

+ 1 - 2
images/gen.sh

@@ -1,8 +1,7 @@
 #!/bin/sh
 
-th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
+th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
-th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
 th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
 th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png

BIN
images/lena_waifu2x.png


BIN
images/lena_waifu2x_ukbench.png


BIN
images/miku_CC_BY-NC_noisy_waifu2x.png


BIN
images/miku_noisy_waifu2x.png


BIN
images/miku_small.png


BIN
images/miku_small_lanczos3.png


BIN
images/miku_small_noisy_waifu2x.png


BIN
images/miku_small_waifu2x.png


BIN
images/slide.odp


BIN
images/slide.png


BIN
images/slide_noise_reduction.png


BIN
images/slide_result.png


BIN
images/slide_upscaling.png


+ 39 - 0
lib/ClippedWeightedHuberCriterion.lua

@@ -0,0 +1,39 @@
+-- ref: https://en.wikipedia.org/wiki/Huber_loss
+local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
+
+function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
+   parent.__init(self)
+   self.clip = clip
+   self.gamma = gamma or 1.0
+   self.weight = w:clone()
+   self.diff = torch.Tensor()
+   self.diff_abs = torch.Tensor()
+   --self.outlier_rate = 0.0
+   self.square_loss_buff = torch.Tensor()
+   self.linear_loss_buff = torch.Tensor()
+end
+function ClippedWeightedHuberCriterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
+   self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
+   for i = 1, input:size(1) do
+      self.diff[i]:add(-1, target[i]):cmul(self.weight)
+   end
+   self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
+   
+   local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
+   local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
+   local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
+   local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
+
+   --self.outlier_rate = linear_targets:nElement() / input:nElement()
+   self.output = (square_loss + linear_loss) / input:nElement()
+   return self.output
+end
+function ClippedWeightedHuberCriterion:updateGradInput(input, target)
+   local norm = 1.0 / input:nElement()
+   self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
+   local outlier = torch.ge(self.diff_abs, self.gamma)
+   self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm
+   return self.gradInput 
+end

+ 77 - 0
lib/DepthExpand2x.lua

@@ -0,0 +1,77 @@
+if w2nn.DepthExpand2x then
+   return w2nn.DepthExpand2x
+end
+local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
+ 
+function DepthExpand2x:__init()
+   parent:__init()
+end
+
+function DepthExpand2x:updateOutput(input)
+   local x = input
+   -- (batch_size, depth, height, width)
+   self.shape = x:size()
+
+   assert(self.shape:size() == 4, "input must be 4d tensor")
+   assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
+   -- (batch_size, width, height, depth)
+   x = x:transpose(2, 4)
+   -- (batch_size, width, height * 2, depth / 2)
+   x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
+   -- (batch_size, height * 2, width, depth / 2)
+   x = x:transpose(2, 3)
+   -- (batch_size, height * 2, width * 2, depth / 4)
+   x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
+   -- (batch_size, depth / 4, height * 2, width * 2)
+   x = x:transpose(2, 4)
+   x = x:transpose(3, 4)
+   self.output:resizeAs(x):copy(x) -- contiguous
+   
+   return self.output
+end
+
+function DepthExpand2x:updateGradInput(input, gradOutput)
+   -- (batch_size, depth / 4, height * 2, width * 2)
+   local x = gradOutput
+   -- (batch_size, height * 2, width * 2, depth / 4)
+   x = x:transpose(2, 4)
+   x = x:transpose(2, 3)
+   -- (batch_size, height * 2, width, depth / 2)
+   x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
+   -- (batch_size, width, height * 2, depth / 2)
+   x = x:transpose(2, 3)
+   -- (batch_size, width, height, depth)
+   x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
+   -- (batch_size, depth, height, width)
+   x = x:transpose(2, 4)
+   
+   self.gradInput:resizeAs(x):copy(x)
+   
+   return self.gradInput
+end
+
+function DepthExpand2x.test()
+   require 'image'
+   local function show(x)
+      local img = torch.Tensor(3, x:size(3), x:size(4))
+      img[1]:copy(x[1][1])
+      img[2]:copy(x[1][2])
+      img[3]:copy(x[1][3])
+      image.display(img)
+   end
+   local img = image.lena()
+   local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
+   for i = 0, img:size(1) * 4 - 1 do
+      src_index = ((i % 3) + 1)
+      x[1][i + 1]:copy(img[src_index])
+   end
+   show(x)
+   
+   local de2x = w2nn.DepthExpand2x()
+   out = de2x:forward(x)
+   show(out)
+   out = de2x:updateGradInput(x, out)
+   show(out)
+end
+
+return DepthExpand2x

+ 4 - 3
lib/LeakyReLU.lua

@@ -1,7 +1,8 @@
-if nn.LeakyReLU then
-   return
+if w2nn and w2nn.LeakyReLU then
+   return w2nn.LeakyReLU
 end
-local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
+
+local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
  
 function LeakyReLU:__init(negative_scale)
    parent.__init(self)

+ 31 - 0
lib/LeakyReLU_deprecated.lua

@@ -0,0 +1,31 @@
+if nn.LeakyReLU then
+   return nn.LeakyReLU
+end
+
+local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
+ 
+function LeakyReLU:__init(negative_scale)
+   parent.__init(self)
+   self.negative_scale = negative_scale or 0.333
+   self.negative = torch.Tensor()
+end
+ 
+function LeakyReLU:updateOutput(input)
+   self.output:resizeAs(input):copy(input):abs():add(input):div(2)
+   self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
+   self.output:add(self.negative)
+   
+   return self.output
+end
+ 
+function LeakyReLU:updateGradInput(input, gradOutput)
+   self.gradInput:resizeAs(gradOutput)
+   -- filter positive
+   self.negative:sign():add(1)
+   torch.cmul(self.gradInput, gradOutput, self.negative)
+   -- filter negative
+   self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
+   self.gradInput:add(self.negative)
+   
+   return self.gradInput
+end

+ 25 - 0
lib/WeightedMSECriterion.lua

@@ -0,0 +1,25 @@
+local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
+
+function WeightedMSECriterion:__init(w)
+   parent.__init(self)
+   self.weight = w:clone()
+   self.diff = torch.Tensor()
+   self.loss = torch.Tensor()
+end
+
+function WeightedMSECriterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   for i = 1, input:size(1) do
+      self.diff[i]:add(-1, target[i]):cmul(self.weight)
+   end
+   self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
+   self.output = self.loss:mean()
+   
+   return self.output
+end
+
+function WeightedMSECriterion:updateGradInput(input, target)
+   local norm = 2.0 / input:nElement()
+   self.gradInput:resizeAs(input):copy(self.diff):mul(norm)
+   return self.gradInput
+end

+ 4 - 24
cleanup_model.lua → lib/cleanup_model.lua

@@ -1,9 +1,5 @@
-require './lib/portable'
-require './lib/LeakyReLU'
-
-torch.setdefaulttensortype("torch.FloatTensor")
-
 -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
+
 local function zeroDataSize(data)
    if type(data) == 'table' then
       for i = 1, #data do
@@ -14,7 +10,6 @@ local function zeroDataSize(data)
    end
    return data
 end
-
 -- Resize the output, gradInput, etc temporary tensors to zero (so that the
 -- on disk size is smaller)
 local function cleanupModel(node)
@@ -27,7 +22,7 @@ local function cleanupModel(node)
    if node.finput ~= nil then
       node.finput = zeroDataSize(node.finput)
    end
-   if tostring(node) == "nn.LeakyReLU" then
+   if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
       if node.negative ~= nil then
 	 node.negative = zeroDataSize(node.negative)
       end
@@ -46,23 +41,8 @@ local function cleanupModel(node)
 	end
      end
    end
-   
-   collectgarbage()
 end
-
-local cmd = torch.CmdLine()
-cmd:text()
-cmd:text("cleanup model")
-cmd:text("Options:")
-cmd:option("-model", "./model.t7", 'path of model file')
-cmd:option("-iformat", "binary", 'input format')
-cmd:option("-oformat", "binary", 'output format')
-
-local opt = cmd:parse(arg)
-local model = torch.load(opt.model, opt.iformat)
-if model then
+function w2nn.cleanup_model(model)
    cleanupModel(model)
-   torch.save(opt.model, model, opt.oformat)
-else
-   error("model not found")
+   return model
 end

+ 17 - 0
lib/compression.lua

@@ -0,0 +1,17 @@
+-- snapply compression for ByteTensor
+require 'snappy'
+
+local compression = {}
+compression.compress = function (bt)
+   local enc = snappy.compress(bt:storage():string())
+   return {bt:size(), torch.ByteStorage():string(enc)}
+end
+compression.decompress = function(data)
+   local size = data[1]
+   local dec = snappy.decompress(data[2]:string())
+   local bt = torch.ByteTensor(unpack(torch.totable(size)))
+   bt:storage():string(dec)
+   return bt
+end
+
+return compression

+ 104 - 0
lib/data_augmentation.lua

@@ -0,0 +1,104 @@
+require 'image'
+local iproc = require 'iproc'
+
+local data_augmentation = {}
+
+local function pcacov(x)
+   local mean = torch.mean(x, 1)
+   local xm = x - torch.ger(torch.ones(x:size(1)), mean:squeeze())
+   local c = torch.mm(xm:t(), xm)
+   c:div(x:size(1) - 1)
+   local ce, cv = torch.symeig(c, 'V')
+   return ce, cv
+end
+function data_augmentation.color_noise(src, p, factor)
+   factor = factor or 0.1
+   if torch.uniform() < p then
+      local src, conversion = iproc.byte2float(src)
+      local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
+      local ce, cv = pcacov(src_t)
+      local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
+      
+      pca_space = torch.mm(src_t, cv):t():contiguous()
+      for i = 1, 3 do
+	 pca_space[i]:mul(color_scale[i])
+      end
+      local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
+      dest[torch.lt(dest, 0.0)] = 0.0
+      dest[torch.gt(dest, 1.0)] = 1.0
+
+      if conversion then
+	 dest = iproc.float2byte(dest)
+      end
+      return dest
+   else
+      return src
+   end
+end
+function data_augmentation.overlay(src, p)
+   if torch.uniform() < p then
+      local r = torch.uniform()
+      local src, conversion = iproc.byte2float(src)
+      src = src:contiguous()
+      local flip = data_augmentation.flip(src)
+      flip:mul(r):add(src * (1.0 - r))
+      if conversion then
+	 flip = iproc.float2byte(flip)
+      end
+      return flip
+   else
+      return src
+   end
+end
+function data_augmentation.shift_1px(src)
+   -- reducing the even/odd issue in nearest neighbor scaler.
+   local direction = torch.random(1, 4)
+   local x_shift = 0
+   local y_shift = 0
+   if direction == 1 then
+      x_shift = 1
+      y_shift = 0
+   elseif direction == 2 then
+      x_shift = 0
+      y_shift = 1
+   elseif direction == 3 then
+      x_shift = 1
+      y_shift = 1
+   elseif flip == 4 then
+      x_shift = 0
+      y_shift = 0
+   end
+   local w = src:size(3) - x_shift
+   local h = src:size(2) - y_shift
+   w = w - (w % 4)
+   h = h - (h % 4)
+   local dest = iproc.crop(src, x_shift, y_shift, x_shift + w, y_shift + h)
+   return dest
+end
+function data_augmentation.flip(src)
+   local flip = torch.random(1, 4)
+   local tr = torch.random(1, 2)
+   local src, conversion = iproc.byte2float(src)
+   local dest
+   
+   src = src:contiguous()
+   if tr == 1 then
+      -- pass
+   elseif tr == 2 then
+      src = src:transpose(2, 3):contiguous()
+   end
+   if flip == 1 then
+      dest = image.hflip(src)
+   elseif flip == 2 then
+      dest = image.vflip(src)
+   elseif flip == 3 then
+      dest = image.hflip(image.vflip(src))
+   elseif flip == 4 then
+      dest = src
+   end
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
+return data_augmentation

+ 81 - 39
lib/image_loader.lua

@@ -1,74 +1,118 @@
 local gm = require 'graphicsmagick'
 local ffi = require 'ffi'
+local iproc = require 'iproc'
 require 'pl'
 
 local image_loader = {}
 
-function image_loader.decode_float(blob)
-   local im, alpha = image_loader.decode_byte(blob)
-   if im then
-      im = im:float():div(255)
-   end
-   return im, alpha
-end
-function image_loader.encode_png(rgb, alpha)
-   if rgb:type() == "torch.ByteTensor" then
-      error("expect FloatTensor")
-   end
+local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
+local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
+local background_color = 0.5
+
+function image_loader.encode_png(rgb, alpha, depth)
+   depth = depth or 8
+   rgb = iproc.byte2float(rgb)
    if alpha then
       if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
-	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
+	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
       end
       local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
       rgba[1]:copy(rgb[1])
       rgba[2]:copy(rgb[2])
       rgba[3]:copy(rgb[3])
       rgba[4]:copy(alpha)
+      
+      if depth < 16 then
+	 rgba:add(clip_eps8)
+	 rgba[torch.lt(rgba, 0.0)] = 0.0
+	 rgba[torch.gt(rgba, 1.0)] = 1.0
+      else
+	 rgba:add(clip_eps16)
+	 rgba[torch.lt(rgba, 0.0)] = 0.0
+	 rgba[torch.gt(rgba, 1.0)] = 1.0
+      end
       local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
-      im:format("png")
-      return im:toBlob(9)
+      return im:depth(depth):format("PNG"):toString(9)
    else
+      if depth < 16 then
+	 rgb = rgb:clone():add(clip_eps8)
+	 rgb[torch.lt(rgb, 0.0)] = 0.0
+	 rgb[torch.gt(rgb, 1.0)] = 1.0
+      else
+	 rgb = rgb:clone():add(clip_eps16)
+	 rgb[torch.lt(rgb, 0.0)] = 0.0
+	 rgb[torch.gt(rgb, 1.0)] = 1.0
+      end
       local im = gm.Image(rgb, "RGB", "DHW")
-      im:format("png")
-      return im:toBlob(9)
+      return im:depth(depth):format("PNG"):toString(9)
    end
 end
-function image_loader.save_png(filename, rgb, alpha)
-   local blob, len = image_loader.encode_png(rgb, alpha)
+function image_loader.save_png(filename, rgb, alpha, depth)
+   depth = depth or 8
+   local blob = image_loader.encode_png(rgb, alpha, depth)
    local fp = io.open(filename, "wb")
-   fp:write(ffi.string(blob, len))
+   if not fp then
+      error("IO error: " .. filename)
+   end
+   fp:write(blob)
    fp:close()
    return true
 end
-function image_loader.decode_byte(blob)
+function image_loader.decode_float(blob)
    local load_image = function()
       local im = gm.Image()
       local alpha = nil
+      local gamma_lcd = 0.454545
       
       im:fromBlob(blob, #blob)
+      
+      if im:colorspace() == "CMYK" then
+	 im:colorspace("RGB")
+      end
+      local gamma = math.floor(im:gamma() * 1000000) / 1000000
+      if gamma ~= 0 and gamma ~= gamma_lcd then
+	 im:gammaCorrection(gamma / gamma_lcd)
+      end
       -- FIXME: How to detect that a image has an alpha channel?
       if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
 	 -- split alpha channel
 	 im = im:toTensor('float', 'RGBA', 'DHW')
-	 local sum_alpha = (im[4] - 1):sum()
-	 if sum_alpha > 0 or sum_alpha < 0 then
+	 local sum_alpha = (im[4] - 1.0):sum()
+	 if sum_alpha < 0 then
 	    alpha = im[4]:reshape(1, im:size(2), im:size(3))
+	    -- drop full transparent background
+	    local mask = torch.le(alpha, 0.0)
+	    im[1][mask] = background_color
+	    im[2][mask] = background_color
+	    im[3][mask] = background_color
 	 end
 	 local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
 	 new_im[1]:copy(im[1])
 	 new_im[2]:copy(im[2])
 	 new_im[3]:copy(im[3])
-	 im = new_im:mul(255):byte()
+	 im = new_im
       else
-	 im = im:toTensor('byte', 'RGB', 'DHW')
+	 im = im:toTensor('float', 'RGB', 'DHW')
       end
-      return {im, alpha}
+      return {im, alpha, blob}
    end
    local state, ret = pcall(load_image)
    if state then
-      return ret[1], ret[2]
+      return ret[1], ret[2], ret[3]
    else
-      return nil
+      return nil, nil, nil
+   end
+end
+function image_loader.decode_byte(blob)
+   local im, alpha
+   im, alpha, blob = image_loader.decode_float(blob)
+   
+   if im then
+      im = iproc.float2byte(im)
+      -- hmm, alpha does not convert here
+      return im, alpha, blob
+   else
+      return nil, nil, nil
    end
 end
 function image_loader.load_float(file)
@@ -90,18 +134,16 @@ function image_loader.load_byte(file)
    return image_loader.decode_byte(buff)
 end
 local function test()
-   require 'image'
-   local img
-   img = image_loader.load_float("./a.jpg")
-   if img then
-      print(img:min())
-      print(img:max())
-      image.display(img)
-   end
-   img = image_loader.load_float("./b.png")
-   if img then
-      image.display(img)
-   end
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local a = image_loader.load_float("../images/lena.png")
+   local blob = image_loader.encode_png(a)
+   local b = image_loader.decode_float(blob)
+   assert((b - a):abs():sum() == 0)
+
+   a = image_loader.load_byte("../images/lena.png")
+   blob = image_loader.encode_png(a)
+   b = image_loader.decode_byte(blob)
+   assert((b:float() - a:float()):abs():sum() == 0)
 end
 --test()
 return image_loader

+ 112 - 4
lib/iproc.lua

@@ -1,16 +1,78 @@
 local gm = require 'graphicsmagick'
 local image = require 'image'
+
 local iproc = {}
+local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 
-function iproc.scale(src, width, height, filter)
-   local t = "float"
+function iproc.crop_mod4(src)
+   local w = src:size(3) % 4
+   local h = src:size(2) % 4
+   return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h)
+end
+function iproc.crop(src, w1, h1, w2, h2)
+   local dest
+   if src:dim() == 3 then
+      dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
+   else -- dim == 2
+      dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
+   end
+   return dest
+end
+function iproc.crop_nocopy(src, w1, h1, w2, h2)
+   local dest
+   if src:dim() == 3 then
+      dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
+   else -- dim == 2
+      dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
+   end
+   return dest
+end
+function iproc.byte2float(src)
+   local conversion = false
+   local dest = src
    if src:type() == "torch.ByteTensor" then
-      t = "byte"
+      conversion = true
+      dest = src:float():div(255.0)
    end
+   return dest, conversion
+end
+function iproc.float2byte(src)
+   local conversion = false
+   local dest = src
+   if src:type() == "torch.FloatTensor" then
+      conversion = true
+      dest = (src + clip_eps8):mul(255.0)
+      dest[torch.lt(dest, 0.0)] = 0
+      dest[torch.gt(dest, 255.0)] = 255.0
+      dest = dest:byte()
+   end
+   return dest, conversion
+end
+function iproc.scale(src, width, height, filter)
+   local conversion
+   src, conversion = iproc.byte2float(src)
    filter = filter or "Box"
    local im = gm.Image(src, "RGB", "DHW")
    im:size(math.ceil(width), math.ceil(height), filter)
-   return im:toTensor(t, "RGB", "DHW")
+   local dest = im:toTensor("float", "RGB", "DHW")
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
+function iproc.scale_with_gamma22(src, width, height, filter)
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   filter = filter or "Box"
+   local im = gm.Image(src, "RGB", "DHW")
+   im:gammaCorrection(1.0 / 2.2):
+      size(math.ceil(width), math.ceil(height), filter):
+      gammaCorrection(2.2)
+   local dest = im:toTensor("float", "RGB", "DHW")
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
 end
 function iproc.padding(img, w1, w2, h1, h2)
    local dst_height = img:size(2) + h1 + h2
@@ -22,5 +84,51 @@ function iproc.padding(img, w1, w2, h1, h2)
    flow[2]:add(-w1)
    return image.warp(img, flow, "simple", false, "clamp")
 end
+function iproc.white_noise(src, std, rgb_weights, gamma)
+   gamma = gamma or 0.454545
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   std = std or 0.01
+
+   local noise = torch.Tensor():resizeAs(src):normal(0, std)
+   if rgb_weights then 
+      noise[1]:mul(rgb_weights[1])
+      noise[2]:mul(rgb_weights[2])
+      noise[3]:mul(rgb_weights[3])
+   end
+
+   local dest
+   if gamma ~= 0 then
+      dest = src:clone():pow(gamma):add(noise)
+      dest[torch.lt(dest, 0.0)] = 0.0
+      dest[torch.gt(dest, 1.0)] = 1.0
+      dest:pow(1.0 / gamma)
+   else
+      dest = src + noise
+   end
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
+
+local function test_conversion()
+   local a = torch.linspace(0, 255, 256):float():div(255.0)
+   local b = iproc.float2byte(a)
+   local c = iproc.byte2float(a)
+   local d = torch.linspace(0, 255, 256)
+   assert((a - c):abs():sum() == 0)
+   assert((d:float() - b:float()):abs():sum() == 0)
+
+   a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0)
+   b = iproc.float2byte(a)
+   assert(b:float():sum() == 255.0 * 3)
+
+   a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0)
+   b = iproc.float2byte(a)
+   print(b)
+   assert(b:float():sum() == 254.0 * 3)
+end
+--test_conversion()
 
 return iproc

+ 7 - 12
lib/minibatch_adam.lua

@@ -21,20 +21,15 @@ local function minibatch_adam(model, criterion,
 			       input_size[1], input_size[2], input_size[3])
    local targets_tmp = torch.Tensor(batch_size,
 				    target_size[1] * target_size[2] * target_size[3])
-   
-   for t = 1, #train_x, batch_size do
-      if t + batch_size > #train_x then
-	 break
-      end
+   for t = 1, #train_x do
       xlua.progress(t, #train_x)
-      for i = 1, batch_size do
-	 local x, y = transformer(train_x[shuffle[t + i - 1]])
-         inputs_tmp[i]:copy(x)
-	 targets_tmp[i]:copy(y)
+      local xy = transformer(train_x[shuffle[t]], false, batch_size)
+      for i = 1, #xy do
+         inputs_tmp[i]:copy(xy[i][1])
+	 targets_tmp[i]:copy(xy[i][2])
       end
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
-      
       local feval = function(x)
 	 if x ~= parameters then
 	    parameters:copy(x)
@@ -50,13 +45,13 @@ local function minibatch_adam(model, criterion,
       optim.adam(feval, parameters, config)
       
       c = c + 1
-      if c % 10 == 0 then
+      if c % 20 == 0 then
 	 collectgarbage()
       end
    end
    xlua.progress(#train_x, #train_x)
    
-   return { mse = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss}
 end
 
 return minibatch_adam

+ 203 - 240
lib/pairwise_transform.lua

@@ -1,69 +1,80 @@
 require 'image'
 local gm = require 'graphicsmagick'
-local iproc = require './iproc'
-local reconstruct = require './reconstruct'
+local iproc = require 'iproc'
+local data_augmentation = require 'data_augmentation'
+
 local pairwise_transform = {}
 
-local function random_half(src, p, min_size)
-   p = p or 0.5
-   local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
-   if p > torch.uniform() then
+local function random_half(src, p)
+   if torch.uniform() < p then
+      local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
       return src
    end
 end
-local function color_augment(x)
-   local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
-   x = x:float():div(255)
-   for i = 1, 3 do
-      x[i]:mul(color_scale[i])
+local function crop_if_large(src, max_size)
+   local tries = 4
+   if src:size(2) > max_size and src:size(3) > max_size then
+      local rect
+      for i = 1, tries do
+	 local yi = torch.random(0, src:size(2) - max_size)
+	 local xi = torch.random(0, src:size(3) - max_size)
+	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
+	 -- ignore simple background
+	 if rect:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect
+   else
+      return src
    end
-   x[torch.lt(x, 0.0)] = 0.0
-   x[torch.gt(x, 1.0)] = 1.0
-   return x:mul(255):byte()
 end
-local function flip_augment(x, y)
-   local flip = torch.random(1, 4)
-   if y then
-      if flip == 1 then
-	 x = image.hflip(x)
-	 y = image.hflip(y)
-      elseif flip == 2 then
-	 x = image.vflip(x)
-	 y = image.vflip(y)
-      elseif flip == 3 then
-	 x = image.hflip(image.vflip(x))
-	 y = image.hflip(image.vflip(y))
-      elseif flip == 4 then
-      end
-      return x, y
+local function preprocess(src, crop_size, options)
+   local dest = src
+   dest = random_half(dest, options.random_half_rate)
+   dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
+   dest = data_augmentation.flip(dest)
+   dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
+   dest = data_augmentation.overlay(dest, options.random_overlay_rate)
+   dest = data_augmentation.shift_1px(dest)
+   
+   return dest
+end
+local function active_cropping(x, y, size, p, tries)
+   assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
+   local r = torch.uniform()
+   if p < r then
+      local xi = torch.random(0, y:size(3) - (size + 1))
+      local yi = torch.random(0, y:size(2) - (size + 1))
+      local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
+      local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
+      return xc, yc
    else
-      if flip == 1 then
-	 x = image.hflip(x)
-      elseif flip == 2 then
-	 x = image.vflip(x)
-      elseif flip == 3 then
-	 x = image.hflip(image.vflip(x))
-      elseif flip == 4 then
+      local best_se = 0.0
+      local best_xc, best_yc
+      local m = torch.FloatTensor(x:size(1), size, size)
+      for i = 1, tries do
+	 local xi = torch.random(0, y:size(3) - (size + 1))
+	 local yi = torch.random(0, y:size(2) - (size + 1))
+	 local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
+	 local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
+	 local xcf = iproc.byte2float(xc)
+	 local ycf = iproc.byte2float(yc)
+	 local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
+	 if se >= best_se then
+	    best_xc = xcf
+	    best_yc = ycf
+	    best_se = se
+	 end
       end
-      return x
+      return best_xc, best_yc
    end
 end
-local INTERPOLATION_PADDING = 16
-function pairwise_transform.scale(src, scale, size, offset, options)
-   options = options or {color_augment = true, random_half = true, rgb = true}
-   if options.random_half then
-      src = random_half(src)
-   end
-   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
-   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
-   local down_scale = 1.0 / scale
-   local y = image.crop(src,
-			xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
-			xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
+function pairwise_transform.scale(src, scale, size, offset, n, options)
    local filters = {
-      "Box",        -- 0.012756949974688
+      "Box","Box",  -- 0.012756949974688
       "Blackman",   -- 0.013191924552285
       --"Cartom",     -- 0.013753536746706
       --"Hanning",    -- 0.013761314529647
@@ -71,221 +82,173 @@ function pairwise_transform.scale(src, scale, size, offset, options)
       "SincFast",   -- 0.014095824314306
       "Jinc",       -- 0.014244299255442
    }
+   local unstable_region_offset = 8
    local downscale_filter = filters[torch.random(1, #filters)]
+   local y = preprocess(src, size, options)
+   assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
+   local down_scale = 1.0 / scale
+   local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
+				     y:size(2) * down_scale, downscale_filter),
+			 y:size(3), y:size(2))
+   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
    
-   y = flip_augment(y)
-   if options.color_augment then
-      y = color_augment(y)
-   end
-   local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
-   x = iproc.scale(x, y:size(3), y:size(2))
-   y = y:float():div(255)
-   x = x:float():div(255)
-
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
+   local batch = {}
+   for i = 1, n do
+      local xc, yc = active_cropping(x, y,
+				     size,
+				     options.active_cropping_rate,
+				     options.active_cropping_tries)
+      xc = iproc.byte2float(xc)
+      yc = iproc.byte2float(yc)
+      if options.rgb then
+      else
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
-
-   y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset -	INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
-   x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
-   
-   return x, y
+   return batch
 end
-function pairwise_transform.jpeg_(src, quality, size, offset, options)
-   options = options or {color_augment = true, random_half = true, rgb = true}
-   if options.random_half then
-      src = random_half(src)
-   end
-   local yi = torch.random(0, src:size(2) - size - 1)
-   local xi = torch.random(0, src:size(3) - size - 1)
-   local y = src
-   local x
+function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
+   local unstable_region_offset = 8
+   local y = preprocess(src, size, options)
+   local x = y
 
-   if options.color_augment then
-      y = color_augment(y)
-   end
-   x = y
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg")
-      x:samplingFactors({1.0, 1.0, 1.0})
+      x:format("jpeg"):depth(8)
+      if options.jpeg_sampling_factors == 444 then
+	 x:samplingFactors({1.0, 1.0, 1.0})
+      else -- 420
+	 x:samplingFactors({2.0, 1.0, 1.0})
+      end
       local blob, len = x:toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
    end
+   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
    
-   y = image.crop(y, xi, yi, xi + size, yi + size)
-   x = image.crop(x, xi, yi, xi + size, yi + size)
-   y = y:float():div(255)
-   x = x:float():div(255)
-   x, y = flip_augment(x, y)
-   
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
-   end
-   
-   return x, image.crop(y, offset, offset, size - offset, size - offset)
-end
-function pairwise_transform.jpeg(src, level, size, offset, options)
-   if level == 1 then
-      return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-				      size, offset,
-				      options)
-   elseif level == 2 then
-      local r = torch.uniform()
-      if r > 0.6 then
-	 return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
-					 size, offset,
-					 options)
-      elseif r > 0.3 then
-	 local quality1 = torch.random(37, 70)
-	 local quality2 = quality1 - torch.random(5, 10)
-	 return pairwise_transform.jpeg_(src, {quality1, quality2},
-					    size, offset,
-					    options)
+   local batch = {}
+   for i = 1, n do
+      local xc, yc = active_cropping(x, y, size,
+				     options.active_cropping_rate,
+				     options.active_cropping_tries)
+      xc = iproc.byte2float(xc)
+      yc = iproc.byte2float(yc)
+      if options.rgb then
       else
-	 local quality1 = torch.random(52, 70)
-	 return pairwise_transform.jpeg_(src,
-					 {quality1,
-					  quality1 - torch.random(5, 15),
-					  quality1 - torch.random(15, 25)},
-					 size, offset,
-					 options)
+	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+      end
+      if torch.uniform() < options.nr_rate then
+	 -- reducing noise
+	 table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      else
+	 -- ratain useful details
+	 table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
       end
-   else
-      error("unknown noise level: " .. level)
-   end
-end
-function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
-   if options.random_half then
-      src = random_half(src)
-   end
-   local down_scale = 1.0 / scale
-   local filters = {
-      "Box",        -- 0.012756949974688
-      "Blackman",   -- 0.013191924552285
-      --"Cartom",     -- 0.013753536746706
-      --"Hanning",    -- 0.013761314529647
-      --"Hermite",    -- 0.013850225205266
-      "SincFast",   -- 0.014095824314306
-      "Jinc",       -- 0.014244299255442
-   }
-   local downscale_filter = filters[torch.random(1, #filters)]
-   local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
-   local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
-   local y = src
-   local x
-   
-   if options.color_augment then
-      y = color_augment(y)
-   end
-   x = y
-   x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
-   for i = 1, #quality do
-      x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg")
-      x:samplingFactors({1.0, 1.0, 1.0})
-      local blob, len = x:toBlob(quality[i])
-      x:fromBlob(blob, len)
-      x = x:toTensor("byte", "RGB", "DHW")
-   end
-   x = iproc.scale(x, y:size(3), y:size(2))
-   y = image.crop(y,
-		  xi, yi,
-		  xi + size, yi + size)
-   x = image.crop(x,
-		  xi, yi,
-		  xi + size, yi + size)
-   x = x:float():div(255)
-   y = y:float():div(255)
-   x, y = flip_augment(x, y)
-
-   if options.rgb then
-   else
-      y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
-      x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
    end
-   
-   return x, image.crop(y, offset, offset, size - offset, size - offset)
+   return batch
 end
-function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
-   options = options or {color_augment = true, random_half = true}
-   if level == 1 then
-      return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
-					    size, offset, options)
-   elseif level == 2 then
-      local r = torch.uniform()
-      if r > 0.6 then
-	 return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
-					       size, offset, options)
-      elseif r > 0.3 then
-	 local quality1 = torch.random(37, 70)
-	 local quality2 = quality1 - torch.random(5, 10)
-	 return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
-					       size, offset, options)
+function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
+   if style == "art" then
+      if level == 1 then
+	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
+					 size, offset, n, options)
+      elseif level == 2 then
+	 local r = torch.uniform()
+	 if r > 0.6 then
+	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
+					    size, offset, n, options)
+	 elseif r > 0.3 then
+	    local quality1 = torch.random(37, 70)
+	    local quality2 = quality1 - torch.random(5, 10)
+	    return pairwise_transform.jpeg_(src, {quality1, quality2},
+					    size, offset, n, options)
+	 else
+	    local quality1 = torch.random(52, 70)
+	    local quality2 = quality1 - torch.random(5, 15)
+	    local quality3 = quality1 - torch.random(15, 25)
+	    
+	    return pairwise_transform.jpeg_(src, 
+					    {quality1, quality2, quality3},
+					    size, offset, n, options)
+	 end
+      else
+	 error("unknown noise level: " .. level)
+      end
+   elseif style == "photo" then
+      if level == 1 then
+	 return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
+					 size, offset, n,
+					 options)
+      elseif level == 2 then
+	 if torch.uniform() > 0.6 then
+	    return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
+					    size, offset, n, options)
+	 else
+	    local quality1 = torch.random(40, 60)
+	    local quality2 = quality1 - torch.random(5, 10)
+	    return pairwise_transform.jpeg_(src, {quality1, quality2},
+					    size, offset, n, options)
+	 end
       else
-	 local quality1 = torch.random(52, 70)
-	    return pairwise_transform.jpeg_scale_(src, scale,
-						  {quality1,
-						   quality1 - torch.random(5, 15),
-						   quality1 - torch.random(15, 25)},
-						  size, offset, options)
+	 error("unknown noise level: " .. level)
       end
    else
-      error("unknown noise level: " .. level)
-   end
-end
-
-local function test_jpeg()
-   local loader = require './image_loader'
-   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
-   local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false)
-   image.display({image = y, legend = "y:0"})
-   image.display({image = x, legend = "x:0"})
-   for i = 2, 9 do
-      local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src),
-					    {i * 10}, 128, 0, {color_augment = false, random_half = true})
-      image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
-      image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
-      --print(x:mean(), y:mean())
+      error("unknown style: " .. style)
    end
 end
 
-local function test_scale()
-   local loader = require './image_loader'
-   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
+function pairwise_transform.test_jpeg(src)
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local options = {random_color_noise_rate = 0.5,
+		    random_half_rate = 0.5,
+		    random_overlay_rate = 0.5,
+		    nr_rate = 1.0,
+		    active_cropping_rate = 0.5,
+		    active_cropping_tries = 10,
+		    max_size = 256,
+		    rgb = true
+   }
+   local image = require 'image'
+   local src = image.lena()
    for i = 1, 9 do
-      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
-      image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
-      --print(x:mean(), y:mean())
+      local xy = pairwise_transform.jpeg(src,
+					 "art",
+					 torch.random(1, 2),
+					 128, 7, 1, options)
+      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
+      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
    end
 end
-local function test_jpeg_scale()
-   local loader = require './image_loader'
-   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
-   for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true})
-      image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
-      --print(x:mean(), y:mean())
-   end
-   for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true})
-      image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
-      image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
-      print(y:size(), x:size())
-      --print(x:mean(), y:mean())
+function pairwise_transform.test_scale(src)
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local options = {random_color_noise_rate = 0.5,
+		    random_half_rate = 0.5,
+		    random_overlay_rate = 0.5,
+		    active_cropping_rate = 0.5,
+		    active_cropping_tries = 10,
+		    max_size = 256,
+		    rgb = true
+   }
+   local image = require 'image'
+   local src = image.lena()
+
+   for i = 1, 10 do
+      local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
+      image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
+      image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
    end
 end
---test_scale()
---test_jpeg()
---test_jpeg_scale()
-
 return pairwise_transform

+ 0 - 4
lib/portable.lua

@@ -1,4 +0,0 @@
-require 'torch'
-require 'cutorch'
-require 'nn'
-require 'cunn'

+ 95 - 16
lib/reconstruct.lua

@@ -1,5 +1,5 @@
 require 'image'
-local iproc = require './iproc'
+local iproc = require 'iproc'
 
 local function reconstruct_y(model, x, offset, block_size)
    if x:dim() == 2 then
@@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
    end
    return new_x
 end
-function model_is_rgb(model)
+local reconstruct = {}
+function reconstruct.is_rgb(model)
    if model:get(model:size() - 1).weight:size(1) == 3 then
       -- 3ch RGB
       return true
@@ -57,8 +58,23 @@ function model_is_rgb(model)
       return false
    end
 end
-
-local reconstruct = {}
+function reconstruct.offset_size(model)
+   local conv = model:findModules("nn.SpatialConvolutionMM")
+   if #conv > 0 then
+      local offset = 0
+      for i = 1, #conv do
+	 offset = offset + (conv[i].kW - 1) / 2
+      end
+      return math.floor(offset)
+   else
+      conv = model:findModules("cudnn.SpatialConvolution")
+      local offset = 0
+      for i = 1, #conv do
+	 offset = offset + (conv[i].kW - 1) / 2
+      end
+      return math.floor(offset)
+   end
+end
 function reconstruct.image_y(model, x, offset, block_size)
    block_size = block_size or 128
    local output_size = block_size - offset * 2
@@ -78,7 +94,7 @@ function reconstruct.image_y(model, x, offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv[1]:copy(y)
-   local output = image.yuv2rgb(image.crop(yuv,
+   local output = image.yuv2rgb(iproc.crop(yuv,
 					   pad_w1, pad_h1,
 					   yuv:size(3) - pad_w2, yuv:size(2) - pad_h2))
    output[torch.lt(output, 0)] = 0
@@ -110,7 +126,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
    yuv_jinc[1]:copy(y)
-   local output = image.yuv2rgb(image.crop(yuv_jinc,
+   local output = image.yuv2rgb(iproc.crop(yuv_jinc,
 					   pad_w1, pad_h1,
 					   yuv_jinc:size(3) - pad_w2, yuv_jinc:size(2) - pad_h2))
    output[torch.lt(output, 0)] = 0
@@ -135,7 +151,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
    local pad_w2 = (w - offset) - x:size(3)
    local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
    local y = reconstruct_rgb(model, input, offset, block_size)
-   local output = image.crop(y,
+   local output = iproc.crop(y,
 			     pad_w1, pad_h1,
 			     y:size(3) - pad_w2, y:size(2) - pad_h2)
    collectgarbage()
@@ -162,7 +178,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
    local pad_w2 = (w - offset) - x:size(3)
    local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
    local y = reconstruct_rgb(model, input, offset, block_size)
-   local output = image.crop(y,
+   local output = iproc.crop(y,
 			     pad_w1, pad_h1,
 			     y:size(3) - pad_w2, y:size(2) - pad_h2)
    output[torch.lt(output, 0)] = 0
@@ -172,18 +188,81 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
    return output
 end
 
-function reconstruct.image(model, x, offset, block_size)
-   if model_is_rgb(model) then
-      return reconstruct.image_rgb(model, x, offset, block_size)
+function reconstruct.image(model, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return reconstruct.image_rgb(model, x,
+				   reconstruct.offset_size(model), block_size)
+   else
+      return reconstruct.image_y(model, x,
+				 reconstruct.offset_size(model), block_size)
+   end
+end
+function reconstruct.scale(model, scale, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return reconstruct.scale_rgb(model, scale, x,
+				   reconstruct.offset_size(model), block_size)
    else
-      return reconstruct.image_y(model, x, offset, block_size)
+      return reconstruct.scale_y(model, scale, x,
+				 reconstruct.offset_size(model), block_size)
    end
 end
-function reconstruct.scale(model, scale, x, offset, block_size)
-   if model_is_rgb(model) then
-      return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+local function tta(f, model, x, block_size)
+   local average = nil
+   local offset = reconstruct.offset_size(model)
+   for i = 1, 4 do 
+      local flip_f, iflip_f
+      if i == 1 then
+	 flip_f = function (a) return a end
+	 iflip_f = function (a) return a end
+      elseif i == 2 then
+	 flip_f = image.vflip
+	 iflip_f = image.vflip
+      elseif i == 3 then
+	 flip_f = image.hflip
+	 iflip_f = image.hflip
+      elseif i == 4 then
+	 flip_f = function (a) return image.hflip(image.vflip(a)) end
+	 iflip_f = function (a) return image.vflip(image.hflip(a)) end
+      end
+      for j = 1, 2 do
+	 local tr_f, itr_f
+	 if j == 1 then
+	    tr_f = function (a) return a end
+	    itr_f = function (a) return a end
+	 elseif j == 2 then
+	    tr_f = function(a) return a:transpose(2, 3):contiguous() end
+	    itr_f = function(a) return a:transpose(2, 3):contiguous() end
+	 end
+	 local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
+				     offset, block_size)))
+	 if not average then
+	    average = out
+	 else
+	    average:add(out)
+	 end
+      end
+   end
+   return average:div(8.0)
+end
+function reconstruct.image_tta(model, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return tta(reconstruct.image_rgb, model, x, block_size)
    else
-      return reconstruct.scale_y(model, scale, x, offset, block_size)
+      return tta(reconstruct.image_y, model, x, block_size)
+   end
+end
+function reconstruct.scale_tta(model, scale, x, block_size)
+   if reconstruct.is_rgb(model) then
+      local f = function (model, x, offset, block_size)
+	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+      end
+      return tta(f, model, x, block_size)
+		 
+   else
+      local f = function (model, x, offset, block_size)
+	 return reconstruct.scale_y(model, scale, x, offset, block_size)
+      end
+      return tta(f, model, x, block_size)
    end
 end
 

+ 31 - 28
lib/settings.lua

@@ -1,5 +1,6 @@
 require 'xlua'
 require 'pl'
+require 'trepl'
 
 -- global settings
 
@@ -14,22 +15,34 @@ local settings = {}
 
 local cmd = torch.CmdLine()
 cmd:text()
-cmd:text("waifu2x")
+cmd:text("waifu2x-training")
 cmd:text("Options:")
-cmd:option("-seed", 11, 'fixed input seed')
-cmd:option("-data_dir", "./data", 'data directory')
-cmd:option("-test", "images/miku_small.png", 'test image file')
+cmd:option("-gpu", -1, 'GPU Device ID')
+cmd:option("-seed", 11, 'RNG seed')
+cmd:option("-data_dir", "./data", 'path to data directory')
+cmd:option("-backend", "cunn", '(cunn|cudnn)')
+cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
-cmd:option("-method", "scale", '(noise|scale|noise_scale)')
+cmd:option("-method", "scale", 'method to training (noise|scale)')
 cmd:option("-noise_level", 1, '(1|2)')
+cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
-cmd:option("-scale", 2.0, 'scale')
+cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
+cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
+cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
+cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
-cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
-cmd:option("-crop_size", 128, 'crop size')
-cmd:option("-batch_size", 2, 'mini batch size')
-cmd:option("-epoch", 200, 'epoch')
-cmd:option("-core", 2, 'cpu core')
+cmd:option("-crop_size", 46, 'crop size')
+cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
+cmd:option("-batch_size", 8, 'mini batch size')
+cmd:option("-epoch", 200, 'number of total epochs to run')
+cmd:option("-thread", -1, 'number of CPU threads')
+cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
+cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
+cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
+cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
+cmd:option("-active_cropping_tries", 10, 'active cropping tries')
+cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
@@ -53,26 +66,16 @@ end
 if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
    error("scale must be mod-2")
 end
-if settings.random_half == 1 then
-   settings.random_half = true
-else
-   settings.random_half = false
+if not (settings.style == "art" or
+	settings.style == "photo") then
+   error(string.format("unknown style: %s", settings.style))
+end
+
+if settings.thread > 0 then
+   torch.setnumthreads(tonumber(settings.thread))
 end
-torch.setnumthreads(settings.core)
 
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
-settings.validation_ratio = 0.1
-settings.validation_crops = 40
-
-local srcnn = require './srcnn'
-if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
-   settings.create_model = srcnn.waifu4x
-   settings.block_offset = 13
-else
-   settings.create_model = srcnn.waifu2x
-   settings.block_offset = 7
-end
-
 return settings

+ 46 - 52
lib/srcnn.lua

@@ -1,74 +1,68 @@
-require './LeakyReLU'
+require 'w2nn'
 
-function nn.SpatialConvolutionMM:reset(stdv)
-   stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
-   self.weight:normal(0, stdv)
-   self.bias:fill(0)
-end
+-- ref: http://arxiv.org/abs/1502.01852
+-- ref: http://arxiv.org/abs/1501.00092
 local srcnn = {}
-function srcnn.waifu2x(color)
+function srcnn.channels(model)
+   return model:get(model:size() - 1).weight:size(1)
+end
+function srcnn.waifu2x_cunn(ch)
    local model = nn.Sequential()
-   local ch = nil
-   if color == "rgb" then
-      ch = 3
-   elseif color == "y" then
-      ch = 1
-   else
-      if color then
-	 error("unknown color: " .. color)
-      else
-	 error("unknown color: nil")
-      end
-   end
-   
    model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
+   model:add(w2nn.LeakyReLU(0.1))
    model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
---model:cuda()
---print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
-   return model, 7
+   return model
 end
-
--- current 4x is worse than 2x * 2
-function srcnn.waifu4x(color)
+function srcnn.waifu2x_cudnn(ch)
    local model = nn.Sequential()
-
-   local ch = nil
+   model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.LeakyReLU(0.1))
+   model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(nn.View(-1):setNumInputDims(3))
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+   
+   return model
+end
+function srcnn.create(model_name, backend, color)
+   local ch = 3
    if color == "rgb" then
       ch = 3
    elseif color == "y" then
       ch = 1
    else
-      error("unknown color: " .. color)
+      error("unsupported color: " + color)
+   end
+   if backend == "cunn" then
+      return srcnn.waifu2x_cunn(ch)
+   elseif backend == "cudnn" then
+      return srcnn.waifu2x_cudnn(ch)
+   else
+      error("unsupported backend: " +  backend)
    end
-   
-   model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1))
-   model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
-   model:add(nn.View(-1):setNumInputDims(3))
-   
-   return model, 13
 end
 return srcnn

+ 26 - 0
lib/w2nn.lua

@@ -0,0 +1,26 @@
+local function load_nn()
+   require 'torch'
+   require 'nn'
+end
+local function load_cunn()
+   require 'cutorch'
+   require 'cunn'
+end
+local function load_cudnn()
+   require 'cudnn'
+   cudnn.benchmark = true
+end
+if w2nn then
+   return w2nn
+else
+   pcall(load_cunn)
+   pcall(load_cudnn)
+   w2nn = {}
+   require 'LeakyReLU'
+   require 'LeakyReLU_deprecated'
+   require 'DepthExpand2x'
+   require 'WeightedMSECriterion'
+   require 'ClippedWeightedHuberCriterion'
+   require 'cleanup_model'
+   return w2nn
+end

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 0 - 0
models/anime_style_art_rgb/noise1_model.json


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 33 - 27
models/anime_style_art_rgb/noise1_model.t7


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 0 - 0
models/anime_style_art_rgb/noise2_model.json


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 33 - 27
models/anime_style_art_rgb/noise2_model.t7


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 0 - 0
models/anime_style_art_rgb/scale2.0x_model.json


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 33 - 27
models/anime_style_art_rgb/scale2.0x_model.t7


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 0 - 0
models/ukbench/scale2.0x_model.json


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 33 - 27
models/ukbench/scale2.0x_model.t7


+ 169 - 0
tools/benchmark.lua

@@ -0,0 +1,169 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'xlua'
+require 'w2nn'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
+local gm = require 'graphicsmagick'
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x-benchmark")
+cmd:text("Options:")
+
+cmd:option("-dir", "./data/test", 'test image directory')
+cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
+cmd:option("-model2_dir", "", 'model2 directory (optional)')
+cmd:option("-method", "scale", '(scale|noise)')
+cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
+cmd:option("-color", "rgb", '(rgb|y)')
+cmd:option("-noise_level", 1, 'model noise level')
+cmd:option("-jpeg_quality", 75, 'jpeg quality')
+cmd:option("-jpeg_times", 1, 'jpeg compression times')
+cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
+
+local opt = cmd:parse(arg)
+torch.setdefaulttensortype('torch.FloatTensor')
+if cudnn then
+   cudnn.fastest = true
+   cudnn.benchmark = false
+end
+
+local function MSE(x1, x2)
+   return (x1 - x2):pow(2):mean()
+end
+local function YMSE(x1, x2)
+   local x1_2 = image.rgb2y(x1)
+   local x2_2 = image.rgb2y(x2)
+   return (x1_2 - x2_2):pow(2):mean()
+end
+local function PSNR(x1, x2)
+   local mse = MSE(x1, x2)
+   return 10 * math.log10(1.0 / mse)
+end
+local function YPSNR(x1, x2)
+   local mse = YMSE(x1, x2)
+   return 10 * math.log10(1.0 / mse)
+end
+
+local function transform_jpeg(x, opt)
+   for i = 1, opt.jpeg_times do
+      jpeg = gm.Image(x, "RGB", "DHW")
+      jpeg:format("jpeg")
+      jpeg:samplingFactors({1.0, 1.0, 1.0})
+      blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
+      jpeg:fromBlob(blob, len)
+      x = jpeg:toTensor("byte", "RGB", "DHW")
+   end
+   return x
+end
+local function transform_scale(x, opt)
+   return iproc.scale(x,
+		      x:size(3) * 0.5,
+		      x:size(2) * 0.5,
+		      opt.filter)
+end
+
+local function benchmark(opt, x, input_func, model1, model2)
+   local model1_mse = 0
+   local model2_mse = 0
+   local model1_psnr = 0
+   local model2_psnr = 0
+   
+   for i = 1, #x do
+      local ground_truth = x[i]
+      local input, model1_output, model2_output
+
+      input = input_func(ground_truth, opt)
+      input = input:float():div(255)
+      ground_truth = ground_truth:float():div(255)
+      
+      t = sys.clock()
+      if input:size(3) == ground_truth:size(3) then
+	 model1_output = reconstruct.image(model1, input)
+	 if model2 then
+	    model2_output = reconstruct.image(model2, input)
+	 end
+      else
+	 model1_output = reconstruct.scale(model1, 2.0, input)
+	 if model2 then
+	    model2_output = reconstruct.scale(model2, 2.0, input)
+	 end
+      end
+      if opt.color == "y" then
+	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
+	 model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
+	 if model2 then
+	    model2_mse = model2_mse + YMSE(ground_truth, model2_output)
+	    model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
+	 end
+      elseif opt.color == "rgb" then
+	 model1_mse = model1_mse + MSE(ground_truth, model1_output)
+	 model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
+	 if model2 then
+	    model2_mse = model2_mse + MSE(ground_truth, model2_output)
+	    model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
+	 end
+      else
+	 error("Unknown color: " .. opt.color)
+      end
+      if model2 then
+	 io.stdout:write(
+	    string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
+			  i, #x,
+			  model1_mse / i, model2_mse / i,
+			  model1_psnr / i, model2_psnr / i
+	    ))
+      else
+	 io.stdout:write(
+	    string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
+			  i, #x,
+			  model1_mse / i, model1_psnr / i
+	    ))
+      end
+      io.stdout:flush()
+   end
+   io.stdout:write("\n")
+end
+local function load_data(test_dir)
+   local test_x = {}
+   local files = dir.getfiles(test_dir, "*.*")
+   for i = 1, #files do
+      table.insert(test_x, iproc.crop_mod4(image_loader.load_byte(files[i])))
+      xlua.progress(i, #files)
+   end
+   return test_x
+end
+function load_model(filename)
+   return torch.load(filename, "ascii")
+end
+print(opt)
+if opt.method == "scale" then
+   local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
+   local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
+   local s1, model1 = pcall(load_model, f1)
+   local s2, model2 = pcall(load_model, f2)
+   if not s1 then
+      error("Load error: " .. f1)
+   end
+   if not s2 then
+      model2 = nil
+   end
+   local test_x = load_data(opt.dir)
+   benchmark(opt, test_x, transform_scale, model1, model2)
+elseif opt.method == "noise" then
+   local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
+   local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
+   local s1, model1 = pcall(load_model, f1)
+   local s2, model2 = pcall(load_model, f2)
+   if not s1 then
+      error("Load error: " .. f1)
+   end
+   if not s2 then
+      model2 = nil
+   end
+   local test_x = load_data(opt.dir)
+   benchmark(opt, test_x, transform_jpeg, model1, model2)
+end

+ 25 - 0
tools/cleanup_model.lua

@@ -0,0 +1,25 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+
+require 'w2nn'
+torch.setdefaulttensortype("torch.FloatTensor")
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("cleanup model")
+cmd:text("Options:")
+cmd:option("-model", "./model.t7", 'path of model file')
+cmd:option("-iformat", "binary", 'input format')
+cmd:option("-oformat", "binary", 'output format')
+
+local opt = cmd:parse(arg)
+local model = torch.load(opt.model, opt.iformat)
+if model then
+   w2nn.cleanup_model(model)
+   model:cuda()
+   model:evaluate()
+   torch.save(opt.model, model, opt.oformat)
+else
+   error("model not found")
+end

+ 43 - 0
tools/cudnn2cunn.lua

@@ -0,0 +1,43 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'os'
+require 'w2nn'
+local srcnn = require 'srcnn'
+
+local function cudnn2cunn(cudnn_model)
+   local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
+   local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
+   local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
+   
+   assert(#weight_from == #weight_to)
+   
+   for i = 1, #weight_from do
+      local from = weight_from[i]
+      local to = weight_to[i]
+      
+      to.weight:copy(from.weight)
+      to.bias:copy(from.bias)
+   end
+   cunn_model:cuda()
+   cunn_model:evaluate()
+   return cunn_model
+end
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x cudnn model to cunn model converter")
+cmd:text("Options:")
+cmd:option("-i", "", 'Specify the input cunn model')
+cmd:option("-o", "", 'Specify the output cudnn model')
+cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
+cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
+
+local opt = cmd:parse(arg)
+if not path.isfile(opt.i) then
+   cmd:help()
+   os.exit(-1)
+end
+local cudnn_model = torch.load(opt.i, opt.iformat)
+local cunn_model = cudnn2cunn(cudnn_model)
+torch.save(opt.o, cunn_model, opt.oformat)

+ 43 - 0
tools/cunn2cudnn.lua

@@ -0,0 +1,43 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'os'
+require 'w2nn'
+local srcnn = require 'srcnn'
+
+local function cunn2cudnn(cunn_model)
+   local cudnn_model = srcnn.waifu2x_cudnn(srcnn.channels(cunn_model))
+   local weight_from = cunn_model:findModules("nn.SpatialConvolutionMM")
+   local weight_to = cudnn_model:findModules("cudnn.SpatialConvolution")
+
+   assert(#weight_from == #weight_to)
+   
+   for i = 1, #weight_from do
+      local from = weight_from[i]
+      local to = weight_to[i]
+      
+      to.weight:copy(from.weight)
+      to.bias:copy(from.bias)
+   end
+   cudnn_model:cuda()
+   cudnn_model:evaluate()
+   return cudnn_model
+end
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x cunn model to cudnn model converter")
+cmd:text("Options:")
+cmd:option("-i", "", 'Specify the input cudnn model')
+cmd:option("-o", "", 'Specify the output cunn model')
+cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
+cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
+
+local opt = cmd:parse(arg)
+if not path.isfile(opt.i) then
+   cmd:help()
+   os.exit(-1)
+end
+local cunn_model = torch.load(opt.i, opt.iformat)
+local cudnn_model = cunn2cudnn(cunn_model)
+torch.save(opt.o, cudnn_model, opt.oformat)

+ 54 - 0
tools/export_model.lua

@@ -0,0 +1,54 @@
+-- adapted from https://github.com/marcan/cl-waifu2x
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'w2nn'
+local cjson = require "cjson"
+
+function export(model, output)
+   local jmodules = {}
+   local modules = model:findModules("nn.SpatialConvolutionMM")
+   if #modules == 0 then
+      -- cudnn model
+      modules = model:findModules("cudnn.SpatialConvolution")
+   end
+   for i = 1, #modules, 1 do
+      local module = modules[i]
+      local jmod = {
+	 kW = module.kW,
+	 kH = module.kH,
+	 nInputPlane = module.nInputPlane,
+	 nOutputPlane = module.nOutputPlane,
+	 bias = torch.totable(module.bias:float()),
+	 weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
+      }
+      table.insert(jmodules, jmod)
+   end
+   jmodules[1].color = "RGB"
+   jmodules[1].gamma = 0
+   jmodules[#jmodules].color = "RGB"
+   jmodules[#jmodules].gamma = 0
+   
+   local fp = io.open(output, "w")
+   if not fp then
+      error("IO Error: " .. output)
+   end
+   fp:write(cjson.encode(jmodules))
+   fp:close()
+end
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x export model")
+cmd:text("Options:")
+cmd:option("-i", "input.t7", 'Specify the input torch model')
+cmd:option("-o", "output.json", 'Specify the output json file')
+cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
+
+local opt = cmd:parse(arg)
+if not path.isfile(opt.i) then
+   cmd:help()
+   os.exit(-1)
+end
+local model = torch.load(opt.i, opt.iformat)
+export(model, opt.o)

+ 112 - 63
train.lua

@@ -1,21 +1,25 @@
-require './lib/portable'
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 require 'optim'
 require 'xlua'
-require 'pl'
 
-local settings = require './lib/settings'
-local minibatch_adam = require './lib/minibatch_adam'
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local pairwise_transform = require './lib/pairwise_transform'
-local image_loader = require './lib/image_loader'
+require 'w2nn'
+local settings = require 'settings'
+local srcnn = require 'srcnn'
+local minibatch_adam = require 'minibatch_adam'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local compression = require 'compression'
+local pairwise_transform = require 'pairwise_transform'
+local image_loader = require 'image_loader'
 
 local function save_test_scale(model, rgb, file)
-   local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
+   local up = reconstruct.scale(model, settings.scale, rgb)
    image.save(file, up)
 end
 local function save_test_jpeg(model, rgb, file)
-   local im, count = reconstruct.image(model, rgb, settings.block_offset)
+   local im, count = reconstruct.image(model, rgb)
    image.save(file, im)
 end
 local function split_data(x, test_size)
@@ -31,14 +35,19 @@ local function split_data(x, test_size)
    end
    return train_x, valid_x
 end
-local function make_validation_set(x, transformer, n)
+local function make_validation_set(x, transformer, n, batch_size)
    n = n or 4
    local data = {}
    for i = 1, #x do
-      for k = 1, n do
-	 local x, y = transformer(x[i], true)
-	 table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
-			     y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+      for k = 1, math.max(n / batch_size, 1) do
+	 local xy = transformer(x[i], true, batch_size)
+	 local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
+	 local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
+	 for j = 1, #xy do
+	    tx[j]:copy(xy[j][1])
+	    ty[j]:copy(xy[j][2])
+	 end
+	 table.insert(data, {x = tx, y = ty})
       end
       xlua.progress(i, #x)
       collectgarbage()
@@ -50,24 +59,92 @@ local function validate(model, criterion, data)
    for i = 1, #data do
       local z = model:forward(data[i].x:cuda())
       loss = loss + criterion:forward(z, data[i].y:cuda())
-      xlua.progress(i, #data)
-      if i % 10 == 0 then
+      if i % 100 == 0 then
+	 xlua.progress(i, #data)
 	 collectgarbage()
       end
    end
+   xlua.progress(#data, #data)
    return loss / #data
 end
 
+local function create_criterion(model)
+   if reconstruct.is_rgb(model) then
+      local offset = reconstruct.offset_size(model)
+      local output_w = settings.crop_size - offset * 2
+      local weight = torch.Tensor(3, output_w * output_w)
+      weight[1]:fill(0.29891 * 3) -- R
+      weight[2]:fill(0.58661 * 3) -- G
+      weight[3]:fill(0.11448 * 3) -- B
+      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
+   else
+      return nn.MSECriterion():cuda()
+   end
+end
+local function transformer(x, is_validation, n, offset)
+   x = compression.decompress(x)
+   n = n or settings.batch_size;
+   if is_validation == nil then is_validation = false end
+   local random_color_noise_rate = nil 
+   local random_overlay_rate = nil
+   local active_cropping_rate = nil
+   local active_cropping_tries = nil
+   if is_validation then
+      active_cropping_rate = 0
+      active_cropping_tries = 0
+      random_color_noise_rate = 0.0
+      random_overlay_rate = 0.0
+   else
+      active_cropping_rate = settings.active_cropping_rate
+      active_cropping_tries = settings.active_cropping_tries
+      random_color_noise_rate = settings.random_color_noise_rate
+      random_overlay_rate = settings.random_overlay_rate
+   end
+   
+   if settings.method == "scale" then
+      return pairwise_transform.scale(x,
+				      settings.scale,
+				      settings.crop_size, offset,
+				      n,
+				      {
+					 random_half_rate = settings.random_half_rate,
+					 random_color_noise_rate = random_color_noise_rate,
+					 random_overlay_rate = random_overlay_rate,
+					 max_size = settings.max_size,
+					 active_cropping_rate = active_cropping_rate,
+					 active_cropping_tries = active_cropping_tries,
+					 rgb = (settings.color == "rgb")
+				      })
+   elseif settings.method == "noise" then
+      return pairwise_transform.jpeg(x,
+				     settings.style,
+				     settings.noise_level,
+				     settings.crop_size, offset,
+				     n,
+				     {
+					random_half_rate = settings.random_half_rate,
+					random_color_noise_rate = random_color_noise_rate,
+					random_overlay_rate = random_overlay_rate,
+					max_size = settings.max_size,
+					jpeg_sampling_factors = settings.jpeg_sampling_factors,
+					active_cropping_rate = active_cropping_rate,
+					active_cropping_tries = active_cropping_tries,
+					nr_rate = settings.nr_rate,
+					rgb = (settings.color == "rgb")
+				     })
+   end
+end
+
 local function train()
-   local model, offset = settings.create_model(settings.color)
-   assert(offset == settings.block_offset)
-   local criterion = nn.MSECriterion():cuda()
+   local model = srcnn.create(settings.method, settings.backend, settings.color)
+   local offset = reconstruct.offset_size(model)
+   local pairwise_func = function(x, is_validation, n)
+      return transformer(x, is_validation, n, offset)
+   end
+   local criterion = create_criterion(model)
    local x = torch.load(settings.images)
    local lrd_count = 0
-   local train_x, valid_x = split_data(x,
-				       math.floor(settings.validation_ratio * #x),
-				       settings.validation_crops)
-   local test = image_loader.load_float(settings.test)
+   local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
@@ -78,38 +155,11 @@ local function train()
    elseif settings.color == "rgb" then
       ch = 3
    end
-   local transformer = function(x, is_validation)
-      if is_validation == nil then is_validation = false end
-      if settings.method == "scale" then
-	 return pairwise_transform.scale(x,
-					 settings.scale,
-					 settings.crop_size, offset,
-					 { color_augment = not is_validation,
-					   random_half = settings.random_half,
-					   rgb = (settings.color == "rgb")
-					 })
-      elseif settings.method == "noise" then
-	 return pairwise_transform.jpeg(x,
-					settings.noise_level,
-					settings.crop_size, offset,
-					{ color_augment = not is_validation,
-					  random_half = settings.random_half,
-					  rgb = (settings.color == "rgb")
-					})
-      elseif settings.method == "noise_scale" then
-	 return pairwise_transform.jpeg_scale(x,
-					      settings.scale,
-					      settings.noise_level,
-					      settings.crop_size, offset,
-					      { color_augment = not is_validation,
-						random_half = settings.random_half,
-						rgb = (settings.color == "rgb")
-					      })
-      end
-   end
    local best_score = 100000.0
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, transformer, 20)
+   local valid_xy = make_validation_set(valid_x, pairwise_func,
+					settings.validation_crops,
+					settings.batch_size)
    valid_x = nil
    
    collectgarbage()
@@ -119,7 +169,7 @@ local function train()
       model:training()
       print("# " .. epoch)
       print(minibatch_adam(model, criterion, train_x, adam_config,
-			   transformer,
+			   pairwise_func,
 			   {ch, settings.crop_size, settings.crop_size},
 			   {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
 			  ))
@@ -127,6 +177,7 @@ local function train()
       print("# validation")
       local score = validate(model, criterion, valid_xy)
       if score < best_score then
+	 local test_image = image_loader.load_float(settings.test) -- reload
 	 lrd_count = 0
 	 best_score = score
 	 print("* update best model")
@@ -134,22 +185,17 @@ local function train()
 	 if settings.method == "noise" then
 	    local log = path.join(settings.model_dir,
 				  ("noise%d_best.png"):format(settings.noise_level))
-	    save_test_jpeg(model, test, log)
+	    save_test_jpeg(model, test_image, log)
 	 elseif settings.method == "scale" then
 	    local log = path.join(settings.model_dir,
 				  ("scale%.1f_best.png"):format(settings.scale))
-	    save_test_scale(model, test, log)
-	 elseif settings.method == "noise_scale" then
-	    local log = path.join(settings.model_dir,
-				  ("noise%d_scale%.1f_best.png"):format(settings.noise_level,
-									settings.scale))
-	    save_test_scale(model, test, log)
+	    save_test_scale(model, test_image, log)
 	 end
       else
 	 lrd_count = lrd_count + 1
 	 if lrd_count > 5 then
 	    lrd_count = 0
-	    adam_config.learningRate = adam_config.learningRate * 0.8
+	    adam_config.learningRate = adam_config.learningRate * 0.9
 	    print("* learning rate decay: " .. adam_config.learningRate)
 	 end
       end
@@ -157,6 +203,9 @@ local function train()
       collectgarbage()
    end
 end
+if settings.gpu > 0 then
+   cutorch.setDevice(settings.gpu)
+end
 torch.manualSeed(settings.seed)
 cutorch.manualSeed(settings.seed)
 print(settings)

+ 8 - 6
train.sh

@@ -1,10 +1,12 @@
 #!/bin/sh
 
-th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
-th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
+th convert_data.lua
 
-th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
-th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
+th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
 
-th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
-th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
+th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
+
+th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

+ 9 - 0
train_ukbench.sh

@@ -0,0 +1,9 @@
+#!/bin/sh
+
+th convert_data.lua -data_dir ./data/ukbench
+
+#th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.png -nr_rate 0.9 -jpeg_sampling_factors 420 # -thread 4 -backend cudnn 
+#th tools/cleanup_model.lua -model models/ukbench/noise2_model.t7 -oformat ascii
+
+th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg # -thread 4 -backend cudnn
+th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii

+ 124 - 42
waifu2x.lua

@@ -1,12 +1,11 @@
-require './lib/portable'
-require 'sys'
 require 'pl'
-require './lib/LeakyReLU'
-
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local image_loader = require './lib/image_loader'
-local BLOCK_OFFSET = 7
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
+require 'sys'
+require 'w2nn'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
@@ -14,43 +13,109 @@ local function convert_image(opt)
    local x, alpha = image_loader.load_float(opt.i)
    local new_x = nil
    local t = sys.clock()
+   local scale_f, image_f
+   if opt.tta == 1 then
+      scale_f = reconstruct.scale_tta
+      image_f = reconstruct.image_tta
+   else
+      scale_f = reconstruct.scale
+      image_f = reconstruct.image
+   end
    if opt.o == "(auto)" then
       local name = path.basename(opt.i)
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
+      opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
    end
    if opt.m == "noise" then
-      local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
-      model:evaluate()
-      new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
+      local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
+      local model = torch.load(model_path, "ascii")
+      if not model then
+	 error("Load Error: " .. model_path)
+      end
+      new_x = image_f(model, x, opt.crop_size)
    elseif opt.m == "scale" then
-      local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-      model:evaluate()
-      new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+      local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      local model = torch.load(model_path, "ascii")
+      if not model then
+	 error("Load Error: " .. model_path)
+      end
+      new_x = scale_f(model, opt.scale, x, opt.crop_size)
    elseif opt.m == "noise_scale" then
-      local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
-      local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-      noise_model:evaluate()
-      scale_model:evaluate()
-      x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
-      new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+      local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
+      local noise_model = torch.load(noise_model_path, "ascii")
+      local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      local scale_model = torch.load(scale_model_path, "ascii")
+      
+      if not noise_model then
+	 error("Load Error: " .. noise_model_path)
+      end
+      if not scale_model then
+	 error("Load Error: " .. scale_model_path)
+      end
+      x = image_f(noise_model, x, opt.crop_size)
+      new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
    else
       error("undefined method:" .. opt.method)
    end
-   image_loader.save_png(opt.o, new_x, alpha)
+   if opt.white_noise == 1 then
+      new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
+   end
+   image_loader.save_png(opt.o, new_x, alpha, opt.depth)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
-   local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii")
-   local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii")
-   local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-
-   noise1_model:evaluate()
-   noise2_model:evaluate()
-   scale_model:evaluate()
-   
+   local model_path, noise1_model, noise2_model, scale_model
+   local scale_f, image_f
+   if opt.tta == 1 then
+      scale_f = reconstruct.scale_tta
+      image_f = reconstruct.image_tta
+   else
+      scale_f = reconstruct.scale
+      image_f = reconstruct.image
+   end
+   if opt.m == "scale" then
+      model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      scale_model = torch.load(model_path, "ascii")
+      if not scale_model then
+	 error("Load Error: " .. model_path)
+      end
+   elseif opt.m == "noise" and opt.noise_level == 1 then
+      model_path = path.join(opt.model_dir, "noise1_model.t7")
+      noise1_model = torch.load(model_path, "ascii")
+      if not noise1_model then
+	 error("Load Error: " .. model_path)
+      end
+   elseif opt.m == "noise" and opt.noise_level == 2 then
+      model_path = path.join(opt.model_dir, "noise2_model.t7")
+      noise2_model = torch.load(model_path, "ascii")
+      if not noise2_model then
+	 error("Load Error: " .. model_path)
+      end
+   elseif opt.m == "noise_scale" then
+      model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      scale_model = torch.load(model_path, "ascii")
+      if not scale_model then
+	 error("Load Error: " .. model_path)
+      end
+      if opt.noise_level == 1 then
+	 model_path = path.join(opt.model_dir, "noise1_model.t7")
+	 noise1_model = torch.load(model_path, "ascii")
+	 if not noise1_model then
+	    error("Load Error: " .. model_path)
+	 end
+      elseif opt.noise_level == 2 then
+	 model_path = path.join(opt.model_dir, "noise2_model.t7")
+	 noise2_model = torch.load(model_path, "ascii")
+	 if not noise2_model then
+	    error("Load Error: " .. model_path)
+	 end
+      end
+   end
    local fp = io.open(opt.l)
+   if not fp then
+      error("Open Error: " .. opt.l)
+   end
    local count = 0
    local lines = {}
    for line in fp:lines() do
@@ -62,20 +127,24 @@ local function convert_frames(opt)
 	 local x, alpha = image_loader.load_float(lines[i])
 	 local new_x = nil
 	 if opt.m == "noise" and opt.noise_level == 1 then
-	    new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
+	    new_x = image_f(noise1_model, x, opt.crop_size)
 	 elseif opt.m == "noise" and opt.noise_level == 2 then
-	    new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+	    new_x = image_func(noise2_model, x, opt.crop_size)
 	 elseif opt.m == "scale" then
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-	    x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    x = image_f(noise1_model, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
-	    x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	    x = image_f(noise2_model, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 else
 	    error("undefined method:" .. opt.method)
 	 end
+	 if opt.white_noise == 1 then
+	    new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
+	 end
+
 	 local output = nil
 	 if opt.o == "(auto)" then
 	    local name = path.basename(lines[i])
@@ -85,7 +154,7 @@ local function convert_frames(opt)
 	 else
 	    output = string.format(opt.o, i)
 	 end
-	 image_loader.save_png(output, new_x, alpha)
+	 image_loader.save_png(output, new_x, alpha, opt.depth)
 	 xlua.progress(i, #lines)
 	 if i % 10 == 0 then
 	    collectgarbage()
@@ -101,17 +170,30 @@ local function waifu2x()
    cmd:text()
    cmd:text("waifu2x")
    cmd:text("Options:")
-   cmd:option("-i", "images/miku_small.png", 'path of the input image')
-   cmd:option("-l", "", 'path of the image-list')
+   cmd:option("-i", "images/miku_small.png", 'path to input image')
+   cmd:option("-l", "", 'path to image-list.txt')
    cmd:option("-scale", 2, 'scale factor')
-   cmd:option("-o", "(auto)", 'path of the output file')
-   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
+   cmd:option("-o", "(auto)", 'path to output file')
+   cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
+   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2)')
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
+   cmd:option("-thread", -1, "number of CPU threads")
+   cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
+   cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
+   cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
    
    local opt = cmd:parse(arg)
+   if opt.thread > 0 then
+      torch.setnumthreads(opt.thread)
+   end
+   if cudnn then
+      cudnn.fastest = true
+      cudnn.benchmark = false
+   end
+   
    if string.len(opt.l) == 0 then
       convert_image(opt)
    else

+ 119 - 90
web.lua

@@ -1,11 +1,21 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+local ROOT = path.dirname(__FILE__)
+package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
 _G.TURBO_SSL = true
-local turbo = require 'turbo'
+
+require 'w2nn'
 local uuid = require 'uuid'
 local ffi = require 'ffi'
 local md5 = require 'md5'
-require 'pl'
-require './lib/portable'
-require './lib/LeakyReLU'
+local iproc = require 'iproc'
+local reconstruct = require 'reconstruct'
+local image_loader = require 'image_loader'
+
+-- Notes:  turbo and xlua has different implementation of string:split().
+--         Therefore, string:split() has conflict issue.
+--         In this script, use turbo's string:split().
+local turbo = require 'turbo'
 
 local cmd = torch.CmdLine()
 cmd:text()
@@ -13,24 +23,27 @@ cmd:text("waifu2x-api")
 cmd:text("Options:")
 cmd:option("-port", 8812, 'listen port')
 cmd:option("-gpu", 1, 'Device ID')
-cmd:option("-core", 2, 'number of CPU cores')
+cmd:option("-thread", -1, 'number of CPU threads')
 local opt = cmd:parse(arg)
 cutorch.setDevice(opt.gpu)
 torch.setdefaulttensortype('torch.FloatTensor')
-torch.setnumthreads(opt.core)
-
-local iproc = require './lib/iproc'
-local reconstruct = require './lib/reconstruct'
-local image_loader = require './lib/image_loader'
-
-local MODEL_DIR = "./models/anime_style_art_rgb"
-
-local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
-local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
-local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
-
-local USE_CACHE = true
-local CACHE_DIR = "./cache"
+if opt.thread > 0 then
+   torch.setnumthreads(opt.thread)
+end
+if cudnn then
+   cudnn.fastest = true
+   cudnn.benchmark = false
+end
+local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
+local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
+local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
+local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
+local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
+--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
+--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
+--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
+local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
+local CACHE_DIR = path.join(ROOT, "cache")
 local MAX_NOISE_IMAGE = 2560 * 2560
 local MAX_SCALE_IMAGE = 1280 * 1280
 local CURL_OPTIONS = {
@@ -40,7 +53,6 @@ local CURL_OPTIONS = {
    max_redirects = 2
 }
 local CURL_MAX_SIZE = 2 * 1024 * 1024
-local BLOCK_OFFSET = 7 -- see srcnn.lua
 
 local function valid_size(x, scale)
    if scale == 0 then
@@ -50,20 +62,16 @@ local function valid_size(x, scale)
    end
 end
 
-local function get_image(req)
-   local file = req:get_argument("file", "")
-   local url = req:get_argument("url", "")
-   local blob = nil
-   local img = nil
-   local alpha = nil
-   if file and file:len() > 0 then
-      blob = file
-      img, alpha = image_loader.decode_float(blob)
-   elseif url and url:len() > 0 then
+local function cache_url(url)
+   local hash = md5.sumhexa(url)
+   local cache_file = path.join(CACHE_DIR, "url_" .. hash)
+   if path.exists(cache_file) then
+      return image_loader.load_float(cache_file)
+   else
       local res = coroutine.yield(
 	 turbo.async.HTTPClient({verify_ca=false},
-				nil,
-				CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
+	    nil,
+	    CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
       )
       if res.code == 200 then
 	 local content_type = res.headers:get("Content-Type", true)
@@ -71,33 +79,64 @@ local function get_image(req)
 	    content_type = content_type[1]
 	 end
 	 if content_type and content_type:find("image") then
-	    blob = res.body
-	    img, alpha = image_loader.decode_float(blob)
+	    local fp = io.open(cache_file, "wb")
+	    local blob = res.body
+	    fp:write(blob)
+	    fp:close()
+	    return image_loader.decode_float(blob)
 	 end
       end
    end
-   return img, blob, alpha
+   return nil, nil, nil
 end
-
-local function apply_denoise1(x)
-   return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
-end
-local function apply_denoise2(x)
-   return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+local function get_image(req)
+   local file = req:get_argument("file", "")
+   local url = req:get_argument("url", "")
+   if file and file:len() > 0 then
+      return image_loader.decode_float(file)
+   elseif url and url:len() > 0 then
+      return cache_url(url)
+   end
+   return nil, nil, nil
 end
-local function apply_scale2x(x)
-   return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
+local function cleanup_model(model)
+   if CLEANUP_MODEL then
+      w2nn.cleanup_model(model) -- release GPU memory
+   end
 end
-local function cache_do(cache, x, func)
-   if path.exists(cache) then
-      return image.load(cache)
+local function convert(x, options)
+   local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
+   if path.exists(cache_file) then
+      return image.load(cache_file)
    else
-      x = func(x)
-      image.save(cache, x)
+      if options.style == "art" then
+	 if options.method == "scale" then
+	    x = reconstruct.scale(art_scale2_model, 2.0, x)
+	    cleanup_model(art_scale2_model)
+	 elseif options.method == "noise1" then
+	    x = reconstruct.image(art_noise1_model, x)
+	    cleanup_model(art_noise1_model)
+	 else -- options.method == "noise2"
+	    x = reconstruct.image(art_noise2_model, x)
+	    cleanup_model(art_noise2_model)
+	 end
+      else --[[photo
+	 if options.method == "scale" then
+	    x = reconstruct.scale(photo_scale2_model, 2.0, x)
+	    cleanup_model(photo_scale2_model)
+	 elseif options.method == "noise1" then
+	    x = reconstruct.image(photo_noise1_model, x)
+	    cleanup_model(photo_noise1_model)
+	 elseif options.method == "noise2" then
+	    x = reconstruct.image(photo_noise2_model, x)
+	    cleanup_model(photo_noise2_model)
+	 end
+      --]]
+      end
+      image.save(cache_file, x)
       return x
    end
 end
-
 local function client_disconnected(handler)
    return not(handler.request and
 		 handler.request.connection and
@@ -112,63 +151,51 @@ function APIHandler:post()
       self:write("client disconnected")
       return
    end
-   local x, src, alpha = get_image(self)
+   local x, alpha, blob = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
+   local white_noise = tonumber(self:get_argument("white_noise", "0"))
+   local style = self:get_argument("style", "art")
+   if style ~= "art" then
+      style = "photo" -- style must be art or photo
+   end
    if x and valid_size(x, scale) then
-      if USE_CACHE and (noise ~= 0 or scale ~= 0) then
-	 local hash = md5.sumhexa(src)
-	 local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
-	 local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
-	 local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
-	 local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
-	 local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
-	 
+      if (noise ~= 0 or scale ~= 0) then
+	 local hash = md5.sumhexa(blob)
 	 if noise == 1 then
-	    x = cache_do(cache_noise1, x, apply_denoise1)
+	    x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
 	 elseif noise == 2 then
-	    x = cache_do(cache_noise2, x, apply_denoise2)
+	    x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
 	 end
 	 if scale == 1 or scale == 2 then
 	    if noise == 1 then
-	       x = cache_do(cache_noise1_scale, x, apply_scale2x)
+	       x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
 	    elseif noise == 2 then
-	       x = cache_do(cache_noise2_scale, x, apply_scale2x)
+	       x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
 	    else
-	       x = cache_do(cache_scale, x, apply_scale2x)
+	       x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
 	    end
 	    if scale == 1 then
-	       x = iproc.scale(x,
-			       math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
-			       math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
-			       "Jinc")
+	       x = iproc.scale_with_gamma22(x,
+					    math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
+					    math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
+					    "Jinc")
 	    end
 	 end
-      elseif noise ~= 0 or scale ~= 0 then
-	 if noise == 1 then
-	    x = apply_denoise1(x)
-	 elseif noise == 2 then
-	    x = apply_denoise2(x)
-	 end
-	 if scale == 1 then
-	    local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
-	    x = apply_scale2x(x)
-	    x = iproc.scale(x, x16[1], x16[2], "Jinc")
-	 elseif scale == 2 then
-	    x = apply_scale2x(x)
+	 if white_noise == 1 then
+	    x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
 	 end
       end
       local name = uuid() .. ".png"
-      local blob, len = image_loader.encode_png(x, alpha)
-      
+      local blob = image_loader.encode_png(x, alpha)
       self:set_header("Content-Disposition", string.format('filename="%s"', name))
       self:set_header("Content-Type", "image/png")
-      self:set_header("Content-Length", string.format("%d", len))
-      self:write(ffi.string(blob, len))
+      self:set_header("Content-Length", string.format("%d", #blob))
+      self:write(blob)
    else
       if not x then
 	 self:set_status(400)
-	 self:write("ERROR: unsupported image format.")
+	 self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
       else
 	 self:set_status(400)
 	 self:write("ERROR: image size exceeds maximum allowable size.")
@@ -177,9 +204,9 @@ function APIHandler:post()
    collectgarbage()
 end
 local FormHandler = class("FormHandler", turbo.web.RequestHandler)
-local index_ja = file.read("./assets/index.ja.html")
-local index_ru = file.read("./assets/index.ru.html")
-local index_en = file.read("./assets/index.html")
+local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
+local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
+local index_en = file.read(path.join(ROOT, "assets", "index.html"))
 function FormHandler:get()
    local lang = self.request.headers:get("Accept-Language")
    if lang then
@@ -209,9 +236,11 @@ turbo.log.categories = {
 local app = turbo.web.Application:new(
    {
       {"^/$", FormHandler},
-      {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
-      {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
-      {"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
+      {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
+      {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
+      {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
+      {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
+      {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
       {"^/api$", APIHandler},
    }
 )

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است