nagadomi 6 vuotta sitten
vanhempi
commit
dd8cb71601
1 muutettua tiedostoa jossa 56 lisäystä ja 96 poistoa
  1. 56 96
      lib/srcnn.lua

+ 56 - 96
lib/srcnn.lua

@@ -1,7 +1,9 @@
 require 'w2nn'
 
--- ref: http://arxiv.org/abs/1502.01852
--- ref: http://arxiv.org/abs/1501.00092
+-- ref: https://arxiv.org/abs/1502.01852
+-- ref: https://arxiv.org/abs/1501.00092
+-- ref: https://arxiv.org/abs/1709.01507
+-- ref: https://arxiv.org/abs/1505.04597
 local srcnn = {}
 
 local function msra_filler(mod)
@@ -240,9 +242,6 @@ local function SEBlock(backend, n_output, r)
    con:add(attention)
    return con
 end
--- I devised this arch for the block size and global average pooling problem,
--- but SEBlock may possibly learn multi-scale input or just a normalization. No problems occur.
--- So this arch is not used.
 local function SpatialSEBlock(backend, ave_size, n_output, r)
    local con = nn.ConcatTable(2)
    local attention = nn.Sequential()
@@ -353,8 +352,6 @@ function srcnn.vgg_7(backend, ch)
    model.w2nn_offset = 7
    model.w2nn_scale_factor = 1
    model.w2nn_channels = ch
-   --model:cuda()
-   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
    
    return model
 end
@@ -378,7 +375,6 @@ function srcnn.upconv_7(backend, ch)
    model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
-
    model.w2nn_arch_name = "upconv_7"
    model.w2nn_offset = 14
    model.w2nn_scale_factor = 2
@@ -414,9 +410,6 @@ function srcnn.upconv_7l(backend, ch)
    model.w2nn_resize = true
    model.w2nn_channels = ch
 
-   --model:cuda()
-   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
-
    return model
 end
 
@@ -439,9 +432,6 @@ function srcnn.resnet_14l(backend, ch)
    model.w2nn_resize = true
    model.w2nn_channels = ch
 
-   --model:cuda()
-   --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
-
    return model
 end
 
@@ -557,6 +547,21 @@ function srcnn.fcn_v1(backend, ch)
    
    return model
 end
+
+-- Cascaded Residual U-Net with SEBlock
+
+local function unet_conv(backend, n_input, n_middle, n_output, se)
+   local model = nn.Sequential()
+   model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1, true))
+   if se then
+      model:add(SEBlock(backend, n_output, 8))
+      model:add(w2nn.ScaleTable())
+   end
+   return model
+end
 local function unet_branch(backend, insert, backend, n_input, n_output, depad)
    local block = nn.Sequential()
    local con = nn.ConcatTable(2)
@@ -573,61 +578,47 @@ local function unet_branch(backend, insert, backend, n_input, n_output, depad)
    model:add(nn.CAddTable())
    return model
 end
-local function unet_conv(backend, n_input, n_middle, n_output, se)
+local function cunet_unet1(backend, ch, deconv)
+   local block1 = unet_conv(backend, 64, 128, 64, true)
    local model = nn.Sequential()
-   model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1, true))
-   model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1, true))
-   if se then
-      model:add(SEBlock(backend, n_output, 8))
-      model:add(w2nn.ScaleTable())
+   model:add(unet_conv(backend, ch, 32, 64, false))
+   model:add(unet_branch(backend, block1, backend, 64, 64, 4))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   if deconv then
+	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
+   else
+      model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
    end
    return model
 end
-
--- Cascaded Residual Channel Attention U-Net
-function srcnn.upcunet(backend, ch)
-   -- Residual U-Net
-   local function unet1(backend, ch, deconv)
-      local block1 = unet_conv(backend, 64, 128, 64, true)
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block1, backend, 64, 64, 4))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local function unet2(backend, ch, deconv)
-      local block1 = unet_conv(backend, 128, 256, 128, true)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(backend, 64, 64, 128, true))
-      block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
-      block2:add(unet_conv(backend, 128, 64, 64, true))
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block2, backend, 64, 64, 16))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
+local function cunet_unet2(backend, ch, deconv)
+   local block1 = unet_conv(backend, 128, 256, 128, true)
+   local block2 = nn.Sequential()
+   block2:add(unet_conv(backend, 64, 64, 128, true))
+   block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
+   block2:add(unet_conv(backend, 128, 64, 64, true))
+   local model = nn.Sequential()
+   model:add(unet_conv(backend, ch, 32, 64, false))
+   model:add(unet_branch(backend, block2, backend, 64, 64, 16))
+   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+   model:add(nn.LeakyReLU(0.1))
+   if deconv then
+      model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
+   else
+      model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
    end
+   return model
+end
+-- 2x
+function srcnn.upcunet(backend, ch)
    local model = nn.Sequential()
    local con = nn.ConcatTable()
    local aux_con = nn.ConcatTable()
 
    -- 2 cascade
-   model:add(unet1(backend, ch, true))
-   con:add(unet2(backend, ch, false))
+   model:add(cunet_unet1(backend, ch, true))
+   con:add(cunet_unet2(backend, ch, false))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
 
    aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
@@ -649,47 +640,15 @@ function srcnn.upcunet(backend, ch)
 
    return model
 end
--- cunet for 1x
+-- 1x
 function srcnn.cunet(backend, ch)
-   local function unet1(backend, ch, deconv)
-      local block1 = unet_conv(backend, 64, 128, 64, true)
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block1, backend, 64, 64, 4))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local function unet2(backend, ch, deconv)
-      local block1 = unet_conv(backend, 128, 256, 128, true)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(backend, 64, 64, 128, true))
-      block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
-      block2:add(unet_conv(backend, 128, 64, 64, true))
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block2, backend, 64, 64, 16))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
    local model = nn.Sequential()
    local con = nn.ConcatTable()
    local aux_con = nn.ConcatTable()
 
    -- 2 cascade
-   model:add(unet1(backend, ch))
-   con:add(unet2(backend, ch))
+   model:add(cunet_unet1(backend, ch, false))
+   con:add(cunet_unet2(backend, ch, false))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
 
    aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
@@ -711,6 +670,7 @@ function srcnn.cunet(backend, ch)
 
    return model
 end
+
 local function bench()
    local sys = require 'sys'
    cudnn.benchmark = true
@@ -719,7 +679,7 @@ local function bench()
    local backend = "cudnn"
    local ch = 3
    local batch_size = 1
-   local output_size = 320
+   local output_size = 256
    for k = 1, #arch do
       model = srcnn[arch[k]](backend, ch):cuda()
       model:evaluate()
@@ -739,7 +699,7 @@ local function bench()
 	 model:forward(x)
       end
       t = sys.clock()
-      for i = 1, 100 do
+      for i = 1, 10 do
 	 local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
 	 local z = model:forward(x)
 	 dummy:add(z)