浏览代码

clean; Add upcunet_v3

nagadomi 6 年之前
父节点
当前提交
b28b6172ca
共有 4 个文件被更改,包括 249 次插入30 次删除
  1. 1 1
      lib/AuxiliaryLossTable.lua
  2. 1 2
      lib/ScaleTable.lua
  3. 131 27
      lib/srcnn.lua
  4. 116 0
      tools/find_unet.py

+ 1 - 1
lib/AuxiliaryLossTable.lua

@@ -41,6 +41,6 @@ end
 function AuxiliaryLossTable:clearState()
    self.gradInput = {}
    self.output_table = {}
-   self.output_tensor:set()
+   nn.utils.clear(self, 'output_tensor')
    return parent:clearState()
 end

+ 1 - 2
lib/ScaleTable.lua

@@ -33,7 +33,6 @@ function ScaleTable:updateGradInput(input, gradOutput)
    return self.gradInput
 end
 function ScaleTable:clearState()
-   self.grad_tmp:set()
-   self.scale:set()
+   nn.utils.clear(self, {'grad_tmp','scale'})
    return parent:clearState()
 end

+ 131 - 27
lib/srcnn.lua

@@ -218,6 +218,13 @@ local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW,
 end
 srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
 
+local function GlobalAveragePooling(n_output)
+   local gap = nn.Sequential()
+   gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
+   gap:add(nn.View(-1, n_output, 1, 1))
+   return gap
+end
+srcnn.GlobalAveragePooling = GlobalAveragePooling
 
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
@@ -247,6 +254,7 @@ function srcnn.vgg_7(backend, ch)
    
    return model
 end
+
 -- VGG style net(12 layers)
 function srcnn.vgg_12(backend, ch)
    local model = nn.Sequential()
@@ -721,6 +729,38 @@ function srcnn.upconv_refine(backend, ch)
    return model
 end
 
+-- I devised this arch because of the block size and global average pooling problem,
+-- but SEBlock may possibly learn multi-scale input and no problems occur.
+local function SpatialSEBlock(backend, ave_size, n_output, r)
+   local con = nn.ConcatTable(2)
+   local attention = nn.Sequential()
+   local n_mid = math.floor(n_output / r)
+   attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
+   attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
+   attention:add(nn.ReLU(true))
+   attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
+   attention:add(nn.Sigmoid(true))
+   attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
+   con:add(nn.Identity())
+   con:add(attention)
+   return con
+end
+
+-- Squeeze and Excitation Block
+local function SEBlock(backend, n_output, r)
+   local con = nn.ConcatTable(2)
+   local attention = nn.Sequential()
+   local n_mid = math.floor(n_output / r)
+   attention:add(GlobalAveragePooling(n_output))
+   attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
+   attention:add(nn.ReLU(true))
+   attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
+   attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid 
+   con:add(nn.Identity())
+   con:add(attention)
+   return con
+end
+
 -- cascaded residual channel attention unet
 function srcnn.upcunet(backend, ch)
    function unet_branch(insert, backend, n_input, n_output, depad)
@@ -744,17 +784,7 @@ function srcnn.upcunet(backend, ch)
 	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
 	model:add(nn.LeakyReLU(0.1, true))
 	if se then
-	   -- Squeeze and Excitation Networks
-	   local con = nn.ConcatTable(2)
-	   local attention = nn.Sequential()
-	   attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
-	   attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
-	   attention:add(nn.ReLU(true))
-	   attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
-	   attention:add(nn.Sigmoid(true))
-	   con:add(nn.Identity())
-	   con:add(attention)
-	   model:add(con)
+	   model:add(SEBlock(backend, n_output, 4))
 	   model:add(w2nn.ScaleTable())
 	end
 	return model
@@ -799,8 +829,10 @@ function srcnn.upcunet(backend, ch)
    model.w2nn_scale_factor = 2
    model.w2nn_channels = ch
    model.w2nn_resize = true
-   -- 72, 128, 256 are valid
-   --model.w2nn_input_size = 128
+   model.w2nn_valid_input_size = {}
+   for i = 76, 512, 4 do
+      table.insert(model.w2nn_valid_input_size, i)
+   end
 
    return model
 end
@@ -828,19 +860,7 @@ function srcnn.upcunet_v2(backend, ch)
       model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
       model:add(nn.LeakyReLU(0.1, true))
       if se then
-	 -- Spatial Squeeze and Excitation Networks
-	 local se_fac = 4
-	 local con = nn.ConcatTable(2)
-	 local attention = nn.Sequential()
-	 attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
-	 attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
-	 attention:add(nn.ReLU(true))
-	 attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
-	 attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid 
-	 attention:add(nn.SpatialUpSamplingNearest(4, 4))
-	 con:add(nn.Identity())
-	 con:add(attention)
-	 model:add(con)
+	 model:add(SpatialSEBlock(backend, 4, n_output, 4))
 	 model:add(nn.CMulTable())
       end
       return model
@@ -888,11 +908,89 @@ function srcnn.upcunet_v2(backend, ch)
 
    return model
 end
