nagadomi 6 gadi atpakaļ
vecāks
revīzija
06246e0d78
1 mainītis faili ar 8 papildinājumiem un 10 dzēšanām
  1. 8 10
      lib/srcnn.lua

+ 8 - 10
lib/srcnn.lua

@@ -987,16 +987,15 @@ end
 function srcnn.cunet_v6(backend, ch)
    function unet_branch(insert, backend, n_input, n_output, depad)
       local block = nn.Sequential()
-      local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
-      --block:add(w2nn.Print())
-      block:add(pooling)
+      local con = nn.ConcatTable(2)
+      local model = nn.Sequential()
+      
+      block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
       block:add(insert)
       block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
-      local parallel = nn.ConcatTable(2)
-      parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
-      parallel:add(block)
-      local model = nn.Sequential()
-      model:add(parallel)
+      con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
+      con:add(block)
+      model:add(con)
       model:add(nn.CAddTable())
       return model
    end
@@ -1015,7 +1014,7 @@ function srcnn.cunet_v6(backend, ch)
 	   attention:add(nn.ReLU(true))
 	   attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
 	   attention:add(nn.Sigmoid(true))
-	   con:add(nn.Identity())                                                                          
+	   con:add(nn.Identity())
 	   con:add(attention)
 	   model:add(con)
 	   model:add(w2nn.ScaleTable())
@@ -1046,7 +1045,6 @@ function srcnn.cunet_v6(backend, ch)
    local aux_con = nn.ConcatTable()
 
    model:add(unet(backend, ch, true))
-
    con:add(unet(backend, ch, false))
    con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))