Ver código fonte

Use conv2d instead of nn.SpatialConvolutionMM

nagadomi 8 anos atrás
pai
commit
f65132dadb
2 arquivos alterados com 63 adições e 10 exclusões
  1. 1 10
      lib/data_augmentation.lua
  2. 62 0
      lib/iproc.lua

+ 1 - 10
lib/data_augmentation.lua

@@ -71,16 +71,12 @@ function data_augmentation.unsharp_mask(src, p)
       return src
    end
 end
-data_augmentation.blur_conv = {}
 function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
    size = size or "3"
    filters = utils.split(size, ",")
    for i = 1, #filters do
       local s = tonumber(filters[i])
       filters[i] = s
-      if not data_augmentation.blur_conv[s] then
-	 data_augmentation.blur_conv[s] = nn.SpatialConvolutionMM(1, 1, s, s, 1, 1, (s - 1) / 2, (s - 1) / 2):noBias():cuda()
-      end
    end
    if torch.uniform() < p then
       local src, conversion = iproc.byte2float(src)
@@ -92,12 +88,7 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
 	 sigma = torch.uniform(sigma_min, sigma_max)
       end
       local kernel = iproc.gaussian2d(kernel_size, sigma)
-      data_augmentation.blur_conv[kernel_size].weight:copy(kernel)
-      local dest = torch.Tensor(3, src:size(2), src:size(3))
-      dest[1]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[1]:reshape(1, src:size(2), src:size(3)):cuda()))
-      dest[2]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[2]:reshape(1, src:size(2), src:size(3)):cuda()))
-      dest[3]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[3]:reshape(1, src:size(2), src:size(3)):cuda()))
-
+      local dest = iproc.convolve(src, kernel, 'same')
       if conversion then
 	 dest = iproc.float2byte(dest)
       end

+ 62 - 0
lib/iproc.lua

@@ -1,6 +1,7 @@
 local gm = {}
 gm.Image = require 'graphicsmagick.Image'
 local image = nil
+require 'dok'
 
 local iproc = {}
 local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
@@ -267,6 +268,67 @@ function iproc.gaussian2d(kernel_size, sigma)
    kernel:div(kernel:sum())
    return kernel
 end
+
+-- from image.convolve
+function iproc.convolve(...)
+   local dst,src,kernel,mode
+   local args = {...}
+   if select('#',...) == 4 then
+      dst = args[1]
+      src = args[2]
+      kernel = args[3]
+      mode = args[4]
+   elseif select('#',...) == 3 then
+      if type(args[3]) == 'string' then
+         src = args[1]
+         kernel = args[2]
+         mode = args[3]
+      else
+         dst = args[1]
+         src = args[2]
+         kernel = args[3]
+      end
+   elseif select('#',...) == 2 then
+      src = args[1]
+      kernel = args[2]
+   else
+      print(dok.usage('iproc.convolve',
+                       'convolves an input image with a kernel, returns the result', nil,
+                       {type='torch.Tensor', help='input image', req=true},
+                       {type='torch.Tensor', help='kernel', req=true},
+                       {type='string', help='type: full | valid | same', default='valid'},
+                       '',
+                       {type='torch.Tensor', help='destination', req=true},
+                       {type='torch.Tensor', help='input image', req=true},
+                       {type='torch.Tensor', help='kernel', req=true},
+                       {type='string', help='type: full | valid | same', default='valid'}))
+      dok.error('incorrect arguments', 'image.convolve')
+   end
+   if mode and mode ~= 'valid' and mode ~= 'full' and mode ~= 'same' then
+      dok.error('mode has to be one of: full | valid | same', 'image.convolve')
+   end
+   local md = (((mode == 'full') or (mode == 'same')) and 'F') or 'V'
+   if kernel:nDimension() == 2 and src:nDimension() == 3 then
+      local k3d = src.new(src:size(1), kernel:size(1), kernel:size(2))
+      for i = 1,src:size(1) do
+         k3d[i]:copy(kernel)
+      end
+      kernel = k3d
+   end
+   if dst then
+      torch.conv2(dst,src,kernel,md)
+   else
+      dst = torch.conv2(src,kernel,md)
+   end
+   if mode == 'same' then
+      local cx = dst:dim()
+      local cy = cx-1
+      local ofy = math.ceil(kernel:size(cy)/2)
+      local ofx = math.ceil(kernel:size(cx)/2)
+      dst = dst:narrow(cy, ofy, src:size(cy)):narrow(cx, ofx, src:size(cx))
+   end
+   return dst
+end
 local function test_conversion()
    local a = torch.linspace(0, 255, 256):float():div(255.0)
    local b = iproc.float2byte(a)