nagadomi 6 år sedan
förälder
incheckning
1f18d1919a
1 ändrade filer med 4 tillägg och 276 borttagningar
  1. 4 276
      lib/srcnn.lua

+ 4 - 276
lib/srcnn.lua

@@ -710,281 +710,8 @@ function srcnn.upconv_refine(backend, ch)
    return model
 end
 
--- cascade u-net
-function srcnn.cunet_v1(backend, ch)
-   function unet_branch(insert, backend, n_input, n_output, depad)
-      local block = nn.Sequential()
-      local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
-      --block:add(w2nn.Print())
-      block:add(pooling)
-      block:add(insert)
-      block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
-      local parallel = nn.ConcatTable(2)
-      parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
-      parallel:add(block)
-      local model = nn.Sequential()
-      model:add(parallel)
-      model:add(nn.JoinTable(2))
-      return model
-   end
-   function unet_conv(n_input, n_middle, n_output)
-	local model = nn.Sequential()
-	model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
-	return model
-   end
-   function unet(backend, ch, deconv)
-      -- 
-      local block1 = unet_conv(128, 256, 128)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(32, 64, 128))
-      block2:add(unet_branch(block1, backend, 128, 128, 4))
-      block2:add(unet_conv(128*2, 64, 32))
-      local model = nn.Sequential()
-      model:add(unet_conv(ch, 32, 32))
-      model:add(unet_branch(block2, backend, 32, 32, 16))
-      model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1))
-      if deconv then
-	 model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   model:add(unet(backend, ch, true))
-
-   con:add(unet(backend, ch, false))
-   con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
-
-   model:add(con)
-   model:add(aux_con)
-   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
-   
-   model.w2nn_arch_name = "cunet_v1"
-   model.w2nn_offset = 60
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   -- 72, 128, 256 are valid
-   --model.w2nn_input_size = 128
-
-   return model
-end
-
--- cascade u-net
-function srcnn.cunet_v2(backend, ch)
-   function unet_branch(insert, backend, n_input, n_output, depad)
-      local block = nn.Sequential()
-      local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
-      --block:add(w2nn.Print())
-      block:add(pooling)
-      block:add(insert)
-      block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
-      local parallel = nn.ConcatTable(2)
-      parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
-      parallel:add(block)
-      local model = nn.Sequential()
-      model:add(parallel)
-      model:add(nn.CAddTable(2))
-      return model
-   end
-   function unet_conv(n_input, n_middle, n_output)
-	local model = nn.Sequential()
-	model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
-	return model
-   end
-   -- res unet
-   function unet(backend, ch, deconv)
-      local block1 = unet_conv(128, 256, 128)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(64, 128, 128))
-      block2:add(unet_branch(block1, backend, 128, 128, 4))
-      block2:add(unet_conv(128, 128, 64))
-      local model = nn.Sequential()
-      model:add(nn.SpatialZeroPadding(-1, -1, -1, -1))
-      model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
-      model:add(unet_branch(block2, backend, 64, 64, 16))
-      if deconv then
-	 model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-	 model:add(nn.LeakyReLU(0.1))
-	 model:add(SpatialFullConvolution(backend, 128, 64, 4, 4, 2, 2, 3, 3))
-      else
-	 model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-      end
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   model:add(unet(backend, ch, true))
-   con:add(unet(backend, 64, false))
-   con:add(nn.SpatialZeroPadding(-19, -19, -19, -19))
-
-   model:add(con)
-   model:add(nn.CAddTable())
-   model:add(nn.LeakyReLU(0.1, true))
-   model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-   
-   model.w2nn_arch_name = "cunet_v2"
-   model.w2nn_offset = 60
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   -- 72, 128, 256 are valid
-   --model.w2nn_input_size = 128
-
-   return model
-end
--- cascade u-net
-function srcnn.cunet_v3(backend, ch)
-   function unet_branch(insert, backend, n_input, n_output, depad)
-      local block = nn.Sequential()
-      local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
-      --block:add(w2nn.Print())
-      block:add(pooling)
-      block:add(insert)
-      block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
-      local parallel = nn.ConcatTable(2)
-      parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
-      parallel:add(block)
-      local model = nn.Sequential()
-      model:add(parallel)
-      model:add(nn.CAddTable())
-      return model
-   end
-   function unet_conv(n_input, n_middle, n_output)
-	local model = nn.Sequential()
-	model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	return model
-   end
-   function unet(backend, ch, deconv)
-      local block1 = unet_conv(128, 256, 128)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(64, 64, 128))
-      block2:add(unet_branch(block1, backend, 128, 128, 4))
-      block2:add(unet_conv(128, 64, 64))
-      local model = nn.Sequential()
-      model:add(unet_conv(ch, 32, 64))
-      model:add(unet_branch(block2, backend, 64, 64, 16))
-      if deconv then
-	 model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
-	 model:add(nn.LeakyReLU(0.1))
-	 model:add(SpatialFullConvolution(backend, 128, 64, 4, 4, 2, 2, 3, 3))
-      end
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-
-   model:add(unet(backend, ch, true))
-   model:add(nn.ConcatTable():add(unet(backend, 64, false)):add(nn.SpatialZeroPadding(-18, -18, -18, -18)))
-   model:add(nn.CAddTable())
-   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU())
-   model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.InplaceClip01())
-   
-   model.w2nn_arch_name = "cunet_v3"
-   model.w2nn_offset = 60
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   -- 72, 128, 256 are valid
-   --model.w2nn_input_size = 128
-
-   return model
-end
--- cascade u-net
-function srcnn.cunet_v4(backend, ch)
-   function upconv_3(backend, n_input, n_output)
-      local model = nn.Sequential()
-      model:add(SpatialConvolution(backend, n_input, 32, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1, true))
-      model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1, true))
-      model:add(SpatialFullConvolution(backend, 32, n_output, 4, 4, 2, 2, 3, 3):noBias())
-      return model
-   end
-   function unet_branch(insert, backend, n_input, n_output, depad)
-      local block = nn.Sequential()
-      local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
-      --block:add(w2nn.Print())
-      block:add(pooling)
-      block:add(insert)
-      block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
-      local parallel = nn.ConcatTable(2)
-      parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
-      parallel:add(block)
-      local model = nn.Sequential()
-      model:add(parallel)
-      model:add(nn.CAddTable())
-      return model
-   end
-   function unet_conv(n_input, n_middle, n_output)
-	local model = nn.Sequential()
-	model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
-	model:add(nn.LeakyReLU(0.1, true))
-	return model
-   end
-   function unet(backend, ch)
-      local block1 = unet_conv(128, 256, 128)
-      local block2 = nn.Sequential()
-      block2:add(unet_conv(64, 64, 128))
-      block2:add(unet_branch(block1, backend, 128, 128, 4))
-      block2:add(unet_conv(128, 64, 64))
-      local model = nn.Sequential()
-      model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
-      model:add(nn.LeakyReLU(0.1, true))
-      model:add(unet_branch(block2, backend, 64, 64, 16))
-      return model
-   end
-   local model = nn.Sequential()
-   local con = nn.ConcatTable()
-   local aux_con = nn.ConcatTable()
-
-   model:add(upconv_3(backend, ch, 64))
-
-   con:add(unet(backend, 32))
-   --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
-
-   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
-   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single output
-
-   model:add(conn)
-   model:add(nn.CAddTable())
-   model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU())
-   model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
-   model:add(w2nn.InplaceClip01())
-   model.w2nn_arch_name = "cunet_v3"
-   model.w2nn_offset = 60
-   model.w2nn_scale_factor = 2
-   model.w2nn_channels = ch
-   model.w2nn_resize = true
-   -- 72, 128, 256 are valid
-   --model.w2nn_input_size = 128
-
-   return model
-end
-
-function srcnn.cunet_v6(backend, ch)
+-- cascaded residual channel attention unet
+function srcnn.upcunet(backend, ch)
    function unet_branch(insert, backend, n_input, n_output, depad)
       local block = nn.Sequential()
       local con = nn.ConcatTable(2)
@@ -1044,6 +771,7 @@ function srcnn.cunet_v6(backend, ch)
    local con = nn.ConcatTable()
    local aux_con = nn.ConcatTable()
 
+   -- 2 cascade
    model:add(unet(backend, ch, true))
    con:add(unet(backend, ch, false))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
@@ -1055,7 +783,7 @@ function srcnn.cunet_v6(backend, ch)
    model:add(aux_con)
    model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
    
-   model.w2nn_arch_name = "cunet_v6"
+   model.w2nn_arch_name = "upcunet"
    model.w2nn_offset = 60
    model.w2nn_scale_factor = 2
    model.w2nn_channels = ch