Bladeren bron

Update cunet arch optimized by benchmark

nagadomi 6 jaren geleden
bovenliggende
commit
62770a901a
1 gewijzigde bestanden met toevoegingen van 19 en 206 verwijderingen
  1. 19 206
      lib/srcnn.lua

+ 19 - 206
lib/srcnn.lua

@@ -485,15 +485,15 @@ function srcnn.upcresnet(backend, ch)
 
    -- 2 cascade
    model:add(resnet(backend, ch, true))
-   con:add(nn.Sequential():add(resnet(backend, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- output is odd
+   con:add(nn.Sequential():add(resnet(backend, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- output size must be odd
    con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
 
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
+   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01()))
+   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01()))
 
    model:add(con)
    model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
+   model:add(w2nn.AuxiliaryLossTable(1))
 
    model.w2nn_arch_name = "upcresnet"
    model.w2nn_offset = 22
@@ -557,7 +557,6 @@ function srcnn.fcn_v1(backend, ch)
    
    return model
 end
-
 local function unet_branch(backend, insert, backend, n_input, n_output, depad)
    local block = nn.Sequential()
    local con = nn.ConcatTable(2)
@@ -589,146 +588,6 @@ end
 
 -- Cascaded Residual Channel Attention U-Net
 function srcnn.upcunet(backend, ch)
-   -- Residual U-Net
-   local function unet(backend, ch, deconv)
-      local block1 = unet_conv(backend, 128, 256, 128, true)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(backend, 64, 64, 128, true))
-      block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
-      block2:add(unet_conv(backend, 128, 64, 64, true))
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block2, backend, 64, 64, 16))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   -- 2 cascade
-   model:add(unet(backend, ch, true))
-   con:add(unet(backend, ch, false))
-   con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
-
-   model:add(con)
-   model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-   
-   model.w2nn_arch_name = "upcunet"
-   model.w2nn_offset = 60
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   model.w2nn_valid_input_size = {}
-   for i = 76, 512, 4 do
-      table.insert(model.w2nn_valid_input_size, i)
-   end
-
-   return model
-end
--- cunet for 1x
-function srcnn.cunet(backend, ch)
-   local function unet(backend, ch)
-      local block1 = unet_conv(backend, 128, 256, 128, true)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(backend, 64, 64, 128, true))
-      block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
-      block2:add(unet_conv(backend, 128, 64, 64, true))
-
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block2, backend, 64, 64, 16))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   -- 2 cascade
-   model:add(unet(backend, ch))
-   con:add(unet(backend, ch))
-   con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
-
-   model:add(con)
-   model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-   
-   model.w2nn_arch_name = "cunet"
-   model.w2nn_offset = 40
-   model.w2nn_scale_factor = 1
-   model.w2nn_channels = ch
-   model.w2nn_resize = false
-   model.w2nn_valid_input_size = {}
-   for i = 100, 512, 4 do
-      table.insert(model.w2nn_valid_input_size, i)
-   end
-
-   return model
-end
-
-function srcnn.upcunet_s_p0(backend, ch)
-   -- Residual U-Net
-   local function unet1(backend, ch, deconv)
-      local block1 = unet_conv(backend, 64, 128, 64, true)
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 64, false))
-      model:add(unet_branch(backend, block1, backend, 64, 64, 4))
-      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   -- 2 cascade
-   model:add(unet1(backend, ch, true))
-   con:add(unet1(backend, ch, false))
-   con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
-   --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
-
-   model:add(con)
-   model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-
-   model.w2nn_arch_name = "upcunet_s_p0"
-   model.w2nn_offset = 24
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   model.w2nn_valid_input_size = {}
-   for i = 76, 512, 4 do
-      table.insert(model.w2nn_valid_input_size, i)
-   end
-
-   return model
-end
-function srcnn.upcunet_s_p1(backend, ch)
    -- Residual U-Net
    local function unet1(backend, ch, deconv)
       local block1 = unet_conv(backend, 64, 128, 64, true)
@@ -769,7 +628,6 @@ function srcnn.upcunet_s_p1(backend, ch)
    -- 2 cascade
    model:add(unet1(backend, ch, true))
    con:add(unet2(backend, ch, false))
