浏览代码

Make compatible with MATLAB based benchmarks

nagadomi 9 年之前
父节点
当前提交
b829595c21
共有 1 个文件被更改,包括 18 次插入6 次删除
  1. 18 6
      tools/benchmark.lua

+ 18 - 6
tools/benchmark.lua

@@ -17,8 +17,8 @@ cmd:option("-dir", "./data/test", 'test image directory')
 cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
 cmd:option("-method", "scale", '(scale|noise)')
 cmd:option("-method", "scale", '(scale|noise)')
-cmd:option("-filter", "Box", "downscaling filter (Box|Lanczos)")
-cmd:option("-color", "rgb", '(rgb|y)')
+cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
+cmd:option("-color", "y", '(rgb|y)')
 cmd:option("-noise_level", 1, 'model noise level')
 cmd:option("-noise_level", 1, 'model noise level')
 cmd:option("-jpeg_quality", 75, 'jpeg quality')
 cmd:option("-jpeg_quality", 75, 'jpeg quality')
 cmd:option("-jpeg_times", 1, 'jpeg compression times')
 cmd:option("-jpeg_times", 1, 'jpeg compression times')
@@ -31,21 +31,33 @@ if cudnn then
    cudnn.benchmark = false
    cudnn.benchmark = false
 end
 end
 
 
+local function rgb2y_matlab(x)
+   local y = torch.Tensor(1, x:size(2), x:size(3)):zero()
+   x = iproc.byte2float(x)
+   y:add(x[1] * 65.481)
+   y:add(x[2] * 128.553)
+   y:add(x[3] * 24.966)
+   y:add(16.0)
+   return y:byte():float()
+end
+
 local function MSE(x1, x2)
 local function MSE(x1, x2)
+   x1 = iproc.float2byte(x1):float()
+   x2 = iproc.float2byte(x2):float()
    return (x1 - x2):pow(2):mean()
    return (x1 - x2):pow(2):mean()
 end
 end
 local function YMSE(x1, x2)
 local function YMSE(x1, x2)
-   local x1_2 = image.rgb2y(x1)
-   local x2_2 = image.rgb2y(x2)
+   local x1_2 = rgb2y_matlab(x1)
+   local x2_2 = rgb2y_matlab(x2)
    return (x1_2 - x2_2):pow(2):mean()
    return (x1_2 - x2_2):pow(2):mean()
 end
 end
 local function PSNR(x1, x2)
 local function PSNR(x1, x2)
    local mse = MSE(x1, x2)
    local mse = MSE(x1, x2)
-   return 10 * math.log10(1.0 / mse)
+   return 10 * math.log10((255.0 * 255.0) / mse)
 end
 end
 local function YPSNR(x1, x2)
 local function YPSNR(x1, x2)
    local mse = YMSE(x1, x2)
    local mse = YMSE(x1, x2)
-   return 10 * math.log10(1.0 / mse)
+   return 10 * math.log10((255.0 * 255.0) / mse)
 end
 end
 
 
 local function transform_jpeg(x, opt)
 local function transform_jpeg(x, opt)