nagadomi %!s(int64=6) %!d(string=hai) anos
pai
achega
d5c2277e0e

+ 204 - 0
appendix/arch/cunet.txt

@@ -0,0 +1,204 @@
+nn.Sequential {
+  [input -> (1) -> (2) -> (3) -> (4) -> output]
+  (1): nn.Sequential {
+    [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+    (1): nn.Sequential {
+      [input -> (1) -> (2) -> (3) -> (4) -> output]
+      (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
+      (2): nn.LeakyReLU(0.1)
+      (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
+      (4): nn.LeakyReLU(0.1)
+    }
+    (2): nn.Sequential {
+      [input -> (1) -> (2) -> output]
+      (1): nn.ConcatTable {
+        input
+          |`-> (1): nn.Sequential {
+          |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+          |      (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
+          |      (2): nn.LeakyReLU(0.1)
+          |      (3): nn.Sequential {
+          |        [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+          |        (1): nn.SpatialConvolutionMM(64 -> 128, 3x3)
+          |        (2): nn.LeakyReLU(0.1)
+          |        (3): nn.SpatialConvolutionMM(128 -> 64, 3x3)
+          |        (4): nn.LeakyReLU(0.1)
+          |        (5): nn.ConcatTable {
+          |          input
+          |            |`-> (1): nn.Identity
+          |             `-> (2): nn.Sequential {
+          |                   [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+          |                   (1): nn.Sequential {
+          |                     [input -> (1) -> (2) -> (3) -> output]
+          |                     (1): nn.Mean
+          |                     (2): nn.Mean
+          |                     (3): nn.View(-1, 64, 1, 1)
+          |                   }
+          |                   (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
+          |                   (3): nn.ReLU
+          |                   (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
+          |                   (5): nn.Sigmoid
+          |                 }
+          |             ... -> output
+          |        }
+          |        (6): w2nn.ScaleTable
+          |      }
+          |      (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
+          |      (5): nn.LeakyReLU(0.1)
+          |    }
+           `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
+           ... -> output
+      }
+      (2): nn.CAddTable
+    }
+    (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+    (4): nn.LeakyReLU(0.1)
+    (5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
+  }
+  (2): nn.ConcatTable {
+    input
+      |`-> (1): nn.Sequential {
+      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |      (1): nn.Sequential {
+      |        [input -> (1) -> (2) -> (3) -> (4) -> output]
+      |        (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
+      |        (2): nn.LeakyReLU(0.1)
+      |        (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
+      |        (4): nn.LeakyReLU(0.1)
+      |      }
+      |      (2): nn.Sequential {
+      |        [input -> (1) -> (2) -> output]
+      |        (1): nn.ConcatTable {
+      |          input
+      |            |`-> (1): nn.Sequential {
+      |            |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |      (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
+      |            |      (2): nn.LeakyReLU(0.1)
+      |            |      (3): nn.Sequential {
+      |            |        [input -> (1) -> (2) -> (3) -> output]
+      |            |        (1): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |          (1): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |            |          (2): nn.LeakyReLU(0.1)
+      |            |          (3): nn.SpatialConvolutionMM(64 -> 128, 3x3)
+      |            |          (4): nn.LeakyReLU(0.1)
+      |            |          (5): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Identity
+      |            |               `-> (2): nn.Sequential {
+      |            |                     [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |                     (1): nn.Sequential {
+      |            |                       [input -> (1) -> (2) -> (3) -> output]
+      |            |                       (1): nn.Mean
+      |            |                       (2): nn.Mean
+      |            |                       (3): nn.View(-1, 128, 1, 1)
+      |            |                     }
+      |            |                     (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
+      |            |                     (3): nn.ReLU
+      |            |                     (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
+      |            |                     (5): nn.Sigmoid
+      |            |                   }
+      |            |               ... -> output
+      |            |          }
+      |            |          (6): w2nn.ScaleTable
+      |            |        }
+      |            |        (2): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> output]
+      |            |          (1): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Sequential {
+      |            |              |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |              |      (1): nn.SpatialConvolutionMM(128 -> 128, 2x2, 2,2)
+      |            |              |      (2): nn.LeakyReLU(0.1)
+      |            |              |      (3): nn.Sequential {
+      |            |              |        [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |              |        (1): nn.SpatialConvolutionMM(128 -> 256, 3x3)
+      |            |              |        (2): nn.LeakyReLU(0.1)
+      |            |              |        (3): nn.SpatialConvolutionMM(256 -> 128, 3x3)
+      |            |              |        (4): nn.LeakyReLU(0.1)
+      |            |              |        (5): nn.ConcatTable {
+      |            |              |          input
+      |            |              |            |`-> (1): nn.Identity
+      |            |              |             `-> (2): nn.Sequential {
+      |            |              |                   [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |              |                   (1): nn.Sequential {
+      |            |              |                     [input -> (1) -> (2) -> (3) -> output]
+      |            |              |                     (1): nn.Mean
+      |            |              |                     (2): nn.Mean
+      |            |              |                     (3): nn.View(-1, 128, 1, 1)
+      |            |              |                   }
+      |            |              |                   (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
+      |            |              |                   (3): nn.ReLU
+      |            |              |                   (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
+      |            |              |                   (5): nn.Sigmoid
+      |            |              |                 }
+      |            |              |             ... -> output
+      |            |              |        }
+      |            |              |        (6): w2nn.ScaleTable
+      |            |              |      }
+      |            |              |      (4): nn.SpatialFullConvolution(128 -> 128, 2x2, 2,2)
+      |            |              |      (5): nn.LeakyReLU(0.1)
+      |            |              |    }
+      |            |               `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
+      |            |               ... -> output
+      |            |          }
+      |            |          (2): nn.CAddTable
+      |            |        }
+      |            |        (3): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |          (1): nn.SpatialConvolutionMM(128 -> 64, 3x3)
+      |            |          (2): nn.LeakyReLU(0.1)
+      |            |          (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |            |          (4): nn.LeakyReLU(0.1)
+      |            |          (5): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Identity
+      |            |               `-> (2): nn.Sequential {
+      |            |                     [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |                     (1): nn.Sequential {
+      |            |                       [input -> (1) -> (2) -> (3) -> output]
+      |            |                       (1): nn.Mean
+      |            |                       (2): nn.Mean
+      |            |                       (3): nn.View(-1, 64, 1, 1)
+      |            |                     }
+      |            |                     (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
+      |            |                     (3): nn.ReLU
+      |            |                     (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
+      |            |                     (5): nn.Sigmoid
+      |            |                   }
+      |            |               ... -> output
+      |            |          }
+      |            |          (6): w2nn.ScaleTable
+      |            |        }
+      |            |      }
+      |            |      (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
+      |            |      (5): nn.LeakyReLU(0.1)
+      |            |    }
+      |             `-> (2): nn.SpatialZeroPadding(l=-16, r=-16, t=-16, b=-16)
+      |             ... -> output
+      |        }
+      |        (2): nn.CAddTable
+      |      }
+      |      (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |      (4): nn.LeakyReLU(0.1)
+      |      (5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
+      |    }
+       `-> (2): nn.SpatialZeroPadding(l=-20, r=-20, t=-20, b=-20)
+       ... -> output
+  }
+  (3): nn.ConcatTable {
+    input
+      |`-> (1): nn.Sequential {
+      |      [input -> (1) -> (2) -> output]
+      |      (1): nn.CAddTable
+      |      (2): w2nn.InplaceClip01
+      |    }
+       `-> (2): nn.Sequential {
+             [input -> (1) -> (2) -> output]
+             (1): nn.SelectTable(2)
+             (2): w2nn.InplaceClip01
+           }
+       ... -> output
+  }
+  (4): w2nn.AuxiliaryLossTable
+}

+ 204 - 0
appendix/arch/upcunet.txt

@@ -0,0 +1,204 @@
+nn.Sequential {
+  [input -> (1) -> (2) -> (3) -> (4) -> output]
+  (1): nn.Sequential {
+    [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+    (1): nn.Sequential {
+      [input -> (1) -> (2) -> (3) -> (4) -> output]
+      (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
+      (2): nn.LeakyReLU(0.1)
+      (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
+      (4): nn.LeakyReLU(0.1)
+    }
+    (2): nn.Sequential {
+      [input -> (1) -> (2) -> output]
+      (1): nn.ConcatTable {
+        input
+          |`-> (1): nn.Sequential {
+          |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+          |      (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
+          |      (2): nn.LeakyReLU(0.1)
+          |      (3): nn.Sequential {
+          |        [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+          |        (1): nn.SpatialConvolutionMM(64 -> 128, 3x3)
+          |        (2): nn.LeakyReLU(0.1)
+          |        (3): nn.SpatialConvolutionMM(128 -> 64, 3x3)
+          |        (4): nn.LeakyReLU(0.1)
+          |        (5): nn.ConcatTable {
+          |          input
+          |            |`-> (1): nn.Identity
+          |             `-> (2): nn.Sequential {
+          |                   [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+          |                   (1): nn.Sequential {
+          |                     [input -> (1) -> (2) -> (3) -> output]
+          |                     (1): nn.Mean
+          |                     (2): nn.Mean
+          |                     (3): nn.View(-1, 64, 1, 1)
+          |                   }
+          |                   (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
+          |                   (3): nn.ReLU
+          |                   (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
+          |                   (5): nn.Sigmoid
+          |                 }
+          |             ... -> output
+          |        }
+          |        (6): w2nn.ScaleTable
+          |      }
+          |      (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
+          |      (5): nn.LeakyReLU(0.1)
+          |    }
+           `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
+           ... -> output
+      }
+      (2): nn.CAddTable
+    }
+    (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+    (4): nn.LeakyReLU(0.1)
+    (5): nn.SpatialFullConvolution(64 -> 3, 4x4, 2,2, 3,3)
+  }
+  (2): nn.ConcatTable {
+    input
+      |`-> (1): nn.Sequential {
+      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |      (1): nn.Sequential {
+      |        [input -> (1) -> (2) -> (3) -> (4) -> output]
+      |        (1): nn.SpatialConvolutionMM(3 -> 32, 3x3)
+      |        (2): nn.LeakyReLU(0.1)
+      |        (3): nn.SpatialConvolutionMM(32 -> 64, 3x3)
+      |        (4): nn.LeakyReLU(0.1)
+      |      }
+      |      (2): nn.Sequential {
+      |        [input -> (1) -> (2) -> output]
+      |        (1): nn.ConcatTable {
+      |          input
+      |            |`-> (1): nn.Sequential {
+      |            |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |      (1): nn.SpatialConvolutionMM(64 -> 64, 2x2, 2,2)
+      |            |      (2): nn.LeakyReLU(0.1)
+      |            |      (3): nn.Sequential {
+      |            |        [input -> (1) -> (2) -> (3) -> output]
+      |            |        (1): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |          (1): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |            |          (2): nn.LeakyReLU(0.1)
+      |            |          (3): nn.SpatialConvolutionMM(64 -> 128, 3x3)
+      |            |          (4): nn.LeakyReLU(0.1)
+      |            |          (5): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Identity
+      |            |               `-> (2): nn.Sequential {
+      |            |                     [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |                     (1): nn.Sequential {
+      |            |                       [input -> (1) -> (2) -> (3) -> output]
+      |            |                       (1): nn.Mean
+      |            |                       (2): nn.Mean
+      |            |                       (3): nn.View(-1, 128, 1, 1)
+      |            |                     }
+      |            |                     (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
+      |            |                     (3): nn.ReLU
+      |            |                     (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
+      |            |                     (5): nn.Sigmoid
+      |            |                   }
+      |            |               ... -> output
+      |            |          }
+      |            |          (6): w2nn.ScaleTable
+      |            |        }
+      |            |        (2): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> output]
+      |            |          (1): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Sequential {
+      |            |              |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |              |      (1): nn.SpatialConvolutionMM(128 -> 128, 2x2, 2,2)
+      |            |              |      (2): nn.LeakyReLU(0.1)
+      |            |              |      (3): nn.Sequential {
+      |            |              |        [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |              |        (1): nn.SpatialConvolutionMM(128 -> 256, 3x3)
+      |            |              |        (2): nn.LeakyReLU(0.1)
+      |            |              |        (3): nn.SpatialConvolutionMM(256 -> 128, 3x3)
+      |            |              |        (4): nn.LeakyReLU(0.1)
+      |            |              |        (5): nn.ConcatTable {
+      |            |              |          input
+      |            |              |            |`-> (1): nn.Identity
+      |            |              |             `-> (2): nn.Sequential {
+      |            |              |                   [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |              |                   (1): nn.Sequential {
+      |            |              |                     [input -> (1) -> (2) -> (3) -> output]
+      |            |              |                     (1): nn.Mean
+      |            |              |                     (2): nn.Mean
+      |            |              |                     (3): nn.View(-1, 128, 1, 1)
+      |            |              |                   }
+      |            |              |                   (2): nn.SpatialConvolutionMM(128 -> 16, 1x1)
+      |            |              |                   (3): nn.ReLU
+      |            |              |                   (4): nn.SpatialConvolutionMM(16 -> 128, 1x1)
+      |            |              |                   (5): nn.Sigmoid
+      |            |              |                 }
+      |            |              |             ... -> output
+      |            |              |        }
+      |            |              |        (6): w2nn.ScaleTable
+      |            |              |      }
+      |            |              |      (4): nn.SpatialFullConvolution(128 -> 128, 2x2, 2,2)
+      |            |              |      (5): nn.LeakyReLU(0.1)
+      |            |              |    }
+      |            |               `-> (2): nn.SpatialZeroPadding(l=-4, r=-4, t=-4, b=-4)
+      |            |               ... -> output
+      |            |          }
+      |            |          (2): nn.CAddTable
+      |            |        }
+      |            |        (3): nn.Sequential {
+      |            |          [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
+      |            |          (1): nn.SpatialConvolutionMM(128 -> 64, 3x3)
+      |            |          (2): nn.LeakyReLU(0.1)
+      |            |          (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |            |          (4): nn.LeakyReLU(0.1)
+      |            |          (5): nn.ConcatTable {
+      |            |            input
+      |            |              |`-> (1): nn.Identity
+      |            |               `-> (2): nn.Sequential {
+      |            |                     [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
+      |            |                     (1): nn.Sequential {
+      |            |                       [input -> (1) -> (2) -> (3) -> output]
+      |            |                       (1): nn.Mean
+      |            |                       (2): nn.Mean
+      |            |                       (3): nn.View(-1, 64, 1, 1)
+      |            |                     }
+      |            |                     (2): nn.SpatialConvolutionMM(64 -> 8, 1x1)
+      |            |                     (3): nn.ReLU
+      |            |                     (4): nn.SpatialConvolutionMM(8 -> 64, 1x1)
+      |            |                     (5): nn.Sigmoid
+      |            |                   }
+      |            |               ... -> output
+      |            |          }
+      |            |          (6): w2nn.ScaleTable
+      |            |        }
+      |            |      }
+      |            |      (4): nn.SpatialFullConvolution(64 -> 64, 2x2, 2,2)
+      |            |      (5): nn.LeakyReLU(0.1)
+      |            |    }
+      |             `-> (2): nn.SpatialZeroPadding(l=-16, r=-16, t=-16, b=-16)
+      |             ... -> output
+      |        }
+      |        (2): nn.CAddTable
+      |      }
+      |      (3): nn.SpatialConvolutionMM(64 -> 64, 3x3)
+      |      (4): nn.LeakyReLU(0.1)
+      |      (5): nn.SpatialConvolutionMM(64 -> 3, 3x3)
+      |    }
+       `-> (2): nn.SpatialZeroPadding(l=-20, r=-20, t=-20, b=-20)
+       ... -> output
+  }
+  (3): nn.ConcatTable {
+    input
+      |`-> (1): nn.Sequential {
+      |      [input -> (1) -> (2) -> output]
+      |      (1): nn.CAddTable
+      |      (2): w2nn.InplaceClip01
+      |    }
+       `-> (2): nn.Sequential {
+             [input -> (1) -> (2) -> output]
+             (1): nn.SelectTable(2)
+             (2): w2nn.InplaceClip01
+           }
+       ... -> output
+  }
+  (4): w2nn.AuxiliaryLossTable
+}

+ 8 - 0
appendix/cudnn2cunn.sh

@@ -0,0 +1,8 @@
+#!/bin/bash
+
+th tools/cudnn2cunn.lua -i models/test/test/upcunet_release/scale2.0x_model.t7 -o models/cunet/art/scale2.0x_model.t7
+for i in 0 1 2 3
+do
+    th tools/cudnn2cunn.lua -i models/test/cunet_release/noise${i}_model.t7 -o models/cunet/art/noise${i}_model.t7
+    th tools/cudnn2cunn.lua -i models/test/cunet_release/noise${i}_scale2.0x_model.t7 -o models/cunet/art/noise${i}_scale2.0x_model.t7
+done

+ 2 - 1
lib/LBPCriterion.lua

@@ -1,3 +1,4 @@
+-- Random Generated Local Binary Pattern Loss 
 local LBPCriterion, parent = torch.class('w2nn.LBPCriterion','nn.Criterion')
 
 local function create_filters(ch, n, k, layers)
@@ -26,7 +27,7 @@ function LBPCriterion:__init(ch, n, k, layers)
    parent.__init(self)
    self.layers = layers or 1
    self.gamma = 0.1
-   self.n = n or 32
+   self.n = n or 128
    self.k = k or 3
    self.ch = ch
    self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)

+ 1 - 0
lib/srcnn.lua

@@ -550,6 +550,7 @@ end
 
 -- Cascaded Residual U-Net with SEBlock
 
+-- unet utils adapted from https://gist.github.com/toshi-k/ca75e614f1ac12fa44f62014ac1d6465
 local function unet_conv(backend, n_input, n_middle, n_output, se)
    local model = nn.Sequential()
    model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))

+ 22 - 11
tools/cudnn2cunn.lua

@@ -6,18 +6,29 @@ require 'w2nn'
 local srcnn = require 'srcnn'
 
 local function cudnn2cunn(cudnn_model)
-   local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
-   local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
-   local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
+   local name = srcnn.name(cudnn_model)
+   local cunn_model = srcnn[name]('cunn', srcnn.channels(cudnn_model))
+   local param_layers = {
+      {cunn="nn.SpatialConvolutionMM", cudnn="cudnn.SpatialConvolution", attr={"bias", "weight"}},
+      {cunn="nn.SpatialDilatedConvolution", cudnn="cudnn.SpatialDilatedConvolution", attr={"bias", "weight"}},
+      {cunn="nn.SpatialFullConvolution", cudnn="cudnn.SpatialFullConvolution", attr={"bias", "weight"}},
+      {cunn="nn.Linear", cudnn="nn.Linear", attr={"bias", "weight"}}
+   }
+   for i = 1, #param_layers do
+      local p = param_layers[i]
+      local weight_from = cudnn_model:findModules(p.cudnn)
+      local weight_to = cunn_model:findModules(p.cunn)
+      print(p.cudnn, #weight_from)
+      assert(#weight_from == #weight_to)
    
-   assert(#weight_from == #weight_to)
-   
-   for i = 1, #weight_from do
-      local from = weight_from[i]
-      local to = weight_to[i]
-      
-      to.weight:copy(from.weight)
-      to.bias:copy(from.bias)
+      for i = 1, #weight_from do
+	 local from = weight_from[i]
+	 local to = weight_to[i]
+	 to.weight:copy(from.weight)
+	 if to.bias then
+	    to.bias:copy(from.bias)
+	 end
+      end
    end
    cunn_model:cuda()
    cunn_model:evaluate()

+ 58 - 0
tools/make_benchmark_input.lua

@@ -0,0 +1,58 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'xlua'
+local iproc = require 'iproc'
+local image_loader = require 'image_loader'
+local gm = require 'graphicsmagick'
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("waifu2x-make benchmark data")
+cmd:text("Options:")
+
+cmd:option("-i", "./data/test", 'input dir')
+cmd:option("-lr", "hr", 'highres output dir')
+cmd:option("-hr", "lr", 'lowres output dir')
+cmd:option("-filter", "Sinc", 'dowsampling filter')
+
+local opt = cmd:parse(arg)
+torch.setdefaulttensortype('torch.FloatTensor')
+local function transform_scale(x, opt)
+   return iproc.scale(x,
+		      x:size(3) * 0.5,
+		      x:size(2) * 0.5,
+		      opt.filter, 1)
+end
+local function load_data_from_dir(test_dir)
+   local test_x = {}
+   local files = dir.getfiles(test_dir, "*.*")
+   for i = 1, #files do
+      local name = path.basename(files[i])
+      local e = path.extension(name)
+      local base = name:sub(0, name:len() - e:len())
+      local img = image_loader.load_byte(files[i])
+      if img then
+	 table.insert(test_x, {y = iproc.crop_mod4(img),
+			       basename = base})
+      end
+      if i % 10 == 0 then
+	 if opt.show_progress then
+	    xlua.progress(i, #files)
+	 end
+	 collectgarbage()
+      end
+   end
+   return test_x
+end
+dir.makepath(opt.lr)
+dir.makepath(opt.hr)
+local files = load_data_from_dir(opt.i)
+for i = 1, #files do
+   local y = files[i].y
+   local x = transform_scale(y, opt)
+   local hr_path = path.join(opt.hr, files[i].basename .. ".png")
+   local lr_path = path.join(opt.lr, files[i].basename .. ".png")
+   image.save(hr_path, y)
+   image.save(lr_path, x)
+end

+ 52 - 0
tools/switch_aux_output.lua

@@ -0,0 +1,52 @@
+require 'pl'
+local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+require 'os'
+require 'w2nn'
+local srcnn = require 'srcnn'
+
+local function find_aux(seq)
+   for k = 1, #seq.modules do
+      local mod = seq.modules[k]
+      local name = torch.typename(mod)
+      if name == "nn.Sequential" or name == "nn.ConcatTable" then
+	 local aux = find_aux(mod)
+	 if aux ~= nil then
+	    return aux
+	 end
+      elseif name == "w2nn.AuxiliaryLossTable" then
+	 return mod
+      end
+   end
+   return nil
+end
+
+local cmd = torch.CmdLine()
+cmd:text()
+cmd:text("switch the output pass of auxiliary loss")
+cmd:text("Options:")
+cmd:option("-j", 1, 'Specify the output path index (1|2)')
+cmd:option("-i", "", 'Specify the input model')
+cmd:option("-o", "", 'Specify the output model')
+cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
+cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
+
+local opt = cmd:parse(arg)
+if not path.isfile(opt.i) then
+   cmd:help()
+   os.exit(-1)
+end
+
+local model = torch.load(opt.i, opt.iformat)
+if model == nil then
+   print("load error")
+   os.exit(-1)
+end
+local aux = find_aux(model)
+if aux == nil then
+   print("AuxiliaryLossTable not found")
+else
+   print(aux)
+   aux.i = opt.j
+   torch.save(opt.o, model, opt.oformat)
+end