소스 검색

No bias at FullConvolution

nagadomi 9 년 전
부모
커밋
9563d84302
5개의 변경된 파일15개의 추가작업 그리고 6개의 파일을 삭제
  1. 1 1
      appendix/train_upconv_7_art.sh
  2. 1 1
      lib/srcnn.lua
  3. 0 0
      models/upconv_7/art/scale2.0x_model.json
  4. 0 0
      models/upconv_7/art/scale2.0x_model.t7
  5. 13 4
      tools/export_model.lua

+ 1 - 1
appendix/train_upconv_7_art.sh

@@ -4,7 +4,7 @@
 th convert_data.lua -max_training_image_size 1600
 
 # scale
-th train.lua -save_history 1 -inner_epoch 1 -epoch 100 -scale 2 -model upconv_7 -method scale -model_dir models/test/upconv_7_rev3 -downsampling_filters "Box,Sinc" -test query/scale_test.png -backend cudnn -thread 4 
+th train.lua -save_history 1 -scale 2 -model upconv_7 -method scale -model_dir models/test/upconv_7_rev5 -downsampling_filters "Box,Sinc" -test query/pixel-art-small.png -backend cudnn -thread 4 -oracle_rate 0.05
 
 # noise_scale
 th train.lua -save_history 1 -model upconv_7 -method noise_scale -noise_level 0 -model_dir models/test/yuv420_rev2 -downsampling_filters "Box,Sinc" -test query/noise_test.jpg -backend cudnn -thread 4  -resume models/test/upconv_7_rev3/scale2.0x_model.t7 -style art 

+ 1 - 1
lib/srcnn.lua

@@ -248,7 +248,7 @@ function srcnn.upconv_7(backend, ch)
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
-   model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3))
+   model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "upconv_7"

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 0 - 0
models/upconv_7/art/scale2.0x_model.json


파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 0 - 0
models/upconv_7/art/scale2.0x_model.t7


+ 13 - 4
tools/export_model.lua

@@ -5,7 +5,7 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
 require 'w2nn'
 local cjson = require "cjson"
 
-function meta_data(model)
+local function meta_data(model)
    local meta = {}
    for k, v in pairs(model) do
       if k:match("w2nn_") then
@@ -14,7 +14,7 @@ function meta_data(model)
    end
    return meta
 end
-function includes(s, a)
+local function includes(s, a)
    for i = 1, #a do
       if s == a[i] then
 	 return true
@@ -22,7 +22,16 @@ function includes(s, a)
    end
    return false
 end
-function export(model, output)
+
+local function get_bias(mod)
+   if mod.bias then
+      return mod.bias:float()
+   else
+      -- no bias
+      return torch.FloatTensor(mod.nOutputPlane):zero()
+   end
+end
+local function export(model, output)
    local targets = {"nn.SpatialConvolutionMM",
 		    "cudnn.SpatialConvolution",
 		    "nn.SpatialFullConvolution",
@@ -52,7 +61,7 @@ function export(model, output)
 	    padH = mod.padH,
 	    nInputPlane = mod.nInputPlane,
 	    nOutputPlane = mod.nOutputPlane,
-	    bias = torch.totable(mod.bias:float()),
+	    bias = torch.totable(get_bias(mod)),
 	    weight = weight
 	 }
 	 if first_layer then

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.