Browse Source

Convert model files; Add new pretrained model

- Add new pretrained model to ./models/upconv_7
- Move old models to ./models/vgg_7
- Use nn.LeakyReLU instead of w2nn.LeakyReLU
- Add useful attribute to .json

New JSON attribute:
The first layer has `model_config` attribute.
It contains:
  model_arch: architecture name of model. see `lib/srcnn.lua`
  scale_factor: if scale_factor > 1, model:forward() changes image resolution with scale_factor.
  channels: input/output channels. if channels == 3, model is RGB model.
  offset: pixel size that is to be removed from output.
          for example:
            (scale_factor=1, offset=7, input=100x100) => output=(100-7)x(100-7)
            (scale_factor=2, offset=12, input=100x100) => output=(100*2-12)x(100*2-12)
And each layer has `class_name` attribute.
nagadomi 9 years ago
parent
commit
a210090033
61 changed files with 1385 additions and 1379 deletions
  1. 3 0
      .gitignore
  2. 4 4
      lib/pairwise_transform_scale.lua
  3. 41 38
      lib/srcnn.lua
  4. 1 0
      models/anime_style_art
  5. 0 0
      models/anime_style_art/noise1_model.json
  6. 0 0
      models/anime_style_art/noise2_model.json
  7. 0 0
      models/anime_style_art/noise3_model.json
  8. 0 0
      models/anime_style_art/scale2.0x_model.json
  9. 1 0
      models/anime_style_art_rgb
  10. 0 0
      models/anime_style_art_rgb/noise1_model.json
  11. 0 0
      models/anime_style_art_rgb/noise2_model.json
  12. 0 0
      models/anime_style_art_rgb/noise3_model.json
  13. 0 0
      models/anime_style_art_rgb/scale2.0x_model.json
  14. 1 0
      models/photo
  15. 0 0
      models/photo/noise1_model.json
  16. 0 0
      models/photo/noise2_model.json
  17. 0 0
      models/photo/noise3_model.json
  18. 0 0
      models/photo/scale2.0x_model.json
  19. 1 0
      models/ukbench
  20. 0 0
      models/ukbench/scale2.0x_model.json
  21. 1 0
      models/upconv_7/art/noise1_model.json
  22. 1 0
      models/upconv_7/art/noise1_model.t7
  23. 1 0
      models/upconv_7/art/noise2_model.json
  24. 1 0
      models/upconv_7/art/noise2_model.t7
  25. 1 0
      models/upconv_7/art/noise3_model.json
  26. 1 0
      models/upconv_7/art/noise3_model.t7
  27. 0 0
      models/upconv_7/art/scale2.0x_model.json
  28. 132 0
      models/upconv_7/art/scale2.0x_model.t7
  29. 0 0
      models/vgg_7/art/noise1_model.json
  30. 48 77
      models/vgg_7/art/noise1_model.t7
  31. 0 0
      models/vgg_7/art/noise2_model.json
  32. 48 77
      models/vgg_7/art/noise2_model.t7
  33. 0 0
      models/vgg_7/art/noise3_model.json
  34. 71 81
      models/vgg_7/art/noise3_model.t7
  35. 0 0
      models/vgg_7/art/scale2.0x_model.json
  36. 48 77
      models/vgg_7/art/scale2.0x_model.t7
  37. 0 0
      models/vgg_7/art_y/noise1_model.json
  38. 70 80
      models/vgg_7/art_y/noise1_model.t7
  39. 0 0
      models/vgg_7/art_y/noise2_model.json
  40. 69 79
      models/vgg_7/art_y/noise2_model.t7
  41. 0 0
      models/vgg_7/art_y/noise3_model.json
  42. 102 119
      models/vgg_7/art_y/noise3_model.t7
  43. 0 0
      models/vgg_7/art_y/scale2.0x_model.json
  44. 102 119
      models/vgg_7/art_y/scale2.0x_model.t7
  45. 0 0
      models/vgg_7/photo/noise1_model.json
  46. 144 160
      models/vgg_7/photo/noise1_model.t7
  47. 0 0
      models/vgg_7/photo/noise2_model.json
  48. 144 160
      models/vgg_7/photo/noise2_model.t7
  49. 0 0
      models/vgg_7/photo/noise3_model.json
  50. 102 119
      models/vgg_7/photo/noise3_model.t7
  51. 0 0
      models/vgg_7/photo/scale2.0x_model.json
  52. 48 77
      models/vgg_7/photo/scale2.0x_model.t7
  53. 0 0
      models/vgg_7/ukbench/scale2.0x_model.json
  54. 48 77
      models/vgg_7/ukbench/scale2.0x_model.t7
  55. 21 0
      tools/export.sh
  56. 54 16
      tools/export_model.lua
  57. 23 0
      tools/rebuild.sh
  58. 49 15
      tools/rebuild_model.lua
  59. 2 2
      train.lua
  60. 1 1
      waifu2x.lua
  61. 1 1
      web.lua

