Forráskód Böngészése

Merge branch 'dev'

nagadomi 8 éve
szülő
commit
d779a9d47a

+ 2 - 0
.gitignore

@@ -11,6 +11,8 @@ models/*
 !models/ukbench
 !models/photo
 !models/upconv_7
+!models/upconv_7l
+!models/srresnet_12l
 !models/vgg_7
 models/*/*.png
 models/*/*/*.png

+ 1 - 0
README.md

@@ -89,6 +89,7 @@ See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
 And install luarocks packages.
 ```
 luarocks install graphicsmagick # upgrade
+luarocks install threads # upgrade
 luarocks install lua-csnappy
 luarocks install md5
 luarocks install uuid

+ 65 - 31
appendix/benchmark.md

@@ -1,45 +1,79 @@
-# Benchmark results
+# Benchmarks
 
-Warning: This benchmark results is outdated. I will update soon.
+## Photo
 
-## Usage
+Note: waifu2x's photo models was trained on the blending dataset of [kou's photo collection](http://photosku.com/photo/category/%E6%92%AE%E5%BD%B1%E8%80%85/kou/) and [ukbench](http://vis.uky.edu/~stewe/ukbench/).
 
-```
-th tools/benchmark.lua -dir path/to/dataset_dir -method scale -color y -model1_dir path/to/model_dir
-```
+Note: PSNR in this benchmark uses a [MATLAB's rgb2ycbcr](https://jp.mathworks.com/help/images/ref/rgb2ycbcr.html?lang=en) compatible function (dynamic range [16 235], not [0 255]) for converting grayscale image. I think it's not correct PSNR. But many paper used this metric.
 
-## Dataset
+command: 
+`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Catrom -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
 
-    photo_test: 300 various photos.
-    art_test  : 90 artworks (PNG only).
+### Datasets
 
-## 2x upscaling model
+BSD100: https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/ (100 test images in BSDS300)
+Urban100: https://github.com/jbhuang0604/SelfExSR
 
-| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo   | ukbench|
-|---------------|----------------------|------------------------|---------|--------|
-| photo\_test   |                29.83 |                  29.81 |**29.89**|  29.86 |
-| art\_test     |                36.02 |               **36.24**|  34.92  |  34.85 |
+### 2x - PSNR 
 
-The evaluation metric is PSNR(Y only), higher is better.
+| Dataset/Model | Bicubic       | vgg\_7/photo  | upconv\_7/photo  | upconv\_7l/photo | resnet_14l/photo | 
+|---------------|---------------|---------------|------------------|------------------|--------------------|
+| BSD100        | 29.558        | 31.427        | 31.640           | 31.749           | 31.847             |
+| Urban100      | 26.852        | 30.057        | 30.477           | 30.759           | 31.016             |
 
-## Denosing level 1 model
+### 2x with TTA - PSNR 
 
-| Dataset/Model            | anime\_style\_art | anime\_style\_art\_rgb | photo   |
-|--------------------------|-------------------|------------------------|---------|
-| photo\_test Quality 80   |             36.07 |               **36.20**|   36.01 |
-| photo\_test Quality 50,45|             31.72 |                 32.01  |**32.31**|
-| art\_test Quality 80     |             40.39 |               **42.48**|   40.35 |
-| art\_test Quality 50,45  |             35.45 |               **36.70**|   36.27 |
+Note: TTA is an ensemble technique that is supported by waifu2x. TTA method is 8x slower than non TTA method but it improves PSNR (~+0.1 on photo, ~+0.4 on art).
 
-The evaluation metric is PSNR(RGB), higher is better.
+| Dataset/Model | Bicubic       | vgg\_7/photo  | upconv\_7/photo  | upconv\_7l/photo | resnet_14l/photo | 
+|---------------|---------------|---------------|------------------|------------------|--------------------|
+| BSD100        | 29.558        | 31.474        | 31.705           | 31.812           | 31.915             |
+| Urban100      | 26.852        | 30.140        | 30.599           | 30.868           | 31.162             |
 
-## Denosing level 2 model
+### 2x - benchmark elapsed time (sec)
 
-| Dataset/Model            | anime\_style\_art | anime\_style\_art\_rgb | photo   |
-|--------------------------|-------------------|------------------------|---------|
-| photo\_test Quality 80   |             34.03 |                  34.42 |**36.06**|
-| photo\_test Quality 50,45|             31.95 |                  32.31 |**32.42**|
-| art\_test Quality 80     |             39.20 |               **41.12**|   40.48 |
-| art\_test Quality 50,45  |             36.14 |               **37.78**|   36.55 |
+| Dataset/Model | vgg\_7/photo  | upconv\_7/photo  | upconv\_7l/photo | resnet_14l/photo |
+|---------------|---------------|------------------|------------------|--------------------|
+| BSD100        | 4.057         | 2.509            | 4.947            | 6.86               |
+| Urban100      | 16.349        | 7.083            | 14.178           | 27.87              |
+
+### 2x with TTA - benchmark elapsed time (sec)
+
+| Dataset/Model | vgg\_7/photo  | upconv\_7/photo  | upconv\_7l/photo | resnet_14l/photo |
+|---------------|---------------|------------------|------------------|--------------------|
+| BSD100        | 36.611        | 20.219           | 42.486           | 60.38              |
+| Urban100      | 132.416       | 65.125           | 129.916          | 255.20             |
+
+## Art
+
+command: 
+`th tools/benchmark.lua -dir <dataset_dir> -model1_dir <model_dir> -method scale -filter Lanczos -color y -range_bug 1 -tta <0|1> -force_cudnn 1`
+
+### Dataset
+
+art_test: This dataset contains 85 various fan-arts. Sorry, This dataset is private. 
+
+### 2x - PSNR 
+
+| Dataset/Model | Bicubic       | vgg\_7/art  | upconv\_7/art  | upconv\_7l/art | 
+|---------------|---------------|-------------|----------------|----------------|
+| art_test      | 31.022        | 37.495      | 38.330         | 39.140         |
+
+### 2x with TTA - PSNR 
+
+| Dataset/Model | Bicubic       | vgg\_7/art  | upconv\_7/art  | upconv\_7l/art | 
+|---------------|---------------|-------------|----------------|----------------|
+| art_test      | 31.022        | 37.777      | 38.677         | 39.510         |
+
+### 2x - benchmark elapsed time (sec)
+
+| Dataset/Model | vgg\_7/art  | upconv\_7/art  | upconv\_7l/art | 
+|---------------|-------------|----------------|----------------|
+| art_test      | 20.681      | 7.683          | 17.667         |
+
+### 2x with TTA - benchmark elapsed time (sec)
+
+| Dataset/Model | vgg\_7/art  | upconv\_7/art  | upconv\_7l/art | 
+|---------------|-------------|----------------|----------------|
+| art_test      | 174.674     | 77.716         | 163.932        |
 
-The evaluation metric is PSNR(RGB), higher is better.

+ 34 - 0
appendix/benchmark.sh

@@ -0,0 +1,34 @@
+#!/bin/sh
+set -x
+
+benchmark_photo() {
+    dir=./benchmarks/${1}/${2}/${3}
+    mkdir -p ${dir}
+    th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/photo -method scale -filter Catrom -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0 
+}
+run_benchmark_photo() {
+    for tta in 0 1
+    do
+	for dataset in bsd100 urban100
+	do
+	    benchmark_photo ${dataset} vgg_7 ${tta}
+	    benchmark_photo ${dataset} upconv_7 ${tta}
+	    benchmark_photo ${dataset} upconv_7l ${tta}
+	done
+    done
+}
+benchmark_art() {
+    dir=./benchmarks/${1}/${2}/${3}
+    mkdir -p ${dir}
+    th tools/benchmark.lua -dir data/${1} -model1\_dir models/${2}/art -method scale -filter Lanczos -color y -range\_bug 1 -tta ${3} -force_cudnn 1 -output_dir ${dir} -save_info 1 -show_progress 0 
+}
+run_benchmark_art() {
+    for tta in 0 1
+    do
+	benchmark_art art_test vgg_7 ${tta}
+	benchmark_art art_test upconv_7 ${tta}
+	benchmark_art art_test upconv_7l ${tta}
+    done
+}
+#run_benchmark_photo
+run_benchmark_art 

+ 524 - 0
appendix/caffe_prototxt/resnet_14l.prototxt