+-- cascaded residual channel attention unet
+function srcnn.upcunet_v3(backend, ch)
+   local function unet_branch(insert, backend, n_input, n_output, depad)
+      local block = nn.Sequential()
+      local con = nn.ConcatTable(2)
+      local model = nn.Sequential()
+      
+      block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
+      block:add(nn.LeakyReLU(0.1, true))
+      block:add(insert)
+      block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
+      block:add(nn.LeakyReLU(0.1, true))
+      con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
+      con:add(block)
+      model:add(con)
+      model:add(nn.CAddTable())
+      return model
+   end
+   local function unet_conv(n_input, n_middle, n_output, se)
+	local model = nn.Sequential()
+	model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
+	model:add(nn.LeakyReLU(0.1, true))
+	model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
+	model:add(nn.LeakyReLU(0.1, true))
+	if se then
+	   model:add(SEBlock(backend, n_output, 4))
+	   model:add(w2nn.ScaleTable())
+	end
+	return model
+   end
+   -- Residual U-Net
+   local function unet(backend, ch, deconv)
+      local block1 = unet_conv(128, 256, 128, true)
+      local block2 = nn.Sequential()
+      block2:add(unet_conv(64, 64, 128, true))
+      block2:add(unet_branch(block1, backend, 128, 128, 4))
+      block2:add(unet_conv(128, 64, 64, true))
+      local model = nn.Sequential()
+      model:add(unet_conv(ch, 32, 64, false))
+      model:add(unet_branch(block2, backend, 64, 64, 16))
+      model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
+      model:add(nn.LeakyReLU(0.1))
+      if deconv then
+	 model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
+      else
+	 model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
+      end
+      return model
+   end
+   local model = nn.Sequential()
+   local con = nn.ConcatTable()
+   local aux_con = nn.ConcatTable()
+
+   -- 2 cascade
+   model:add(unet(backend, ch, true))
+   con:add(unet(backend, ch, false))
+   con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
+
+   aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
+   aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
+
+   model:add(con)
+   model:add(aux_con)
+   model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
+   
+   model.w2nn_arch_name = "upcunet_v3"
+   model.w2nn_offset = 60
+   model.w2nn_scale_factor = 2
+   model.w2nn_channels = ch
+   model.w2nn_resize = true
+   model.w2nn_valid_input_size = {}
+   for i = 76, 512, 4 do
+      table.insert(model.w2nn_valid_input_size, i)
+   end
+
+   return model
+end
+
 local function bench()
    local sys = require 'sys'
    cudnn.benchmark = true
    local model = nil
-   local arch = {"upconv_7", "upcunet", "upcunet_v2"}
+   local arch = {"upconv_7", "upcunet", "upcunet_v3"}
    local backend = "cudnn"
    for k = 1, #arch do
       model = srcnn[arch[k]](backend, 3):cuda()
@@ -947,6 +1045,12 @@ print(model)
 model:training()
 print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
 os.exit()
+local model = srcnn.upcunet_v3("cunn", 3):cuda()
+print(model)
+model:training()
+print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
+os.exit()
+bench()
 --]]
 
 return srcnn

+ 116 - 0
tools/find_unet.py

@@ -0,0 +1,116 @@
+def find_unet_v2():
+    avg_pool=4
+    print_mod = False
+    check_mod = True
+    print("cascade")
+    
+    for i in range(76, 512):
+        print("-- {}".format(i))
+        print_buf = []
+        s = i
+        # unet 1
+
+        s = s - 4 # conv3x3x2
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+            continue
+
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+           continue
+        s = s * 2 # up2x2
+        s = s - 4 # conv3x3x2
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+            continue
+        s = s * 2 # up2x2
+
+        # deconv
+        s = s
+        s = s * 2 - 4
+
+        # unet 2
+        s = s - 4 # conv3x3x2
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+            continue
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+            continue
+        s = s * 2 # up2x2
+        s = s - 4 # conv3x3x2
+        if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
+        if check_mod and s % avg_pool != 0:
+            continue
+        s = s * 2 # up2x2
+        s = s - 2 # conv3x3 last
+        #if s % avg_pool != 0:
+        #    continue
+        print("ok", i, s)
+
+def find_unet():
+    check_mod = True
+    print_size = False
+    print("cascade")
+    
+    for i in range(76, 512):
+        print_buf = []
+        s = i
+        # unet 1
+
+        s = s - 4 # conv3x3x2
+        if print_size: print("1/2", s)
+        if check_mod and s % 2 != 0:
+            continue
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        if print_size: print("1/2",s)
+        if check_mod and s % 2 != 0:
+            continue
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        
+        s = s * 2 # up2x2
+        if print_size: print("2x",s)
+        s = s - 4 # conv3x3x2
+        s = s * 2 # up2x2
+        if print_size: print("2x",s)
+
+        # deconv
+        s = s - 2
+        s = s * 2 - 4
+
+        # unet 2
+        s = s - 4 # conv3x3x2
+        if print_size: print("1/2",s)
+        if check_mod and s % 2 != 0:
+            continue
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        if print_size: print("1/2",s)
+        if check_mod and s % 2 != 0:
+            continue
+        s = s / 2 # down2x2
+        s = s - 4 # conv3x3x2
+        s = s * 2 # up2x2
+        if print_size: print("2x",s)
+        s = s - 4 # conv3x3x2
+        s = s * 2 # up2x2
+        if print_size: print("2x",s)
+        s = s - 2 # conv3x3
+        s = s - 2 # conv3x3 last
+        #if s % avg_pool != 0:
+        #    continue
+        print("ok", i, s)
+        
+find_unet()
+