-   --con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
 
    aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
@@ -779,7 +637,7 @@ function srcnn.upcunet_s_p1(backend, ch)
    model:add(aux_con)
    model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
 
-   model.w2nn_arch_name = "upcunet_s_p1"
+   model.w2nn_arch_name = "upcunet"
    model.w2nn_offset = 36
    model.w2nn_scale_factor = 2
    model.w2nn_channels = ch
@@ -791,9 +649,8 @@ function srcnn.upcunet_s_p1(backend, ch)
 
    return model
 end
-
-function srcnn.upcunet_s_p2(backend, ch)
-   -- Residual U-Net
+-- cunet for 1x
+function srcnn.cunet(backend, ch)
    local function unet1(backend, ch, deconv)
       local block1 = unet_conv(backend, 64, 128, 64, true)
       local model = nn.Sequential()
@@ -831,54 +688,8 @@ function srcnn.upcunet_s_p2(backend, ch)
    local aux_con = nn.ConcatTable()
 
    -- 2 cascade
-   model:add(unet2(backend, ch, true))
-   con:add(unet1(backend, ch, false))
-   con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
-   --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
-
-   model:add(con)
-   model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-
-   model.w2nn_arch_name = "upcunet_s_p2"
-   model.w2nn_offset = 48
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   model.w2nn_valid_input_size = {}
-   for i = 76, 512, 4 do
-      table.insert(model.w2nn_valid_input_size, i)
-   end
-
-   return model
-end
-function srcnn.cunet_s(backend, ch)
-   local function unet(backend, ch)
-      local block1 = unet_conv(backend, 128, 256, 128, true)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(backend, 32, 64, 128, true))
-      block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
-      block2:add(unet_conv(backend, 128, 64, 32, true))
-
-      local model = nn.Sequential()
-      model:add(unet_conv(backend, ch, 32, 32, false))
-      model:add(unet_branch(backend, block2, backend, 32, 32, 16))
-      model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   -- 2 cascade
-   model:add(unet(backend, ch))
-   con:add(unet(backend, ch))
+   model:add(unet1(backend, ch))
+   con:add(unet2(backend, ch))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
 
    aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
@@ -887,9 +698,9 @@ function srcnn.cunet_s(backend, ch)
    model:add(con)
    model:add(aux_con)
    model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-
-   model.w2nn_arch_name = "cunet_s"
-   model.w2nn_offset = 40
+   
+   model.w2nn_arch_name = "cunet"
+   model.w2nn_offset = 28
    model.w2nn_scale_factor = 1
    model.w2nn_channels = ch
    model.w2nn_resize = false
@@ -900,13 +711,11 @@ function srcnn.cunet_s(backend, ch)
 
    return model
 end
-
 local function bench()
    local sys = require 'sys'
    cudnn.benchmark = true
    local model = nil
-   local arch = {"upconv_7", "upresnet_s","upcresnet", "resnet_14l", "upcunet", "upcunet_s_p0", "upcunet_s_p1", "upcunet_s_p2"}
-   --local arch = {"upconv_7", "upcunet","upcunet_v0", "upcunet_s", "vgg_7", "cunet", "cunet_s"}
+   local arch = {"upconv_7", "upcunet", "vgg_7", "cunet"}
    local backend = "cudnn"
    local ch = 3
    local batch_size = 1
@@ -915,7 +724,12 @@ local function bench()
       model = srcnn[arch[k]](backend, ch):cuda()
       model:evaluate()
       local dummy = nil
-      local crop_size = (output_size + model.w2nn_offset * 2) / 2
+      local crop_size = nil
+      if model.w2nn_resize then
+	 crop_size = (output_size + model.w2nn_offset * 2) / 2
+      else
+	 crop_size = (output_size + model.w2nn_offset * 2)
+      end
       local dummy = torch.Tensor(batch_size, ch, output_size, output_size):zero():cuda()
 
       print(arch[k], output_size, crop_size)
@@ -962,5 +776,4 @@ print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()):size())
 bench()
 os.exit()
 --]]
-
 return srcnn