+ 3 - 0
.gitignore

@@ -10,7 +10,10 @@ models/*
 !models/anime_style_art_rgb
 !models/anime_style_art_rgb
 !models/ukbench
 !models/ukbench
 !models/photo
 !models/photo
+!models/upconv_7
+!models/vgg_7
 models/*/*.png
 models/*/*.png
+models/*/*/*.png
 
 
 waifu2x.log
 waifu2x.log
 waifu2x-*.log
 waifu2x-*.log

+ 4 - 4
lib/pairwise_transform_scale.lua

@@ -46,10 +46,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
 
 
    for i = 1, n do
    for i = 1, n do
       local xc, yc = pairwise_utils.active_cropping(x, y,
       local xc, yc = pairwise_utils.active_cropping(x, y,
-							size,
-							scale_inner,
-							options.active_cropping_rate,
-							options.active_cropping_tries)
+						    size,
+						    scale_inner,
+						    options.active_cropping_rate,
+						    options.active_cropping_tries)
       xc = iproc.byte2float(xc)
       xc = iproc.byte2float(xc)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then

+ 41 - 38
lib/srcnn.lua

@@ -44,7 +44,8 @@ function srcnn.channels(model)
 end
 end
 function srcnn.backend(model)
 function srcnn.backend(model)
    local conv = model:findModules("cudnn.SpatialConvolution")
    local conv = model:findModules("cudnn.SpatialConvolution")
-   if #conv > 0 then
+   local fullconv = model:findModules("cudnn.SpatialFullConvolution")
+   if #conv > 0 or #fullconv > 0 then
       return "cudnn"
       return "cudnn"
    else
    else
       return "cunn"
       return "cunn"
@@ -132,17 +133,17 @@ end
 function srcnn.vgg_7(backend, ch)
 function srcnn.vgg_7(backend, ch)
    local model = nn.Sequential()
    local model = nn.Sequential()
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    model:add(nn.View(-1):setNumInputDims(3))
 
 
@@ -159,27 +160,27 @@ end
 function srcnn.vgg_12(backend, ch)
 function srcnn.vgg_12(backend, ch)
    local model = nn.Sequential()
    local model = nn.Sequential()
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    model:add(nn.View(-1):setNumInputDims(3))
 
 
@@ -198,17 +199,17 @@ end
 function srcnn.dilated_7(backend, ch)
 function srcnn.dilated_7(backend, ch)
    local model = nn.Sequential()
    local model = nn.Sequential()
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
    model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
    model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
    model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
    model:add(nn.View(-1):setNumInputDims(3))
    model:add(nn.View(-1):setNumInputDims(3))
 
 
@@ -229,17 +230,17 @@ function srcnn.upconv_7(backend, ch)
    local model = nn.Sequential()
    local model = nn.Sequential()
 
 
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
    model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
 
 
    model.w2nn_arch_name = "upconv_7"
    model.w2nn_arch_name = "upconv_7"
@@ -257,19 +258,19 @@ function srcnn.upconv_8_4x(backend, ch)
    local model = nn.Sequential()
    local model = nn.Sequential()
 
 
    model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
    model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
    model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.LeakyReLU(0.1))
+   model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
    model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
 
 
    model.w2nn_arch_name = "upconv_8_4x"
    model.w2nn_arch_name = "upconv_8_4x"
@@ -296,7 +297,9 @@ function srcnn.create(model_name, backend, color)
       error("unsupported color: " .. color)
       error("unsupported color: " .. color)
    end
    end
    if srcnn[model_name] then
    if srcnn[model_name] then
-      return srcnn[model_name](backend, ch)
+      local model = srcnn[model_name](backend, ch)
+      assert(model.w2nn_offset == (model.w2nn_offset / model.w2nn_scale_factor) * model.w2nn_scale_factor)
+      return model
    else
    else
       error("unsupported model_name: " .. model_name)
       error("unsupported model_name: " .. model_name)
    end
    end

+ 1 - 0
models/anime_style_art

@@ -0,0 +1 @@
+./vgg_7/art_y

File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art/noise1_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art/noise2_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art/noise3_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art/scale2.0x_model.json


+ 1 - 0
models/anime_style_art_rgb

@@ -0,0 +1 @@
+./vgg_7/art

File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art_rgb/noise1_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art_rgb/noise2_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art_rgb/noise3_model.json


File diff suppressed because it is too large
+ 0 - 0
models/anime_style_art_rgb/scale2.0x_model.json


+ 1 - 0
models/photo

@@ -0,0 +1 @@
+./vgg_7/photo

File diff suppressed because it is too large
+ 0 - 0
models/photo/noise1_model.json


File diff suppressed because it is too large
+ 0 - 0
models/photo/noise2_model.json


File diff suppressed because it is too large
+ 0 - 0
models/photo/noise3_model.json


File diff suppressed because it is too large
+ 0 - 0
models/photo/scale2.0x_model.json


+ 1 - 0
models/ukbench

@@ -0,0 +1 @@
+./vgg_7/ukbench

File diff suppressed because it is too large
+ 0 - 0
models/ukbench/scale2.0x_model.json


+ 1 - 0
models/upconv_7/art/noise1_model.json

@@ -0,0 +1 @@
+../../vgg_7/art/noise1_model.json

+ 1 - 0
models/upconv_7/art/noise1_model.t7

@@ -0,0 +1 @@
+../../vgg_7/art/noise1_model.t7

+ 1 - 0
models/upconv_7/art/noise2_model.json

@@ -0,0 +1 @@
+../../vgg_7/art/noise2_model.json

+ 1 - 0
models/upconv_7/art/noise2_model.t7

@@ -0,0 +1 @@
+../../vgg_7/art/noise2_model.t7

+ 1 - 0
models/upconv_7/art/noise3_model.json

@@ -0,0 +1 @@
+../../vgg_7/art/noise3_model.json

+ 1 - 0
models/upconv_7/art/noise3_model.t7

@@ -0,0 +1 @@
+../../vgg_7/art/noise3_model.t7

File diff suppressed because it is too large
+ 0 - 0
models/upconv_7/art/scale2.0x_model.json


File diff suppressed because it is too large
+ 132 - 0
models/upconv_7/art/scale2.0x_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art/noise1_model.json


File diff suppressed because it is too large
+ 48 - 77
models/vgg_7/art/noise1_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art/noise2_model.json


File diff suppressed because it is too large
+ 48 - 77
models/vgg_7/art/noise2_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art/noise3_model.json


File diff suppressed because it is too large
+ 71 - 81
models/vgg_7/art/noise3_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art/scale2.0x_model.json


File diff suppressed because it is too large
+ 48 - 77
models/vgg_7/art/scale2.0x_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art_y/noise1_model.json


File diff suppressed because it is too large
+ 70 - 80
models/vgg_7/art_y/noise1_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art_y/noise2_model.json


File diff suppressed because it is too large
+ 69 - 79
models/vgg_7/art_y/noise2_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art_y/noise3_model.json


File diff suppressed because it is too large
+ 102 - 119
models/vgg_7/art_y/noise3_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/art_y/scale2.0x_model.json


File diff suppressed because it is too large
+ 102 - 119
models/vgg_7/art_y/scale2.0x_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/photo/noise1_model.json


File diff suppressed because it is too large
+ 144 - 160
models/vgg_7/photo/noise1_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/photo/noise2_model.json


File diff suppressed because it is too large
+ 144 - 160
models/vgg_7/photo/noise2_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/photo/noise3_model.json


File diff suppressed because it is too large
+ 102 - 119
models/vgg_7/photo/noise3_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/photo/scale2.0x_model.json


File diff suppressed because it is too large
+ 48 - 77
models/vgg_7/photo/scale2.0x_model.t7


File diff suppressed because it is too large
+ 0 - 0
models/vgg_7/ukbench/scale2.0x_model.json


File diff suppressed because it is too large
+ 48 - 77
models/vgg_7/ukbench/scale2.0x_model.t7


+ 21 - 0
tools/export.sh

@@ -0,0 +1,21 @@
+#!/bin/sh -x
+
+export_model() {
+    if [ -f models/${1}/scale2.0x_model.t7 ] && [ ! -h models/${1}/scale2.0x_model.t7 ] ; then
+	th tools/export_model.lua -i models/${1}/scale2.0x_model.t7 -o models/${1}/scale2.0x_model.json
+    fi
+    if [ -f models/${1}/noise1_model.t7 ] && [ ! -h models/${1}/noise1_model.t7 ]; then
+	th tools/export_model.lua -i models/${1}/noise1_model.t7 -o models/${1}/noise1_model.json
+    fi
+    if [ -f models/${1}/noise2_model.t7 ] && [ ! -h models/${1}/noise2_model.t7 ]; then
+	th tools/export_model.lua -i models/${1}/noise2_model.t7 -o models/${1}/noise2_model.json
+    fi
+    if [ -f models/${1}/noise3_model.t7 ] && [ ! -h models/${1}/noise3_model.t7 ]; then
+	th tools/export_model.lua -i models/${1}/noise3_model.t7 -o models/${1}/noise3_model.json
+    fi
+}
+export_model vgg_7/art
+export_model vgg_7/art_y
+export_model vgg_7/photo
+export_model vgg_7/ukbench
+export_model upconv_7/art

+ 54 - 16
tools/export_model.lua

@@ -5,24 +5,62 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
 require 'w2nn'
 require 'w2nn'
 local cjson = require "cjson"
 local cjson = require "cjson"
 
 
+function meta_data(model)
+   local meta = {}
+   for k, v in pairs(model) do
+      if k:match("w2nn_") then
+	 meta[k:gsub("w2nn_", "")] = v
+      end
+   end
+   return meta
+end
+function includes(s, a)
+   for i = 1, #a do
+      if s == a[i] then
+	 return true
+      end
+   end
+   return false
+end
 function export(model, output)
 function export(model, output)
+   local targets = {"nn.SpatialConvolutionMM",
+		    "cudnn.SpatialConvolution",
+		    "nn.SpatialFullConvolution",
+		    "cudnn.SpatialFullConvolution"
+   }
    local jmodules = {}
    local jmodules = {}
-   local modules = model:findModules("nn.SpatialConvolutionMM")
-   if #modules == 0 then
-      -- cudnn model
-      modules = model:findModules("cudnn.SpatialConvolution")
-   end
-   for i = 1, #modules, 1 do
-      local module = modules[i]
-      local jmod = {
-	 kW = module.kW,
-	 kH = module.kH,
-	 nInputPlane = module.nInputPlane,
-	 nOutputPlane = module.nOutputPlane,
-	 bias = torch.totable(module.bias:float()),
-	 weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
-      }
-      table.insert(jmodules, jmod)
+   local model_config = meta_data(model)
+   local first_layer = true
+
+   for k = 1, #model.modules do
+      local mod = model.modules[k]
+      local name = torch.typename(mod)
+      if includes(name, targets) then
+	 local weight = mod.weight:float()
+	 if name:match("FullConvolution") then
+	    weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
+	 else
+	    weight = torch.totable(weight:reshape(mod.nOutputPlane, mod.nInputPlane, mod.kH, mod.kW))
+	 end
+	 local jmod = {
+	    class_name = name,
+	    kW = mod.kW,
+	    kH = mod.kH,
+	    dH = mod.dH,
+	    dW = mod.dW,
+	    padW = mod.padW,
+	    padH = mod.padH,
+	    nInputPlane = mod.nInputPlane,
+	    nOutputPlane = mod.nOutputPlane,
+	    bias = torch.totable(mod.bias:float()),
+	    weight = weight
+	 }
+	 if first_layer then
+	    first_layer = false
+	    jmod.model_config = model_config
+	 end
+	 table.insert(jmodules, jmod)
+      end
    end
    end
    local fp = io.open(output, "w")
    local fp = io.open(output, "w")
    if not fp then
    if not fp then

+ 23 - 0
tools/rebuild.sh

@@ -0,0 +1,23 @@
+#!/bin/sh -x
+
+# maybe you should backup models
+
+rebuild() {
+    if [ -f models/${1}/scale2.0x_model.t7 ] && [ ! -h models/${1}/scale2.0x_model.t7 ] ; then
+	th tools/rebuild_model.lua -i models/${1}/scale2.0x_model.t7 -o models/${1}/scale2.0x_model.t7 -backend cunn -model $2
+    fi
+    if [ -f models/${1}/noise1_model.t7 ] && [ ! -h models/${1}/noise1_model.t7 ]; then
+	th tools/rebuild_model.lua -i models/${1}/noise1_model.t7 -o models/${1}/noise1_model.t7 -backend cunn -model $2
+    fi
+    if [ -f models/${1}/noise2_model.t7 ] && [ ! -h models/${1}/noise2_model.t7 ]; then
+	th tools/rebuild_model.lua -i models/${1}/noise2_model.t7 -o models/${1}/noise2_model.t7 -backend cunn -model $2
+    fi
+    if [ -f models/${1}/noise3_model.t7 ] && [ ! -h models/${1}/noise3_model.t7 ]; then
+	th tools/rebuild_model.lua -i models/${1}/noise3_model.t7 -o models/${1}/noise3_model.t7 -backend cunn -model $2
+    fi
+}
+rebuild vgg_7/art vgg_7
+rebuild vgg_7/art_y vgg_7
+rebuild vgg_7/photo vgg_7
+rebuild vgg_7/ukbench vgg_7
+rebuild upconv_7/art upconv_7

+ 49 - 15
tools/rebuild_model.lua

@@ -5,19 +5,52 @@ require 'os'
 require 'w2nn'
 require 'w2nn'
 local srcnn = require 'srcnn'
 local srcnn = require 'srcnn'
 
 
-local function rebuild(old_model, model)
-   local new_model = srcnn.create(model, srcnn.backend(old_model), srcnn.color(old_model))
-   local weight_from = old_model:findModules("nn.SpatialConvolutionMM")
-   local weight_to = new_model:findModules("nn.SpatialConvolutionMM")
-
-   assert(#weight_from == #weight_to)
-   
-   for i = 1, #weight_from do
-      local from = weight_from[i]
-      local to = weight_to[i]
-      
-      to.weight:copy(from.weight)
-      to.bias:copy(from.bias)
+local function rebuild(old_model, model, backend)
+   local targets = {
+      {"nn.SpatialConvolutionMM", 
+       {cunn = "nn.SpatialConvolutionMM", 
+	cudnn = "cudnn.SpatialConvolution"
+       }
+      },
+      {"cudnn.SpatialConvolution",
+       {cunn = "nn.SpatialConvolutionMM", 
+	cudnn = "cudnn.SpatialConvolution"
+       }
+      },
+      {"nn.SpatialFullConvolution",
+       {cunn = "nn.SpatialFullConvolution", 
+	cudnn = "cudnn.SpatialFullConvolution"
+       }
+      },
+      {"cudnn.SpatialFullConvolution",
+       {cunn = "nn.SpatialFullConvolution", 
+	cudnn = "cudnn.SpatialFullConvolution"
+       }
+      }
+   }
+   if backend:len() == 0 then
+      backend = srcnn.backend(old_model)
+   end
+   local new_model = srcnn.create(model, backend, srcnn.color(old_model))
+   for k = 1, #targets do
+      local weight_from = old_model:findModules(targets[k][1])
+      local weight_to = new_model:findModules(targets[k][2][backend])
+      if #weight_from > 0 then
+	 if #weight_from ~= #weight_to then
+	    error(targets[k][1] .. ": weight_from: " .. #weight_from .. ", weight_to: " .. #weight_to)
+	 end
+	 for i = 1, #weight_from do
+	    local from = weight_from[i]
+	    local to = weight_to[i]
+	    
+	    if to.weight then
+	       to.weight:copy(from.weight)
+	    end
+	    if to.bias then
+	       to.bias:copy(from.bias)
+	    end
+	 end
+      end
    end
    end
    new_model:cuda()
    new_model:cuda()
    new_model:evaluate()
    new_model:evaluate()
@@ -30,7 +63,8 @@ cmd:text("waifu2x rebuild cunn model")
 cmd:text("Options:")
 cmd:text("Options:")
 cmd:option("-i", "", 'Specify the input model')
 cmd:option("-i", "", 'Specify the input model')
 cmd:option("-o", "", 'Specify the output model')
 cmd:option("-o", "", 'Specify the output model')
-cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12)')
+cmd:option("-backend", "", 'Specify the CUDA backend (cunn|cudnn)')
+cmd:option("-model", "vgg_7", 'Specify the model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
 cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
 cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
 cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
 cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
 
 
@@ -40,5 +74,5 @@ if not path.isfile(opt.i) then
    os.exit(-1)
    os.exit(-1)
 end
 end
 local old_model = torch.load(opt.i, opt.iformat)
 local old_model = torch.load(opt.i, opt.iformat)
-local new_model = rebuild(old_model, opt.model)
+local new_model = rebuild(old_model, opt.model, opt.backend)
 torch.save(opt.o, new_model, opt.oformat)
 torch.save(opt.o, new_model, opt.oformat)

+ 2 - 2
train.lua

@@ -149,7 +149,7 @@ local function transformer(model, x, is_validation, n, offset)
 					 active_cropping_tries = active_cropping_tries,
 					 active_cropping_tries = active_cropping_tries,
 					 rgb = (settings.color == "rgb"),
 					 rgb = (settings.color == "rgb"),
 					 gamma_correction = settings.gamma_correction,
 					 gamma_correction = settings.gamma_correction,
-					 x_upsampling = not srcnn.has_resize(model)
+					 x_upsampling = not reconstruct.has_resize(model)
 				      })
 				      })
    elseif settings.method == "noise" then
    elseif settings.method == "noise" then
       return pairwise_transform.jpeg(x,
       return pairwise_transform.jpeg(x,
@@ -245,7 +245,7 @@ local function train()
    local x = nil
    local x = nil
    local y = torch.Tensor(settings.patches * #train_x,
    local y = torch.Tensor(settings.patches * #train_x,
 			  ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
 			  ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
-   if srcnn.has_resize(model) then
+   if reconstruct.has_resize(model) then
       x = torch.Tensor(settings.patches * #train_x,
       x = torch.Tensor(settings.patches * #train_x,
 		       ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
 		       ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
    else
    else

+ 1 - 1
waifu2x.lua

@@ -179,7 +179,7 @@ local function waifu2x()
    cmd:option("-scale", 2, 'scale factor')
    cmd:option("-scale", 2, 'scale factor')
    cmd:option("-o", "(auto)", 'path to output file')
    cmd:option("-o", "(auto)", 'path to output file')
    cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
    cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
-   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
+   cmd:option("-model_dir", "./models/upconv_7/art", 'path to model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2|3)')
    cmd:option("-noise_level", 1, '(1|2|3)')
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-crop_size", 128, 'patch size per process')

+ 1 - 1
web.lua

@@ -38,7 +38,7 @@ if cudnn then
    cudnn.fastest = true
    cudnn.fastest = true
    cudnn.benchmark = false
    cudnn.benchmark = false
 end
 end
-local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
+local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
 local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
 local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
 local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
 local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
 local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
 local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")

Some files were not shown because too many files changed in this diff