@@ -0,0 +1,524 @@
+name: "resnet_14l"
+layer {
+  name: "input"
+  type: "Input"
+  top: "input"
+  input_param { shape: { dim: 1 dim: 3 dim: 156 dim: 156 } }
+}
+layer {
+  name: "Convolution1"
+  type: "Convolution"
+  bottom: "input"
+  top: "Convolution1"
+  convolution_param {
+    num_output: 32
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU1"
+  type: "ReLU"
+  bottom: "Convolution1"
+  top: "Convolution1"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution2"
+  type: "Convolution"
+  bottom: "Convolution1"
+  top: "Convolution2"
+  convolution_param {
+    num_output: 64
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU2"
+  type: "ReLU"
+  bottom: "Convolution2"
+  top: "Convolution2"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution3"
+  type: "Convolution"
+  bottom: "Convolution2"
+  top: "Convolution3"
+  convolution_param {
+    num_output: 64
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU3"
+  type: "ReLU"
+  bottom: "Convolution3"
+  top: "Convolution3"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution4"
+  type: "Convolution"
+  bottom: "Convolution1"
+  top: "Convolution4"
+  convolution_param {
+    num_output: 64
+    bias_term: true
+    pad: 0
+    kernel_size: 1
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "Crop1"
+  type: "Crop"
+  bottom: "Convolution4"
+  bottom: "Convolution3"
+  top: "Crop1"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise1"
+  type: "Eltwise"
+  bottom: "Convolution3"
+  bottom: "Crop1"
+  top: "Eltwise1"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Convolution5"
+  type: "Convolution"
+  bottom: "Eltwise1"
+  top: "Convolution5"
+  convolution_param {
+    num_output: 64
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU4"
+  type: "ReLU"
+  bottom: "Convolution5"
+  top: "Convolution5"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution6"
+  type: "Convolution"
+  bottom: "Convolution5"
+  top: "Convolution6"
+  convolution_param {
+    num_output: 64
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU5"
+  type: "ReLU"
+  bottom: "Convolution6"
+  top: "Convolution6"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Crop2"
+  type: "Crop"
+  bottom: "Eltwise1"
+  bottom: "Convolution6"
+  top: "Crop2"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise2"
+  type: "Eltwise"
+  bottom: "Convolution6"
+  bottom: "Crop2"
+  top: "Eltwise2"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Convolution7"
+  type: "Convolution"
+  bottom: "Eltwise2"
+  top: "Convolution7"
+  convolution_param {
+    num_output: 128
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU6"
+  type: "ReLU"
+  bottom: "Convolution7"
+  top: "Convolution7"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution8"
+  type: "Convolution"
+  bottom: "Convolution7"
+  top: "Convolution8"
+  convolution_param {
+    num_output: 128
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU7"
+  type: "ReLU"
+  bottom: "Convolution8"
+  top: "Convolution8"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution9"
+  type: "Convolution"
+  bottom: "Eltwise2"
+  top: "Convolution9"
+  convolution_param {
+    num_output: 128
+    bias_term: true
+    pad: 0
+    kernel_size: 1
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "Crop3"
+  type: "Crop"
+  bottom: "Convolution9"
+  bottom: "Convolution8"
+  top: "Crop3"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise3"
+  type: "Eltwise"
+  bottom: "Convolution8"
+  bottom: "Crop3"
+  top: "Eltwise3"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Convolution10"
+  type: "Convolution"
+  bottom: "Eltwise3"
+  top: "Convolution10"
+  convolution_param {
+    num_output: 128
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU8"
+  type: "ReLU"
+  bottom: "Convolution10"
+  top: "Convolution10"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution11"
+  type: "Convolution"
+  bottom: "Convolution10"
+  top: "Convolution11"
+  convolution_param {
+    num_output: 128
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU9"
+  type: "ReLU"
+  bottom: "Convolution11"
+  top: "Convolution11"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Crop4"
+  type: "Crop"
+  bottom: "Eltwise3"
+  bottom: "Convolution11"
+  top: "Crop4"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise4"
+  type: "Eltwise"
+  bottom: "Convolution11"
+  bottom: "Crop4"
+  top: "Eltwise4"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Convolution12"
+  type: "Convolution"
+  bottom: "Eltwise4"
+  top: "Convolution12"
+  convolution_param {
+    num_output: 256
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU10"
+  type: "ReLU"
+  bottom: "Convolution12"
+  top: "Convolution12"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution13"
+  type: "Convolution"
+  bottom: "Convolution12"
+  top: "Convolution13"
+  convolution_param {
+    num_output: 256
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU11"
+  type: "ReLU"
+  bottom: "Convolution13"
+  top: "Convolution13"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution14"
+  type: "Convolution"
+  bottom: "Eltwise4"
+  top: "Convolution14"
+  convolution_param {
+    num_output: 256
+    bias_term: true
+    pad: 0
+    kernel_size: 1
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "Crop5"
+  type: "Crop"
+  bottom: "Convolution14"
+  bottom: "Convolution13"
+  top: "Crop5"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise5"
+  type: "Eltwise"
+  bottom: "Convolution13"
+  bottom: "Crop5"
+  top: "Eltwise5"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Convolution15"
+  type: "Convolution"
+  bottom: "Eltwise5"
+  top: "Convolution15"
+  convolution_param {
+    num_output: 256
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU12"
+  type: "ReLU"
+  bottom: "Convolution15"
+  top: "Convolution15"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Convolution16"
+  type: "Convolution"
+  bottom: "Convolution15"
+  top: "Convolution16"
+  convolution_param {
+    num_output: 256
+    bias_term: true
+    pad: 0
+    kernel_size: 3
+    stride: 1
+    weight_filler {
+      type: "msra"
+    }
+  }
+}
+layer {
+  name: "ReLU13"
+  type: "ReLU"
+  bottom: "Convolution16"
+  top: "Convolution16"
+  relu_param {
+    negative_slope: 0.1
+  }
+}
+layer {
+  name: "Crop6"
+  type: "Crop"
+  bottom: "Eltwise5"
+  bottom: "Convolution16"
+  top: "Crop6"
+  crop_param {
+    axis: 2
+    offset: 2
+    offset: 2
+  }
+}
+layer {
+  name: "Eltwise6"
+  type: "Eltwise"
+  bottom: "Convolution16"
+  bottom: "Crop6"
+  top: "Eltwise6"
+  eltwise_param {
+    operation: SUM
+  }
+}
+layer {
+  name: "Deconvolution1"
+  type: "Deconvolution"
+  bottom: "Eltwise6"
+  top: "Deconvolution1"
+  convolution_param {
+    num_output: 3
+    pad: 3
+    kernel_size: 4
+    stride: 2
+  }
+}

+ 34 - 30
convert_data.lua

@@ -82,46 +82,50 @@ local function load_images(list)
       local skip = false
       local alpha_color = torch.random(0, 1)
 
-      if meta and meta.alpha then
-	 if settings.use_transparent_png then
-	    im = alpha_util.fill(im, meta.alpha, alpha_color)
-	 else
-	    skip = true
-	 end
-      end
-      if skip then
-	 if not skip_notice then
-	    io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
-	    skip_notice = true
+      if im then
+	 if meta and meta.alpha then
+	    if settings.use_transparent_png then
+	       im = alpha_util.fill(im, meta.alpha, alpha_color)
+	    else
+	       skip = true
+	    end
 	 end
-      else
-	 if csv_meta and csv_meta.x then
-	    -- method == user
-	    local yy = im
-	    local xx, meta2 = image_loader.load_byte(csv_meta.x)
-	    if meta2 and meta2.alpha then
-	       xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
+	 if skip then
+	    if not skip_notice then
+	       io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
+	       skip_notice = true
 	    end
-	    xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
-	    table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
-			    {data = {filters = filters, has_x = true}}})
 	 else
-	    im = crop_if_large(im, settings.max_training_image_size)
-	    im = iproc.crop_mod4(im)
-	    local scale = 1.0
-	    if settings.random_half_rate > 0.0 then
-	       scale = 2.0
-	    end
-	    if im then
+	    if csv_meta and csv_meta.x then
+	       -- method == user
+	       local yy = im
+	       local xx, meta2 = image_loader.load_byte(csv_meta.x)
+	       if xx then
+		  if meta2 and meta2.alpha then
+		     xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
+		  end
+		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
+		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
+				  {data = {filters = filters, has_x = true}}})
+	       else
+		  io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
+	       end
+	    else
+	       im = crop_if_large(im, settings.max_training_image_size)
+	       im = iproc.crop_mod4(im)
+	       local scale = 1.0
+	       if settings.random_half_rate > 0.0 then
+		  scale = 2.0
+	       end
 	       if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
 		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
 	       end
-	    else
-	       io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
 	    end
 	 end
+      else
+	 io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
       end
       xlua.progress(i, #csv)
       if i % 10 == 0 then

+ 3 - 1
lib/ClippedMSECriterion.lua

@@ -5,12 +5,14 @@ function ClippedMSECriterion:__init(min, max)
    self.min = min
    self.max = max
    self.diff = torch.Tensor()
+   self.diff_pow2 = torch.Tensor()
 end
 function ClippedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
    self.diff:clamp(self.min, self.max)
    self.diff:add(-1, target)
-   self.output = self.diff:pow(2):sum() / input:nElement()
+   self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
+   self.output = self.diff_pow2:sum() / input:nElement()
    return self.output
 end
 function ClippedMSECriterion:updateGradInput(input, target)

+ 13 - 0
lib/InplaceClip01.lua

@@ -0,0 +1,13 @@
+local Clip01, parent = torch.class("w2nn.InplaceClip01", "nn.Module")
+
+function Clip01:__init()
+   parent.__init(self)
+end
+function Clip01:updateOutput(input)
+   self.output:set(input:clamp(0, 1))
+   return self.output
+end
+function Clip01:updateGradInput(input, gradOutput)
+   self.gradInput:set(gradOutput)
+   return self.gradInput
+end

+ 27 - 0
lib/L1Criterion.lua

@@ -0,0 +1,27 @@
+-- ref: https://en.wikipedia.org/wiki/L1_loss
+local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
+
+function L1Criterion:__init()
+   parent.__init(self)
+   self.diff = torch.Tensor()
+   self.linear_loss_buff = torch.Tensor()
+end
+function L1Criterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   if input:dim() == 1 then
+      self.diff[1] = input[1] - target
+   else
+      for i = 1, input:size(1) do
+	 self.diff[i]:add(-1, target[i])
+      end
+   end
+   local linear_targets = self.diff
+   local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
+   self.output = (linear_loss) / input:nElement()
+   return self.output
+end
+function L1Criterion:updateGradInput(input, target)
+   local norm = 1.0 / input:nElement()
+   self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
+   return self.gradInput
+end

+ 67 - 0
lib/SSIMCriterion.lua

@@ -0,0 +1,67 @@
+-- SSIM Index, ref: http://www.cns.nyu.edu/~lcv/ssim/ssim_index.m
+local SSIMCriterion, parent = torch.class('w2nn.SSIMCriterion','nn.Criterion')
+function SSIMCriterion:__init(ch, kernel_size, sigma)
+   parent.__init(self)
+   local function gaussian2d(kernel_size, sigma)
+      sigma = sigma or 1
+      local kernel = torch.Tensor(kernel_size, kernel_size)
+      local u = math.floor(kernel_size / 2) + 1
+      local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
+      for x = 1, kernel_size do
+	 for y = 1, kernel_size do
+	 kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
+	 end
+      end
+      kernel:div(kernel:sum())
+      return kernel
+   end
+   ch = ch or 1
+   kernel_size = kernel_size or 11
+   sigma = sigma or 1.5
+   local kernel = gaussian2d(kernel_size, sigma)
+   if ch > 1 then
+      local kernel_nd = torch.Tensor(ch, ch, kernel_size, kernel_size)
+      for i = 1, ch do
+	 for j = 1, ch do
+	    kernel_nd[i][j]:copy(kernel)
+	    if i ~= j then
+	       kernel_nd[i][j]:zero()
+	    end
+	 end
+      end
+      kernel = kernel_nd
+   end
+   self.c1 = 0.01^2
+   self.c2 = 0.03^2
+   self.ch = ch
+   self.conv = nn.SpatialConvolution(ch, ch, kernel_size, kernel_size, 1, 1, 0, 0):noBias()
+   self.conv.weight:copy(kernel)
+   self.mu1 = torch.Tensor()
+   self.mu2 = torch.Tensor()
+   self.mu1_sq = torch.Tensor()
+   self.mu2_sq = torch.Tensor()
+   self.mu1_mu2 = torch.Tensor()
+   self.sigma1_sq = torch.Tensor()
+   self.sigma2_sq = torch.Tensor()
+   self.sigma12 = torch.Tensor()
+   self.ssim_map = torch.Tensor()
+end
+function SSIMCriterion:updateOutput(input, target)-- dynamic range: 0-1
+   assert(input:nElement() == target:nElement())
+   local valid = self.conv:forward(input)
+   self.mu1:resizeAs(valid):copy(valid)
+   self.mu2:resizeAs(valid):copy(self.conv:forward(target))
+   self.mu1_sq:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu1)
+   self.mu2_sq:resizeAs(self.mu2):copy(self.mu2):cmul(self.mu2)
+   self.mu1_mu2:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu2)
+   self.sigma1_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, input)):add(-1, self.mu1_sq))
+   self.sigma2_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(target, target)):add(-1, self.mu2_sq))
+   self.sigma12:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, target)):add(-1, self.mu1_mu2))
+
+   local ssim = self.mu1_mu2:mul(2):add(self.c1):cmul(self.sigma12:mul(2):add(self.c2)):
+      cdiv(self.mu1_sq:add(self.mu2_sq):add(self.c1):cmul(self.sigma1_sq:add(self.sigma2_sq):add(self.c2))):mean()
+   return ssim
+end
+function SSIMCriterion:updateGradInput(input, target)
+   error("not implemented")
+end

+ 1 - 2
lib/alpha_util.lua

@@ -40,8 +40,7 @@ function alpha_util.make_border(rgb, alpha, offset)
 	 collectgarbage()
       end
    end
-   rgb[torch.gt(rgb, 1.0)] = 1.0
-   rgb[torch.lt(rgb, 0.0)] = 0.0
+   rgb:clamp(0.0, 1.0)
 
    return rgb
 end

+ 93 - 8
lib/data_augmentation.lua

@@ -1,7 +1,8 @@
-require 'image'
+require 'pl'
+require 'cunn'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
-
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local data_augmentation = {}
 
 local function pcacov(x)
@@ -25,8 +26,7 @@ function data_augmentation.color_noise(src, p, factor)
 	 pca_space[i]:mul(color_scale[i])
       end
       local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
-      dest[torch.lt(dest, 0.0)] = 0.0
-      dest[torch.gt(dest, 1.0)] = 1.0
+      dest:clamp(0.0, 1.0)
 
       if conversion then
 	 dest = iproc.float2byte(dest)
@@ -70,6 +70,75 @@ function data_augmentation.unsharp_mask(src, p)
       return src
    end
 end
+function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
+   size = size or "3"
+   filters = utils.split(size, ",")
+   for i = 1, #filters do
+      local s = tonumber(filters[i])
+      filters[i] = s
+   end
+   if torch.uniform() < p then
+      local src, conversion = iproc.byte2float(src)
+      local kernel_size = filters[torch.random(1, #filters)]
+      local sigma
+      if sigma_min == sigma_max then
+	 sigma = sigma_min
+      else
+	 sigma = torch.uniform(sigma_min, sigma_max)
+      end
+      local kernel = iproc.gaussian2d(kernel_size, sigma)
+      local dest = image.convolve(src, kernel, 'same')
+      if conversion then
+	 dest = iproc.float2byte(dest)
+      end
+      return dest
+   else
+      return src
+   end
+end
+function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      local scale = torch.uniform(scale_min, scale_max)
+      local h = math.floor(x:size(2) * scale)
+      local w = math.floor(x:size(3) * scale)
+      x = iproc.scale(x, w, h, "Triangle")
+      y = iproc.scale(y, w, h, "Triangle")
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      local r = torch.uniform(r_min, r_max) / 360.0 * math.pi
+      x = iproc.rotate(x, r)
+      y = iproc.rotate(y, r)
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_negate(x, y, p)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      x = iproc.negate(x)
+      y = iproc.negate(y)
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_negate_x(x, y, p)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      x = iproc.negate(x)
+      return x, y
+   else
+      return x, y
+   end
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)
@@ -107,11 +176,11 @@ function data_augmentation.flip(src)
       src = src:transpose(2, 3):contiguous()
    end
    if flip == 1 then
-      dest = image.hflip(src)
+      dest = iproc.hflip(src)
    elseif flip == 2 then
-      dest = image.vflip(src)
+      dest = iproc.vflip(src)
    elseif flip == 3 then
-      dest = image.hflip(image.vflip(src))
+      dest = iproc.hflip(iproc.vflip(src))
    elseif flip == 4 then
       dest = src
    end
@@ -120,4 +189,20 @@ function data_augmentation.flip(src)
    end
    return dest
 end
+
+local function test_blur()
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local image =require 'image'
+   local src = image.lena()
+
+   image.display({image = src, min=0, max=1})
+   local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6)
+   image.display({image = dest, min=0, max=1})
+   dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0)
+   image.display({image = dest, min=0, max=1})
+   dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75)
+   image.display({image = dest, min=0, max=1})
+end
+--test_blur()
+
 return data_augmentation

