Bläddra i källkod

Change the evaluation metric

nagadomi 9 år sedan
förälder
incheckning
8a65db7bab
2 ändrade filer med 21 tillägg och 0 borttagningar
  1. 20 0
      lib/ClippedMSECriterion.lua
  2. 1 0
      lib/w2nn.lua

+ 20 - 0
lib/ClippedMSECriterion.lua

@@ -0,0 +1,20 @@
+local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
+
+function ClippedMSECriterion:__init(min, max)
+   parent.__init(self)
+   self.min = min
+   self.max = max
+   self.diff = torch.Tensor()
+end
+function ClippedMSECriterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max)
+   self.diff:add(-1, target)
+   self.output = self.diff:pow(2):sum() / input:nElement()
+   return self.output
+end
+function ClippedMSECriterion:updateGradInput(input, target)
+   local norm = 1.0 / input:nElement()
+   self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
+   return self.gradInput 
+end

+ 1 - 0
lib/w2nn.lua

@@ -21,5 +21,6 @@ else
    require 'DepthExpand2x'
    require 'DepthExpand2x'
    require 'PSNRCriterion'
    require 'PSNRCriterion'
    require 'ClippedWeightedHuberCriterion'
    require 'ClippedWeightedHuberCriterion'
+   require 'ClippedMSECriterion'
    return w2nn
    return w2nn
 end
 end