소스 검색

Add Clip(0,1) to last layer

nagadomi 8 년 전
부모
커밋
cabeeed2a7
3개의 변경된 파일18개의 추가작업 그리고 0개의 파일을 삭제
  1. 13 0
      lib/InplaceClip01.lua
  2. 4 0
      lib/srcnn.lua
  3. 1 0
      lib/w2nn.lua

+ 13 - 0
lib/InplaceClip01.lua

@@ -0,0 +1,13 @@
+local Clip01, parent = torch.class("w2nn.InplaceClip01", "nn.Module")
+
+function Clip01:__init()
+   parent.__init(self)
+end
+function Clip01:updateOutput(input)
+   self.output:set(input:clamp(0, 1))
+   return self.output
+end
+function Clip01:updateGradInput(input, gradOutput)
+   self.gradInput:set(gradOutput)
+   return self.gradInput
+end

+ 4 - 0
lib/srcnn.lua

@@ -153,6 +153,7 @@ function srcnn.vgg_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "vgg_7"
@@ -190,6 +191,7 @@ function srcnn.vgg_12(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "vgg_12"
@@ -219,6 +221,7 @@ function srcnn.dilated_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "dilated_7"
@@ -249,6 +252,7 @@ function srcnn.upconv_7(backend, ch)
    model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
+   model:add(w2nn.InplaceClip01())
    model:add(nn.View(-1):setNumInputDims(3))
 
    model.w2nn_arch_name = "upconv_7"

+ 1 - 0
lib/w2nn.lua

@@ -30,5 +30,6 @@ else
    require 'LeakyReLU'
    require 'ClippedWeightedHuberCriterion'
    require 'ClippedMSECriterion'
+   require 'InplaceClip01'
    return w2nn
 end