+ 2 - 4
lib/image_loader.lua

@@ -22,8 +22,7 @@ function image_loader.encode_png(rgb, options)
       else
 	 rgb = rgb:clone():add(clip_eps8)
       end
-      rgb[torch.lt(rgb, 0.0)] = 0.0
-      rgb[torch.gt(rgb, 1.0)] = 1.0
+      rgb:clamp(0.0, 1.0)
       rgb = rgb:mul(255):floor():div(255)
    else
       if options.inplace then
@@ -31,8 +30,7 @@ function image_loader.encode_png(rgb, options)
       else
 	 rgb = rgb:clone():add(clip_eps16)
       end
-      rgb[torch.lt(rgb, 0.0)] = 0.0
-      rgb[torch.gt(rgb, 1.0)] = 1.0
+      rgb:clamp(0.0, 1.0)
       rgb = rgb:mul(65535):floor():div(65535)
    end
    local im

+ 142 - 6
lib/iproc.lua

@@ -1,6 +1,7 @@
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
+require 'dok'
 local image = require 'image'
-
 local iproc = {}
 local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 
@@ -42,8 +43,7 @@ function iproc.float2byte(src)
    if src:type() == "torch.FloatTensor" then
       conversion = true
       dest = (src + clip_eps8):mul(255.0)
-      dest[torch.lt(dest, 0.0)] = 0
-      dest[torch.gt(dest, 255.0)] = 255.0
+      dest:clamp(0, 255.0)
       dest = dest:byte()
    end
    return dest, conversion
@@ -80,6 +80,7 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
    return dest
 end
 function iproc.padding(img, w1, w2, h1, h2)
+   image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
    local flow = torch.Tensor(2, dst_height, dst_width)
@@ -90,6 +91,7 @@ function iproc.padding(img, w1, w2, h1, h2)
    return image.warp(img, flow, "simple", false, "clamp")
 end
 function iproc.zero_padding(img, w1, w2, h1, h2)
+   image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
    local flow = torch.Tensor(2, dst_height, dst_width)
@@ -115,8 +117,7 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
    local dest
    if gamma ~= 0 then
       dest = src:clone():pow(gamma):add(noise)
-      dest[torch.lt(dest, 0.0)] = 0.0
-      dest[torch.gt(dest, 1.0)] = 1.0
+      dest:clamp(0.0, 1.0)
       dest:pow(1.0 / gamma)
    else
       dest = src + noise
@@ -126,6 +127,101 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
    end
    return dest
 end
+function iproc.hflip(src)
+   local t
+   if src:type() == "torch.ByteTensor" then
+      t = "byte"
+   else
+      t = "float"
+   end
+   if src:size(1) == 3 then
+      color = "RGB"
+   else
+      color = "I"
+   end
+   local im = gm.Image(src, color, "DHW")
+   return im:flop():toTensor(t, color, "DHW")
+end
+function iproc.vflip(src)
+   local t
+   if src:type() == "torch.ByteTensor" then
+      t = "byte"
+   else
+      t = "float"
+   end
+   if src:size(1) == 3 then
+      color = "RGB"
+   else
+      color = "I"
+   end
+   local im = gm.Image(src, color, "DHW")
+   return im:flip():toTensor(t, color, "DHW")
+end
+local function rotate_with_warp(src, dst, theta, mode)
+  local height
+  local width
+  if src:dim() == 2 then
+    height = src:size(1)
+    width = src:size(2)
+  elseif src:dim() == 3 then
+    height = src:size(2)
+    width = src:size(3)
+  else
+    dok.error('src image must be 2D or 3D', 'image.rotate')
+  end
+  local flow = torch.Tensor(2, height, width)
+  local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)},
+			       {math.sin(-theta), math.cos(-theta)}})
+  flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width))
+  flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5))
+  flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width))
+  flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
+  flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
+  dst:resizeAs(src)
+  return image.warp(dst, src, flow, mode, true, 'clamp')
+end
+function iproc.rotate(src, theta)
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   local dest = torch.Tensor():typeAs(src):resizeAs(src)
+   rotate_with_warp(src, dest, theta, 'bilinear')
+   dest:clamp(0, 1)
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
+function iproc.negate(src)
+   if src:type() == "torch.ByteTensor" then
+      return -src + 255
+   else
+      return -src + 1
+   end
+end
+
+function iproc.gaussian2d(kernel_size, sigma)
+   sigma = sigma or 1
+   local kernel = torch.Tensor(kernel_size, kernel_size)
+   local u = math.floor(kernel_size / 2) + 1
+   local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
+   for x = 1, kernel_size do
+      for y = 1, kernel_size do
+	 kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
+      end
+   end
+   kernel:div(kernel:sum())
+   return kernel
+end
+function iproc.rgb2y(src)
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
+   dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
 
 local function test_conversion()
    local a = torch.linspace(0, 255, 256):float():div(255.0)
