Bläddra i källkod

Add GCN option for user method

nagadomi 8 år sedan
förälder
incheckning
43a9b58fcb
5 ändrade filer med 32 tillägg och 2 borttagningar
  1. 1 0
      .gitignore
  2. 9 0
      lib/pairwise_transform_user.lua
  3. 15 1
      lib/reconstruct.lua
  4. 1 1
      lib/srcnn.lua
  5. 6 0
      train.lua

+ 1 - 0
.gitignore

@@ -12,6 +12,7 @@ models/*
 !models/photo
 !models/upconv_7
 !models/upconv_7l
+!models/srresnet_12l
 !models/vgg_7
 models/*/*.png
 models/*/*/*.png

+ 9 - 0
lib/pairwise_transform_user.lua

@@ -37,6 +37,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
 	 yc = iproc.rgb2y(yc)
 	 xc = iproc.rgb2y(xc)
       end
+      if options.gcn then
+	 local mean = xc:mean()
+	 local stdv = xc:std()
+	 if stdv > 0 then
+	    xc:add(-mean):div(stdv)
+	 else
+	    xc:add(-mean)
+	 end
+      end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
 

+ 15 - 1
lib/reconstruct.lua

@@ -40,6 +40,15 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
 	    break
 	 end
 	 input[j+1]:copy(x[input_indexes[i + j]])
+	 if model.w2nn_gcn then
+	    local mean = input[j + 1]:mean()
+	    local stdv = input[j + 1]:std()
+	    if stdv > 0 then
+	       input[j + 1]:add(-mean):div(stdv)
+	    else
+	       input[j + 1]:add(-mean)
+	    end
+	 end
 	 c = c + 1
       end
       input_cuda:copy(input)
@@ -80,7 +89,12 @@ local function padding_params(x, model, block_size)
    p.x_w = x:size(3)
    p.x_h = x:size(2)
    p.inner_scale = reconstruct.inner_scale(model)
-   local input_offset = math.ceil(offset / p.inner_scale)
+   local input_offset
+   if model.w2nn_input_offset then
+      input_offset = model.w2nn_input_offset
+   else
+      input_offset = math.ceil(offset / p.inner_scale)
+   end
    local input_block_size = block_size
    local process_size = input_block_size - input_offset * 2
    local h_blocks = math.floor(p.x_h / process_size) +

+ 1 - 1
lib/srcnn.lua

@@ -519,12 +519,12 @@ function srcnn.fcn_v1(backend, ch)
 
    model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
-
    model.w2nn_arch_name = "fcn_v1"
    model.w2nn_offset = 36
    model.w2nn_scale_factor = 1
    model.w2nn_channels = ch
    model.w2nn_input_size = 120
+   model.w2nn_gcn = true
    
    return model
 end

+ 6 - 0
train.lua

@@ -192,6 +192,7 @@ local function transform_pool_init(has_resize, offset)
 		  negate_x_rate = settings.random_pairwise_negate_x_rate
 	       end
 	       local conf = tablex.update({
+		     gcn = settings.gcn,
 		     max_size = settings.max_size,
 		     active_cropping_rate = active_cropping_rate,
 		     active_cropping_tries = active_cropping_tries,
@@ -432,6 +433,11 @@ local function train()
 	 settings.crop_size = model.w2nn_input_size
       end
    end
+   if model.w2nn_gcn then
+      settings.gcn = true
+   else
+      settings.gcn = false
+   end
    dir.makepath(settings.model_dir)
 
    local offset = reconstruct.offset_size(model)