@@ -144,6 +240,46 @@ local function test_conversion()
    print(b)
    assert(b:float():sum() == 254.0 * 3)
 end
+local function test_flip()
+   require 'sys'
+   require 'torch'
+   torch.setdefaulttensortype("torch.FloatTensor")
+   image = require 'image'
+   local src = image.lena()
+   local src_byte = src:clone():mul(255):byte()
+
+   print(src:size())
+   print((image.hflip(src) - iproc.hflip(src)):sum())
+   print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum())
+   print((image.vflip(src) - iproc.vflip(src)):sum())
+   print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
+end
+local function test_gaussian2d()
+   local t = {3, 5, 7}
+   for i = 1, #t do
+      local kp = iproc.gaussian2d(t[i], 0.5)
+      print(kp)
+   end
+end
+local function test_conv()
+   local image = require 'image'
+   local src = image.lena()
+   local kernel = torch.Tensor(3, 3):fill(1)
+   kernel:div(kernel:sum())
+   --local blur = image.convolve(iproc.padding(src, 1, 1, 1, 1), kernel, 'valid')
+   local blur = image.convolve(src, kernel, 'same')
+   print(src:size(), blur:size())
+   local diff = (blur - src):abs()
+   image.save("diff.png", diff)
+   image.display({image = blur, min=0, max=1})
+   image.display({image = diff, min=0, max=1})
+end
+
 --test_conversion()
+--test_flip()
+--test_gaussian2d()
+--test_conv()
 
 return iproc
+
+

+ 10 - 5
lib/minibatch_adam.lua

@@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
 	 local se = 0
-	 for i = 1, batch_size do
-	    local el = eval_metric:forward(output[i], targets[i])
-	    se = se + el
-	    instance_loss[shuffle[t + i - 1]] = el
+	 if config.xInstanceLoss then
+	    for i = 1, batch_size do
+	       local el = eval_metric:forward(output[i], targets[i])
+	       se = se + el
+	       instance_loss[shuffle[t + i - 1]] = el
+	    end
+	    se = (se / batch_size)
+	 else
+	    se = eval_metric:forward(output, targets)
 	 end
-	 sum_eval = sum_eval + (se / batch_size)
+	 sum_eval = sum_eval + se
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
 	 model:backward(inputs, criterion:backward(output, targets))

+ 4 - 3
lib/pairwise_transform_jpeg.lua

@@ -1,5 +1,6 @@
 local pairwise_utils = require 'pairwise_transform_utils'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local iproc = require 'iproc'
 local pairwise_transform = {}
 
@@ -42,8 +43,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2y(yc)
+	 xc = iproc.rgb2y(xc)
       end
       if torch.uniform() < options.nr_rate then
 	 -- reducing noise

+ 4 - 3
lib/pairwise_transform_jpeg_scale.lua

@@ -1,6 +1,7 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 
 local function add_jpeg_noise_(x, quality, options)
@@ -117,8 +118,8 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2y(yc)
+	 xc = iproc.rgb2y(xc)
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 4 - 3
lib/pairwise_transform_scale.lua

@@ -1,6 +1,7 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 
 function pairwise_transform.scale(src, scale, size, offset, n, options)
@@ -50,8 +51,8 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2y(yc)
+	 xc = iproc.rgb2y(xc)
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 26 - 30
lib/pairwise_transform_user.lua

@@ -1,43 +1,30 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 
-local function crop_if_large(x, y, scale_y, max_size, mod)
-   local tries = 4
-   if y:size(2) > max_size and y:size(3) > max_size then
-      assert(max_size % 4 == 0)
-      local rect_x, rect_y
-      for i = 1, tries do
-	 local yi = torch.random(0, y:size(2) - max_size)
-	 local xi = torch.random(0, y:size(3) - max_size)
-	 if mod then
-	    yi = yi - (yi % mod)
-	    xi = xi - (xi % mod)
-	 end
-	 rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
-	 rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
-	 -- ignore simple background
-	 if rect_y:float():std() >= 0 then
-	    break
-	 end
-      end
-      return rect_x, rect_y
-   else
-      return x, y
-   end
-end
 function pairwise_transform.user(x, y, size, offset, n, options)
    assert(x:size(1) == y:size(1))
 
    local scale_y = y:size(2) / x:size(2)
    assert(x:size(3) == y:size(3) / scale_y)
 
-   x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
+   x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
    assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
    local batch = {}
-   local lowres_y = pairwise_utils.low_resolution(y)
-   local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
+   local lowres_y = nil
+   local xs ={x}
+   local ys = {y}
+   local ls = {}
+
+   if options.active_cropping_rate > 0 then
+      lowres_y = pairwise_utils.low_resolution(y)
+   end
+   if options.pairwise_flip then
+      xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
+   end
+   assert(#xs == #ys)
    for i = 1, n do
       local t = (i % #xs) + 1
       local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
@@ -47,8 +34,17 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2y(yc)
+	 xc = iproc.rgb2y(xc)
+      end
+      if options.gcn then
+	 local mean = xc:mean()
+	 local stdv = xc:std()
+	 if stdv > 0 then
+	    xc:add(-mean):div(stdv)
+	 else
+	    xc:add(-mean)
+	 end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 94 - 20
lib/pairwise_transform_utils.lua

@@ -1,7 +1,7 @@
-require 'image'
 require 'cunn'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local data_augmentation = require 'data_augmentation'
 local pairwise_transform_utils = {}
 
@@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod)
       return src
    end
 end
+function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
+   local tries = 4
+   if y:size(2) > max_size and y:size(3) > max_size then
+      assert(max_size % 4 == 0)
+      local rect_x, rect_y
+      for i = 1, tries do
+	 local yi = torch.random(0, y:size(2) - max_size)
+	 local xi = torch.random(0, y:size(3) - max_size)
+	 if mod then
+	    yi = yi - (yi % mod)
+	    xi = xi - (xi % mod)
+	 end
+	 rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
+	 rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
+	 -- ignore simple background
+	 if rect_y:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect_x, rect_y
+   else
+      return x, y
+   end
+end
 function pairwise_transform_utils.preprocess(src, crop_size, options)
    local dest = src
    local box_only = false
@@ -47,7 +71,6 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
    if box_only then
       local mod = 2 -- assert pos % 2 == 0
       dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
-      dest = data_augmentation.flip(dest)
       dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
       dest = data_augmentation.overlay(dest, options.random_overlay_rate)
       dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
@@ -55,7 +78,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
    else
       dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
       dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
-      dest = data_augmentation.flip(dest)
+      dest = data_augmentation.blur(dest, options.random_blur_rate,
+				    options.random_blur_size, 
+				    options.random_blur_sigma_min,
+				    options.random_blur_sigma_max)
       dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
       dest = data_augmentation.overlay(dest, options.random_overlay_rate)
       dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
@@ -63,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
    end
    return dest
 end
+function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
+
+   x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
+   x, y = data_augmentation.pairwise_rotate(x, y,
+					    options.random_pairwise_rotate_rate,
+					    options.random_pairwise_rotate_min,
+					    options.random_pairwise_rotate_max)
+
+   local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
+   local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
+   x, y = data_augmentation.pairwise_scale(x, y,
+					   options.random_pairwise_scale_rate,
+					   scale_min,
+					   scale_max)
+   x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
+   x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
+
+   x = iproc.crop_mod4(x)
+   y = iproc.crop_mod4(y)
+
+   if options.pairwise_y_binary then
+      y[torch.lt(y, 128)] = 0
+      y[torch.gt(y, 0)] = 255
+   end
+
+   return x, y
+end
 function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
    assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
    assert("crop_size % scale == 0", size % scale == 0)
@@ -111,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 
    for j = 1, 2 do
       -- TTA
-      local xi, yi, ri
+      local xi, yi, ri, ni
       if j == 1 then
 	 xi = x
 	 ni = x_noise
@@ -123,42 +176,55 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 	    ni = x_noise:transpose(2, 3):contiguous()
 	 end
 	 yi = y:transpose(2, 3):contiguous()
-	 ri = lowres_y:transpose(2, 3):contiguous()
+	 if lowres_y then
+	    ri = lowres_y:transpose(2, 3):contiguous()
+	 end
       end
-      local xv = image.vflip(xi)
+      local xv = iproc.vflip(xi)
       local nv
       if x_noise then
-	 nv = image.vflip(ni)
+	 nv = iproc.vflip(ni)
+      end
+      local yv = iproc.vflip(yi)
+      local rv
+      if ri then
+	 rv = iproc.vflip(ri)
       end
-      local yv = image.vflip(yi)
-      local rv = image.vflip(ri)
       table.insert(xs, xi)
       if ni then
 	 table.insert(ns, ni)
       end
       table.insert(ys, yi)
-      table.insert(ls, ri)
+      if ri then
+	 table.insert(ls, ri)
+      end
 
       table.insert(xs, xv)
       if nv then
 	 table.insert(ns, nv)
       end
       table.insert(ys, yv)
-      table.insert(ls, rv)
+      if rv then
+	 table.insert(ls, rv)
+      end
 
-      table.insert(xs, image.hflip(xi))
+      table.insert(xs, iproc.hflip(xi))
       if ni then
-	 table.insert(ns, image.hflip(ni))
+	 table.insert(ns, iproc.hflip(ni))
+      end
+      table.insert(ys, iproc.hflip(yi))
+      if ri then
+	 table.insert(ls, iproc.hflip(ri))
       end
-      table.insert(ys, image.hflip(yi))
-      table.insert(ls, image.hflip(ri))
 
-      table.insert(xs, image.hflip(xv))
+      table.insert(xs, iproc.hflip(xv))
       if nv then
-	 table.insert(ns, image.hflip(nv))
+	 table.insert(ns, iproc.hflip(nv))
+      end
+      table.insert(ys, iproc.hflip(yv))
+      if rv then
+	 table.insert(ls, iproc.hflip(rv))
       end
-      table.insert(ys, image.hflip(yv))
-      table.insert(ls, image.hflip(rv))
    end
    return xs, ys, ls, ns
 end
@@ -171,6 +237,9 @@ end
 local g_lowres_model = nil
 local g_lowres_gpu = nil
 function pairwise_transform_utils.low_resolution(src)
+--[[
+   -- I am not sure that the following process is thraed-safe
+
    g_lowres_model = g_lowres_model or lowres_model()
    if g_lowres_gpu == nil then
       --benchmark
@@ -203,6 +272,11 @@ function pairwise_transform_utils.low_resolution(src)
 	 size(src:size(3), src:size(2), "Box"):
 	    toTensor("byte", "RGB", "DHW")
    end
+--]]
+   return gm.Image(src, "RGB", "DHW"):
+      size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+      size(src:size(3), src:size(2), "Box"):
+      toTensor("byte", "RGB", "DHW")
 end
 
 return pairwise_transform_utils

+ 27 - 1
lib/reconstruct.lua

@@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
 	    break
 	 end
 	 input[j+1]:copy(x[input_indexes[i + j]])
+	 if model.w2nn_gcn then
+	    local mean = input[j + 1]:mean()
+	    local stdv = input[j + 1]:std()
+	    if stdv > 0 then
+	       input[j + 1]:add(-mean):div(stdv)
+	    else
+	       input[j + 1]:add(-mean)
+	    end
+	 end
 	 c = c + 1
       end
       input_cuda:copy(input)
@@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
    p.x_w = x:size(3)
    p.x_h = x:size(2)
    p.inner_scale = reconstruct.inner_scale(model)
-   local input_offset = math.ceil(offset / p.inner_scale)
+   local input_offset
+   if model.w2nn_input_offset then
+      input_offset = model.w2nn_input_offset
+   else
+      input_offset = math.ceil(offset / p.inner_scale)
+   end
    local input_block_size = block_size
    local process_size = input_block_size - input_offset * 2
    local h_blocks = math.floor(p.x_h / process_size) +
@@ -172,6 +186,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
    return output
 end
 function reconstruct.image(model, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    local i2rgb = false
    if x:size(1) == 1 then
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
@@ -194,6 +211,9 @@ function reconstruct.image(model, x, block_size)
    return x
 end
 function reconstruct.scale(model, scale, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    local i2rgb = false
    if x:size(1) == 1 then
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
@@ -287,6 +307,9 @@ local function tta(f, n, model, x, block_size)
    return average:div(#augments)
 end
 function reconstruct.image_tta(model, n, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    if reconstruct.is_rgb(model) then
       return tta(reconstruct.image_rgb, n, model, x, block_size)
    else
@@ -294,6 +317,9 @@ function reconstruct.image_tta(model, n, x, block_size)
    end
 end
 function reconstruct.scale_tta(model, n, scale, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    if reconstruct.is_rgb(model) then
       local f = function (model, x, offset, block_size)
 	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)

+ 22 - 1
lib/settings.lua

@@ -1,6 +1,7 @@
 require 'xlua'
 require 'pl'
 require 'trepl'
+require 'cutorch'
 
 -- global settings
 
@@ -18,7 +19,7 @@ cmd:text()
 cmd:text("waifu2x-training")
 cmd:text("Options:")
 cmd:option("-gpu", -1, 'GPU Device ID')
-cmd:option("-seed", 11, 'RNG seed')
+cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
 cmd:option("-data_dir", "./data", 'path to data directory')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
@@ -32,6 +33,20 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise
 cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
 cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
 cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
+cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
+cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
+cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
+cmd:option("-random_blur_sigma_max", 1.0, 'max sigma for random gaussian blur')
+cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
+cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
+cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
+cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
+cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
+cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
+cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
+cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
+cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
+cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-crop_size", 48, 'crop size')
@@ -59,6 +74,8 @@ cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
+cmd:option("-gpu", 1, 'Device ID')
+cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -75,6 +92,8 @@ end
 to_bool(settings, "plot")
 to_bool(settings, "save_history")
 to_bool(settings, "use_transparent_png")
+to_bool(settings, "pairwise_y_binary")
+to_bool(settings, "pairwise_flip")
 
 if settings.plot then
    require 'gnuplot'
@@ -148,4 +167,6 @@ end
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
+cutorch.setDevice(opt.gpu)
+
 return settings

+ 271 - 3
lib/srcnn.lua

@@ -136,6 +136,24 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
       error("unsupported backend:" .. backend)
    end
 end
+local function ReLU(backend)
+   if backend == "cunn" then
+      return nn.ReLU(true)
+   elseif backend == "cudnn" then
+      return cudnn.ReLU(true)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
+   if backend == "cunn" then
+      return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
+   elseif backend == "cudnn" then
+      return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
 
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
@@ -153,6 +171,7 @@ function srcnn.vgg_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "vgg_7"
@@ -190,6 +209,7 @@ function srcnn.vgg_12(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "vgg_12"
@@ -219,6 +239,7 @@ function srcnn.dilated_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "dilated_7"
@@ -249,6 +270,7 @@ function srcnn.upconv_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "upconv_7"
@@ -257,11 +279,255 @@ function srcnn.upconv_7(backend, ch)
    model.w2nn_resize = true
    model.w2nn_channels = ch
 
+   return model
+end
+
+-- large version of upconv_7
+-- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
+function srcnn.upconv_7l(backend, ch)
+   local model = nn.Sequential()
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
+   model:add(w2nn.InplaceClip01())
+   model:add(nn.View(-1):setNumInputDims(3))
+
+   model.w2nn_arch_name = "upconv_7l"
+   model.w2nn_offset = 14
+   model.w2nn_scale_factor = 2
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
+
+-- layerwise linear blending with skip connections
+-- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
+function srcnn.skiplb_7(backend, ch)
+   local function skip(backend, i, o)
+      local con = nn.Concat(2)
+      local conv = nn.Sequential()
+      conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
+      conv:add(nn.LeakyReLU(0.1, true))
+
+      -- depth concat
+      con:add(conv)
+      con:add(nn.Identity()) -- skip
+      return con
+   end
+   local model = nn.Sequential()
+   model:add(skip(backend, ch, 16))
+   model:add(skip(backend, 16+ch, 32))
+   model:add(skip(backend, 32+16+ch, 64))
+   model:add(skip(backend, 64+32+16+ch, 128))
+   model:add(skip(backend, 128+64+32+16+ch, 128))
+   model:add(skip(backend, 128+128+64+32+16+ch, 256))
+   -- input of last layer = [all layerwise output(contains input layer)].flatten
+   model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
+   model:add(w2nn.InplaceClip01())
+   model:add(nn.View(-1):setNumInputDims(3))
+   model.w2nn_arch_name = "skiplb_7"
+   model.w2nn_offset = 14
+   model.w2nn_scale_factor = 2
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
+
+-- dilated convolution + deconvolution
+-- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
+function srcnn.dilated_upconv_7(backend, ch)
+   local model = nn.Sequential()
+   model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
+   model:add(w2nn.InplaceClip01())
+   model:add(nn.View(-1):setNumInputDims(3))
+
+   model.w2nn_arch_name = "dilated_upconv_7"
+   model.w2nn_offset = 20
+   model.w2nn_scale_factor = 2
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
+
+-- ref: https://arxiv.org/abs/1609.04802
+-- note: no batch-norm, no zero-paading
+function srcnn.srresnet_2x(backend, ch)
+   local function resblock(backend)
+      local seq = nn.Sequential()
+      local con = nn.ConcatTable()
+      local conv = nn.Sequential()
+      conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+      conv:add(ReLU(backend))
+      conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+      conv:add(ReLU(backend))
+      con:add(conv)
+      con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
+      seq:add(con)
+      seq:add(nn.CAddTable())
+      return seq
+   end
+   local model = nn.Sequential()
+   --model:add(skip(backend, ch, 64 - ch))
+   model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(resblock(backend))
+   model:add(resblock(backend))
+   model:add(resblock(backend))
+   model:add(resblock(backend))
+   model:add(resblock(backend))
+   model:add(resblock(backend))
+   model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
+   model:add(ReLU(backend))
+   model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
+
+   model:add(w2nn.InplaceClip01())
+   --model:add(nn.View(-1):setNumInputDims(3))
+   model.w2nn_arch_name = "srresnet_2x"
+   model.w2nn_offset = 28
+   model.w2nn_scale_factor = 2
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
+
+   return model
+end
+
+-- large version of srresnet_2x. It's current best model but slow.
+function srcnn.resnet_14l(backend, ch)
+   local function resblock(backend, i, o)
+      local seq = nn.Sequential()
+      local con = nn.ConcatTable()
+      local conv = nn.Sequential()
+      conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
+      conv:add(nn.LeakyReLU(0.1, true))
+      conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
+      conv:add(nn.LeakyReLU(0.1, true))
+      con:add(conv)
+      if i == o then
+	 con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
+      else
+	 local seq = nn.Sequential()
+	 seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
+	 seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
+	 con:add(seq)
+      end
+      seq:add(con)
+      seq:add(nn.CAddTable())
+      return seq
+   end
+   local model = nn.Sequential()
+   model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(resblock(backend, 32, 64))
+   model:add(resblock(backend, 64, 64))
+   model:add(resblock(backend, 64, 128))
+   model:add(resblock(backend, 128, 128))
+   model:add(resblock(backend, 128, 256))
+   model:add(resblock(backend, 256, 256))
+   model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
+   model:add(w2nn.InplaceClip01())
+   model:add(nn.View(-1):setNumInputDims(3))
+   model.w2nn_arch_name = "resnet_14l"
+   model.w2nn_offset = 28
+   model.w2nn_scale_factor = 2
+   model.w2nn_resize = true
+   model.w2nn_channels = ch
+
    --model:cuda()
    --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
 
    return model
 end
+
+-- for segmentation
+function srcnn.fcn_v1(backend, ch)
+   -- input_size = 120
+   local model = nn.Sequential()
+   --i = 120
+   --model:cuda()
+   --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
+
+   model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
+
+   model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
+
+   model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
+
+   model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(nn.Dropout(0.5, false, true))
+
+   model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
+
+   model:add(w2nn.InplaceClip01())
+   model:add(nn.View(-1):setNumInputDims(3))
+   model.w2nn_arch_name = "fcn_v1"
+   model.w2nn_offset = 36
+   model.w2nn_scale_factor = 1
+   model.w2nn_channels = ch
+   model.w2nn_input_size = 120
+   --model.w2nn_gcn = true
+   
+   return model
+end
 function srcnn.create(model_name, backend, color)
    model_name = model_name or "vgg_7"
    backend = backend or "cunn"
@@ -282,8 +548,10 @@ function srcnn.create(model_name, backend, color)
       error("unsupported model_name: " .. model_name)
    end
 end
-
---local model = srcnn.upconv_6("cunn", 3):cuda()
---print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
+--[[
+local model = srcnn.fcn_v1("cunn", 3):cuda()
+print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
+print(model)
+--]]
 
 return srcnn

+ 3 - 0
lib/w2nn.lua

@@ -30,5 +30,8 @@ else
    require 'LeakyReLU'
    require 'ClippedWeightedHuberCriterion'
    require 'ClippedMSECriterion'
+   require 'SSIMCriterion'
+   require 'InplaceClip01'
+   require 'L1Criterion'
    return w2nn
 end

+ 1 - 0
models/resnet_14l/README.md

@@ -0,0 +1 @@
+Currently, this models are for the benchmark.

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 136 - 0
models/resnet_14l/photo/scale2.0x_model.t7


+ 1 - 0
models/upconv_7l/README.md

@@ -0,0 +1 @@
+Currently, this models are for the benchmark.

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 0 - 0
models/upconv_7l/art/scale2.0x_model.json


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 136 - 0
models/upconv_7l/art/scale2.0x_model.t7


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 0 - 0
models/upconv_7l/photo/scale2.0x_model.json


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 136 - 0
models/upconv_7l/photo/scale2.0x_model.t7


+ 116 - 31
tools/benchmark.lua

@@ -18,7 +18,7 @@ cmd:option("-dir", "./data/test", 'test image directory')
 cmd:option("-file", "", 'test image file list')
 cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
-cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff)')
+cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff|scale4)')
 cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
 cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
 cmd:option("-color", "y", '(rgb|y|r|g|b)')
@@ -46,6 +46,7 @@ cmd:option("-x_dir", "", 'input image for user method')
 cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
 cmd:option("-x_file", "", 'input image for user method')
 cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
+cmd:option("-border", 0, 'border px that will removed')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -153,12 +154,24 @@ local function baseline_scale(x, filter)
 		      x:size(2) * 2.0,
 		      filter)
 end
+local function baseline_scale4(x, filter)
+   return iproc.scale(x,
+		      x:size(3) * 4.0,
+		      x:size(2) * 4.0,
+		      filter)
+end
 local function transform_scale(x, opt)
    return iproc.scale(x,
 		      x:size(3) * 0.5,
 		      x:size(2) * 0.5,
 		      opt.filter, opt.resize_blur)
 end
+local function transform_scale4(x, opt)
+   return iproc.scale(x,
+		      x:size(3) * 0.25,
+		      x:size(2) * 0.25,
+		      opt.filter, opt.resize_blur)
+end
 
 local function transform_scale_jpeg(x, opt)
    x = iproc.scale(x,
@@ -179,9 +192,15 @@ local function transform_scale_jpeg(x, opt)
    end
    return iproc.byte2float(x)
 end
-
+local function remove_border(x, border)
+   return iproc.crop(x,
+		     border, border,
+		     x:size(3) - border,
+		     x:size(2) - border)
+end
 local function benchmark(opt, x, model1, model2)
-   local mse
+   local mse1, mse2
+   local won = {0, 0}
    local model1_mse = 0
    local model2_mse = 0
    local baseline_mse = 0
@@ -192,6 +211,10 @@ local function benchmark(opt, x, model1, model2)
    local model2_time = 0
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
+   local detail_fp = nil
+   if opt.save_info then
+      detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
+   end
    if opt.tta then
       scale_f = function(model, scale, x, block_size, batch_size)
 	 return reconstruct.scale_tta(model, opt.tta_level,
@@ -204,12 +227,15 @@ local function benchmark(opt, x, model1, model2)
    end
 
    for i = 1, #x do
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
       local basename = x[i].basename
       local input, model1_output, model2_output, baseline_output, ground_truth
 
       if opt.method == "scale" then
-	 input = transform_scale(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_scale(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
 	    model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
@@ -226,9 +252,29 @@ local function benchmark(opt, x, model1, model2)
 	    model2_time = model2_time + (sys.clock() - t)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
+      elseif opt.method == "scale4" then
+	 input = transform_scale4(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
+	 if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
+	    model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
+	    if model2 then
+	       model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
+	    end
+	 end
+	 t = sys.clock()
+	 model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
+	 model1_output = scale_f(model1, 2.0, model1_output, opt.crop_size, opt.batch_size)
+	 model1_time = model1_time + (sys.clock() - t)
+	 if model2 then
+	    t = sys.clock()
+	    model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
+	    model2_output = scale_f(model2, 2.0, model2_output, opt.crop_size, opt.batch_size)
+	    model2_time = model2_time + (sys.clock() - t)
+	 end
+	 baseline_output = baseline_scale4(input, opt.baseline_filter)
       elseif opt.method == "noise" then
-	 input = transform_jpeg(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_jpeg(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then
 	    model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
@@ -246,8 +292,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = input
       elseif opt.method == "noise_scale" then
-	 input = transform_scale_jpeg(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then
 	    if model1.noise_scale_model then
@@ -312,8 +358,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       elseif opt.method == "user" then
-	 input = x[i].x
-	 ground_truth = x[i].y
+	 input = iproc.byte2float(x[i].x)
+	 ground_truth = iproc.byte2float(x[i].y)
 	 local y_scale = ground_truth:size(2) / input:size(2)
 	 if y_scale > 1 then
 	    if opt.force_cudnn and i == 1 then
@@ -347,19 +393,44 @@ local function benchmark(opt, x, model1, model2)
 	    end
 	 end
       elseif opt.method == "diff" then
-	 input = x[i].x
-	 ground_truth = x[i].y
+	 input = iproc.byte2float(x[i].x)
+	 ground_truth = iproc.byte2float(x[i].y)
 	 model1_output = input
       end
-      mse = MSE(ground_truth, model1_output, opt.color)
-      model1_mse = model1_mse + mse
-      model1_psnr = model1_psnr + MSE2PSNR(mse)
+      if opt.border > 0 then
+	 ground_truth = remove_border(ground_truth, opt.border)
+	 model1_output = remove_border(model1_output, opt.border)
+      end
+      mse1 = MSE(ground_truth, model1_output, opt.color)
+      model1_mse = model1_mse + mse1
+      model1_psnr = model1_psnr + MSE2PSNR(mse1)
+
+      local won_model = 1
       if model2 then
-	 mse = MSE(ground_truth, model2_output, opt.color)
-	 model2_mse = model2_mse + mse
-	 model2_psnr = model2_psnr + MSE2PSNR(mse)
+	 if opt.border > 0 then
+	    model2_output = remove_border(model2_output, opt.border)
+	 end
+	 mse2 = MSE(ground_truth, model2_output, opt.color)
+	 model2_mse = model2_mse + mse2
+	 model2_psnr = model2_psnr + MSE2PSNR(mse2)
+
+	 if mse1 < mse2 then
+	    won[1] = won[1] + 1
+	 elseif mse1 > mse2 then
+	    won[2] = won[2] + 1
+	    won_model = 2
+	 end
+	 if detail_fp then
+	    detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
+					  MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	 end
+      else
+	 if detail_fp then
+	    detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
+	 end
       end
       if baseline_output then
+	 baseline_output = remove_border(baseline_output, opt.border)
 	 mse = MSE(ground_truth, baseline_output, opt.color)
 	 baseline_mse = baseline_mse + mse
 	 baseline_psnr = baseline_psnr + MSE2PSNR(mse)
@@ -382,29 +453,31 @@ local function benchmark(opt, x, model1, model2)
 	 if model2 then
 	    if baseline_output then
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
 				i, #x,
 				model1_time,
 				model2_time,
 				math.sqrt(baseline_mse / i),
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 				baseline_psnr / i,
-				model1_psnr / i, model2_psnr / i
+				model1_psnr / i, model2_psnr / i,
+				won[1], won[2]
 		  ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
 				i, #x,
 				model1_time,
 				model2_time,
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				model1_psnr / i, model2_psnr / i
+				model1_psnr / i, model2_psnr / i,
+				won[1], won[2]
 		  ))
 	    end
 	 else
 	    if baseline_output then
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
 				i, #x,
 				model1_time,
 				math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
@@ -412,7 +485,7 @@ local function benchmark(opt, x, model1, model2)
 		  ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model1_rmse=%f, model1_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
 				i, #x,
 				model1_time,
 				math.sqrt(model1_mse / i), model1_psnr / i
@@ -438,6 +511,9 @@ local function benchmark(opt, x, model1, model2)
 				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
       end
       fp:close()
+      if detail_fp then
+	 detail_fp:close()
+      end
    end
    io.stdout:write("\n")
 end
@@ -448,7 +524,7 @@ local function load_data_from_dir(test_dir)
       local name = path.basename(files[i])
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      local img = image_loader.load_float(files[i])
+      local img = image_loader.load_byte(files[i])
       if img then
 	 table.insert(test_x, {y = iproc.crop_mod4(img),
 			       basename = base})
@@ -456,6 +532,9 @@ local function load_data_from_dir(test_dir)
       if opt.show_progress then
 	 xlua.progress(i, #files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test_x
 end
@@ -466,7 +545,7 @@ local function load_data_from_file(test_file)
       local name = path.basename(files[i])
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      local img = image_loader.load_float(files[i])
+      local img = image_loader.load_byte(files[i])
       if img then
 	 table.insert(test_x, {y = iproc.crop_mod4(img),
 			       basename = base})
@@ -474,6 +553,9 @@ local function load_data_from_file(test_file)
       if opt.show_progress then
 	 xlua.progress(i, #files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test_x
 end
@@ -519,16 +601,19 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
    end
    for i = 1, #y_files do
       local key = get_basename(y_files[i])
-      local x = image_loader.load_float(basename_db[key].x)
-      local y = image_loader.load_float(basename_db[key].y)
+      local x = image_loader.load_byte(basename_db[key].x)
+      local y = image_loader.load_byte(basename_db[key].y)
       if x and y then
 	 table.insert(test, {y = y,
 			     x = x,
-			     basename = base})
+			     basename = key})
       end
       if opt.show_progress then
 	 xlua.progress(i, #y_files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test
 end
@@ -563,7 +648,7 @@ if opt.show_progress then
    print(opt)
 end
 
-if opt.method == "scale" then
+if opt.method == "scale" or opt.method == "scale4" then
    local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
    local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
    local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn)

+ 2 - 0
tools/export_all.sh

@@ -33,5 +33,7 @@ export_model() {
 }
 export_model vgg_7/art
 export_model upconv_7/art
+export_model upconv_7l/art
 export_model vgg_7/photo
 export_model upconv_7/photo
+export_model upconv_7l/photo

+ 14 - 9
tools/export_model.lua

@@ -22,7 +22,6 @@ local function includes(s, a)
    end
    return false
 end
-
 local function get_bias(mod)
    if mod.bias then
       return mod.bias:float()
@@ -31,20 +30,18 @@ local function get_bias(mod)
       return torch.FloatTensor(mod.nOutputPlane):zero()
    end
 end
-local function export(model, output)
+local function export_weight(jmodules, seq)
    local targets = {"nn.SpatialConvolutionMM",
 		    "cudnn.SpatialConvolution",
 		    "nn.SpatialFullConvolution",
 		    "cudnn.SpatialFullConvolution"
    }
-   local jmodules = {}
-   local model_config = meta_data(model)
-   local first_layer = true
-
-   for k = 1, #model.modules do
-      local mod = model.modules[k]
+   for k = 1, #seq.modules do
+      local mod = seq.modules[k]
       local name = torch.typename(mod)
-      if includes(name, targets) then
+      if name == "nn.Sequential" or name == "nn.ConcatTable" then
+	 export_weight(jmodules, mod)
+      elseif includes(name, targets) then
 	 local weight = mod.weight:float()
 	 if name:match("FullConvolution") then
 	    weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
@@ -71,6 +68,14 @@ local function export(model, output)
 	 table.insert(jmodules, jmod)
       end
    end
+end
+local function export(model, output)
+   local jmodules = {}
+   local model_config = meta_data(model)
+   local first_layer = true
+
+   export_weight(jmodules, model)
+
    local fp = io.open(output, "w")
    if not fp then
       error("IO Error: " .. output)

+ 281 - 144
train.lua

@@ -3,15 +3,14 @@ local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^
 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 require 'optim'
 require 'xlua'
-
+require 'image'
 require 'w2nn'
+local threads = require 'threads'
 local settings = require 'settings'
 local srcnn = require 'srcnn'
 local minibatch_adam = require 'minibatch_adam'
 local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
-local compression = require 'compression'
-local pairwise_transform = require 'pairwise_transform'
 local image_loader = require 'image_loader'
 
 local function save_test_scale(model, rgb, file)
@@ -42,20 +41,218 @@ local function split_data(x, test_size)
    end
    return train_x, valid_x
 end
-local function make_validation_set(x, transformer, n, patches)
+
+local g_transform_pool = nil
+local g_mutex = nil
+local g_mutex_id = nil
+local function transform_pool_init(has_resize, offset)
+   local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
+   g_mutex = threads.Mutex()
+   g_mutex_id = g_mutex:id()
+   g_transform_pool = threads.Threads(
+      nthread,
+      threads.safe(
+      function(threadid)
+	 require 'pl'
+	 local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
+	 require 'torch'
+	 require 'nn'
+	 require 'cunn'
+
+	 torch.setnumthreads(1)
+	 torch.setdefaulttensortype("torch.FloatTensor")
+
+	 local threads = require 'threads'
+	 local compression = require 'compression'
+	 local pairwise_transform = require 'pairwise_transform'
+
+	 function transformer(x, is_validation, n)
+	    local mutex = threads.Mutex(g_mutex_id)
+	    local meta = {data = {}}
+	    local y = nil
+	    if type(x) == "table" and type(x[2]) == "table" then
+	       meta = x[2]
+	       if x[1].x and x[1].y then
+		  y = compression.decompress(x[1].y)
+		  x = compression.decompress(x[1].x)
+	       else
+		  x = compression.decompress(x[1])
+	       end
+	    else
+	       x = compression.decompress(x)
+	    end
+	    n = n or settings.patches
+	    if is_validation == nil then is_validation = false end
+	    local random_color_noise_rate = nil 
+	    local random_overlay_rate = nil
+	    local active_cropping_rate = nil
+	    local active_cropping_tries = nil
+	    if is_validation then
+	       active_cropping_rate = settings.active_cropping_rate
+	       active_cropping_tries = settings.active_cropping_tries
+	       random_color_noise_rate = 0.0
+	       random_overlay_rate = 0.0
+	    else
+	       active_cropping_rate = settings.active_cropping_rate
+	       active_cropping_tries = settings.active_cropping_tries
+	       random_color_noise_rate = settings.random_color_noise_rate
+	       random_overlay_rate = settings.random_overlay_rate
+	    end
+	    if settings.method == "scale" then
+	       local conf = tablex.update({
+		     mutex = mutex,
+		     downsampling_filters = settings.downsampling_filters,
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
+		     max_size = settings.max_size,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     rgb = (settings.color == "rgb"),
+		     x_upsampling = not has_resize,
+		     resize_blur_min = settings.resize_blur_min,
+		     resize_blur_max = settings.resize_blur_max}, meta)
+	       return pairwise_transform.scale(x,
+					       settings.scale,
+					       settings.crop_size, offset,
+					       n, conf)
+	    elseif settings.method == "noise" then
+	       local conf = tablex.update({
+		     mutex = mutex,
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
+		     max_size = settings.max_size,
+		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     nr_rate = settings.nr_rate,
+		     rgb = (settings.color == "rgb")}, meta)
+	       return pairwise_transform.jpeg(x,
+					      settings.style,
+					      settings.noise_level,
+					      settings.crop_size, offset,
+					      n, conf)
+	    elseif settings.method == "noise_scale" then
+	       local conf = tablex.update({
+		     mutex = mutex,
+		     downsampling_filters = settings.downsampling_filters,
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
+		     max_size = settings.max_size,
+		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
+		     nr_rate = settings.nr_rate,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     rgb = (settings.color == "rgb"),
+		     x_upsampling = not has_resize,
+		     resize_blur_min = settings.resize_blur_min,
+		     resize_blur_max = settings.resize_blur_max}, meta)
+	       return pairwise_transform.jpeg_scale(x,
+						    settings.scale,
+						    settings.style,
+						    settings.noise_level,
+						    settings.crop_size, offset,
+						    n, conf)
+	    elseif settings.method == "user" then
+	       if is_validation == nil then is_validation = false end
+	       local rotate_rate = nil 
+	       local scale_rate = nil
+	       local negate_rate = nil
+	       local negate_x_rate = nil
+	       if is_validation then
+		  rotate_rate = 0
+		  scale_rate = 0
+		  negate_rate = 0
+		  negate_x_rate = 0
+	       else
+		  rotate_rate = settings.random_pairwise_rotate_rate
+		  scale_rate = settings.random_pairwise_scale_rate
+		  negate_rate = settings.random_pairwise_negate_rate
+		  negate_x_rate = settings.random_pairwise_negate_x_rate
+	       end
+	       local conf = tablex.update({
+		     gcn = settings.gcn,
+		     max_size = settings.max_size,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     random_pairwise_rotate_rate = rotate_rate,
+		     random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
+		     random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
+		     random_pairwise_scale_rate = scale_rate,
+		     random_pairwise_scale_min = settings.random_pairwise_scale_min,
+		     random_pairwise_scale_max = settings.random_pairwise_scale_max,
+		     random_pairwise_negate_rate = negate_rate,
+		     random_pairwise_negate_x_rate = negate_x_rate,
+		     pairwise_y_binary = settings.pairwise_y_binary,
+		     pairwise_flip = settings.pairwise_flip,
+		     rgb = (settings.color == "rgb")}, meta)
+	       return pairwise_transform.user(x, y,
+					      settings.crop_size, offset,
+					      n, conf)
+	    end
+	 end
+      end)
+   )
+   g_transform_pool:synchronize()
+end
+
+local function make_validation_set(x, n, patches)
+   local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
    n = n or 4
    local validation_patches = math.min(16, patches or 16)
    local data = {}
+
+   g_transform_pool:synchronize()
+   torch.setnumthreads(1) -- 1
+
    for i = 1, #x do
       for k = 1, math.max(n / validation_patches, 1) do
-	 local xy = transformer(x[i], true, validation_patches)
-	 for j = 1, #xy do
-	    table.insert(data, {x = xy[j][1], y = xy[j][2]})
-	 end
+	 local input = x[i]
+	 g_transform_pool:addjob(
+	    function()
+	       local xy = transformer(input, true, validation_patches)
+	       return xy
+	    end,
+	    function(xy)
+	       for j = 1, #xy do
+		  table.insert(data, {x = xy[j][1], y = xy[j][2]})
+	       end
+	    end
+	 )
+      end
+      if i % 20 == 0 then
+	 collectgarbage()
+	 g_transform_pool:synchronize()
+	 xlua.progress(i, #x)
       end
-      xlua.progress(i, #x)
-      collectgarbage()
    end
+   g_transform_pool:synchronize()
+   torch.setnumthreads(nthread) -- revert
+
    local new_data = {}
    local perm = torch.randperm(#data)
    for i = 1, perm:size(1) do
@@ -102,144 +299,71 @@ local function validate(model, criterion, eval_metric, data, batch_size)
 end
 
 local function create_criterion(model)
-   if reconstruct.is_rgb(model) then
-      local offset = reconstruct.offset_size(model)
-      local output_w = settings.crop_size - offset * 2
-      local weight = torch.Tensor(3, output_w * output_w)
-      weight[1]:fill(0.29891 * 3) -- R
-      weight[2]:fill(0.58661 * 3) -- G
-      weight[3]:fill(0.11448 * 3) -- B
-      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
-   else
-      local offset = reconstruct.offset_size(model)
-      local output_w = settings.crop_size - offset * 2
-      local weight = torch.Tensor(1, output_w * output_w)
-      weight[1]:fill(1.0)
-      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
-   end
-end
-local function transformer(model, x, is_validation, n, offset)
-   local meta = {data = {}}
-   local y = nil
-   if type(x) == "table" and type(x[2]) == "table" then
-      meta = x[2]
-      if x[1].x and x[1].y then
-	 y = compression.decompress(x[1].y)
-	 x = compression.decompress(x[1].x)
+   if settings.loss == "huber" then
+      if reconstruct.is_rgb(model) then
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(3, output_w * output_w)
+	 weight[1]:fill(0.29891 * 3) -- R
+	 weight[2]:fill(0.58661 * 3) -- G
+	 weight[3]:fill(0.11448 * 3) -- B
+	 return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
       else
-	 x = compression.decompress(x[1])
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(1, output_w * output_w)
+	 weight[1]:fill(1.0)
+	 return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
       end
+   elseif settings.loss == "l1" then
+      return w2nn.L1Criterion():cuda()
+   elseif settings.loss == "mse" then
+      return w2nn.ClippedMSECriterion(0, 1.0):cuda()
    else
-      x = compression.decompress(x)
-   end
-   n = n or settings.patches
-   if is_validation == nil then is_validation = false end
-   local random_color_noise_rate = nil 
-   local random_overlay_rate = nil
-   local active_cropping_rate = nil
-   local active_cropping_tries = nil
-   if is_validation then
-      active_cropping_rate = settings.active_cropping_rate
-      active_cropping_tries = settings.active_cropping_tries
-      random_color_noise_rate = 0.0
-      random_overlay_rate = 0.0
-   else
-      active_cropping_rate = settings.active_cropping_rate
-      active_cropping_tries = settings.active_cropping_tries
-      random_color_noise_rate = settings.random_color_noise_rate
-      random_overlay_rate = settings.random_overlay_rate
-   end
-   if settings.method == "scale" then
-      local conf = tablex.update({
-	    downsampling_filters = settings.downsampling_filters,
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb"),
-	    x_upsampling = not reconstruct.has_resize(model),
-	    resize_blur_min = settings.resize_blur_min,
-	 resize_blur_max = settings.resize_blur_max}, meta)
-      return pairwise_transform.scale(x,
-				      settings.scale,
-				      settings.crop_size, offset,
-				      n, conf)
-   elseif settings.method == "noise" then
-      local conf = tablex.update({
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    nr_rate = settings.nr_rate,
-	    rgb = (settings.color == "rgb")}, meta)
-      return pairwise_transform.jpeg(x,
-				     settings.style,
-				     settings.noise_level,
-				     settings.crop_size, offset,
-				     n, conf)
-   elseif settings.method == "noise_scale" then
-      local conf = tablex.update({
-	    downsampling_filters = settings.downsampling_filters,
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
-	    nr_rate = settings.nr_rate,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb"),
-	    x_upsampling = not reconstruct.has_resize(model),
-	    resize_blur_min = settings.resize_blur_min,
-	    resize_blur_max = settings.resize_blur_max}, meta)
-      return pairwise_transform.jpeg_scale(x,
-					   settings.scale,
-					   settings.style,
-					   settings.noise_level,
-					   settings.crop_size, offset,
-					   n, conf)
-   elseif settings.method == "user" then
-      local conf = tablex.update({
-	    max_size = settings.max_size,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb")}, meta)
-      return pairwise_transform.user(x, y,
-				     settings.crop_size, offset,
-				     n, conf)
+      error("unsupported loss .." .. settings.loss)
    end
 end
 
-local function resampling(x, y, train_x, transformer, input_size, target_size)
+local function resampling(x, y, train_x)
    local c = 1
    local shuffle = torch.randperm(#train_x)
+   local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
+   torch.setnumthreads(1) -- 1
+
    for t = 1, #train_x do
-      xlua.progress(t, #train_x)
-      local xy = transformer(train_x[shuffle[t]], false, settings.patches)
-      for i = 1, #xy do
-         x[c]:copy(xy[i][1])
-	 y[c]:copy(xy[i][2])
-	 c = c + 1
-	 if c > x:size(1) then
-	    break
+      local input = train_x[shuffle[t]]
+      g_transform_pool:addjob(
+	 function()
+	    local xy = transformer(input, false, settings.patches)
+	    return xy
+	 end,
+	 function(xy)
+	    for i = 1, #xy do
+	       if c <= x:size(1) then
+		  x[c]:copy(xy[i][1])
+		  y[c]:copy(xy[i][2])
+		  c = c + 1
+	       else
+		  break
+	       end
+	    end
 	 end
+      )
+      if t % 50 == 0 then
+	 collectgarbage()
+	 g_transform_pool:synchronize()
+	 xlua.progress(t, #train_x)
       end
       if c > x:size(1) then
 	 break
       end
-      if t % 50 == 0 then
-	 collectgarbage()
-      end
    end
+   g_transform_pool:synchronize()
    xlua.progress(#train_x, #train_x)
+   torch.setnumthreads(nthread) -- revert
 end
 local function get_oracle_data(x, y, instance_loss, k, samples)
    local index = torch.LongTensor(instance_loss:size(1))
@@ -262,6 +386,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples)
 end
 
 local function remove_small_image(x)
+   local compression = require 'compression'
    local new_x = {}
    for i = 1, #x do
       local xe, meta, x_s
@@ -293,6 +418,8 @@ local function plot(train, valid)
 	 {'validation', torch.Tensor(valid), '-'}})
 end
 local function train()
+   local x = remove_small_image(torch.load(settings.images))
+   local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local hist_train = {}
    local hist_valid = {}
    local model
@@ -301,20 +428,30 @@ local function train()
    else
       model = srcnn.create(settings.model, settings.backend, settings.color)
    end
+   if model.w2nn_input_size then
+      if settings.crop_size ~= model.w2nn_input_size then
+	 io.stderr:write(string.format("warning: crop_size is replaced with %d\n",
+				       model.w2nn_input_size))
+	 settings.crop_size = model.w2nn_input_size
+      end
+   end
+   if model.w2nn_gcn then
+      settings.gcn = true
+   else
+      settings.gcn = false
+   end
    dir.makepath(settings.model_dir)
 
    local offset = reconstruct.offset_size(model)
-   local pairwise_func = function(x, is_validation, n)
-      return transformer(model, x, is_validation, n, offset)
-   end
+   transform_pool_init(reconstruct.has_resize(model), offset)
+
    local criterion = create_criterion(model)
    local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
-   local x = remove_small_image(torch.load(settings.images))
-   local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local adam_config = {
       xLearningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
-      xLearningRateDecay = settings.learning_rate_decay
+      xLearningRateDecay = settings.learning_rate_decay,
+      xInstanceLoss = (settings.oracle_rate > 0)
    }
    local ch = nil
    if settings.color == "y" then
@@ -324,7 +461,7 @@ local function train()
    end
    local best_score = 1000.0
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, pairwise_func,
+   local valid_xy = make_validation_set(valid_x, 
 					settings.validation_crops,
 					settings.patches)
    valid_x = nil
@@ -358,7 +495,7 @@ local function train()
 	 if oracle_n > 0 then
 	    local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
 	    resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
-		       y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func)
+		       y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x)
 	    x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
 	    y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
 
@@ -374,7 +511,7 @@ local function train()
 			     min = 0,
 			     max = 1}))
 	 else
-	    resampling(x, y, train_x, pairwise_func)
+	    resampling(x, y, train_x)
 	 end
       else
 	 resampling(x, y, train_x, pairwise_func)
@@ -395,9 +532,9 @@ local function train()
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score.MSE < best_score then
+	 if score.loss < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
-	    best_score = score.MSE
+	    best_score = score.loss
 	    print("* model has updated")
 	    if settings.save_history then
 	       torch.save(settings.model_file_best, model:clearState(), "ascii")
@@ -446,7 +583,7 @@ local function train()
 	       end
 	    end
 	 end
-	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
+	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
 	 collectgarbage()
       end
    end

+ 2 - 0
waifu2x.lua

@@ -267,6 +267,7 @@ local function waifu2x()
    cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow')
    cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
    cmd:option("-q", 0, 'quiet (0|1)')
+   cmd:option("-gpu", 1, 'Device ID')
 
    local opt = cmd:parse(arg)
    if opt.method:len() > 0 then
@@ -292,5 +293,6 @@ local function waifu2x()
    else
       convert_frames(opt)
    end
+   cutorch.setDevice(opt.gpu)
 end
 waifu2x()

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott