Sfoglia il codice sorgente

Fix crop bug in rare case

nagadomi 8 anni fa
parent
commit
763f5ddcab
3 ha cambiato i file con 25 aggiunte e 4 eliminazioni
  1. 8 2
      lib/pairwise_transform_utils.lua
  2. 7 0
      lib/srcnn.lua
  3. 10 2
      train.lua

+ 8 - 2
lib/pairwise_transform_utils.lua

@@ -125,8 +125,14 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
       t = "byte"
       t = "byte"
    end
    end
    if p < r then
    if p < r then
-      local xi = torch.random(1, x:size(3) - (size + 1)) * scale
-      local yi = torch.random(1, x:size(2) - (size + 1)) * scale
+      local xi = 0
+      local yi = 0
+      if x:size(2) > size + 1 then
+	 xi = torch.random(0, x:size(2) - (size + 1)) * scale
+      end
+      if x:size(3) > size + 1 then
+	 yi = torch.random(0, x:size(3) - (size + 1)) * scale
+      end
       local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
       local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
       local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
       local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
       return xc, yc
       return xc, yc

+ 7 - 0
lib/srcnn.lua

@@ -127,6 +127,8 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
       error("unsupported backend:" .. backend)
       error("unsupported backend:" .. backend)
    end
    end
 end
 end
+srcnn.SpatialConvolution = SpatialConvolution
+
 local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
 local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
    if backend == "cunn" then
    if backend == "cunn" then
       return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
       return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
@@ -136,6 +138,8 @@ local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH
       error("unsupported backend:" .. backend)
       error("unsupported backend:" .. backend)
    end
    end
 end
 end
+srcnn.SpatialFullConvolution = SpatialFullConvolution
+
 local function ReLU(backend)
 local function ReLU(backend)
    if backend == "cunn" then
    if backend == "cunn" then
       return nn.ReLU(true)
       return nn.ReLU(true)
@@ -145,6 +149,8 @@ local function ReLU(backend)
       error("unsupported backend:" .. backend)
       error("unsupported backend:" .. backend)
    end
    end
 end
 end
+srcnn.ReLU = ReLU
+
 local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
 local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
    if backend == "cunn" then
    if backend == "cunn" then
       return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
       return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@@ -154,6 +160,7 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
       error("unsupported backend:" .. backend)
       error("unsupported backend:" .. backend)
    end
    end
 end
 end
+srcnn.SpatialMaxPooling = SpatialMaxPooling
 
 
 -- VGG style net(7 layers)
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
 function srcnn.vgg_7(backend, ch)

+ 10 - 2
train.lua

@@ -418,7 +418,10 @@ local function plot(train, valid)
 	 {'validation', torch.Tensor(valid), '-'}})
 	 {'validation', torch.Tensor(valid), '-'}})
 end
 end
 local function train()
 local function train()
-   local x = remove_small_image(torch.load(settings.images))
+   local x = torch.load(settings.images)
+   if settings.method ~= "user" then
+      x = remove_small_image(x)
+   end
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local hist_train = {}
    local hist_train = {}
    local hist_valid = {}
    local hist_valid = {}
@@ -426,7 +429,12 @@ local function train()
    if settings.resume:len() > 0 then
    if settings.resume:len() > 0 then
       model = torch.load(settings.resume, "ascii")
       model = torch.load(settings.resume, "ascii")
    else
    else
-      model = srcnn.create(settings.model, settings.backend, settings.color)
+      if stringx.endswith(settings.model, ".lua") then
+	 local create_model = dofile(settings.model)
+	 model = create_model(srcnn, settings)
+      else
+	 model = srcnn.create(settings.model, settings.backend, settings.color)
+      end
    end
    end
    if model.w2nn_input_size then
    if model.w2nn_input_size then
       if settings.crop_size ~= model.w2nn_input_size then
       if settings.crop_size ~= model.w2nn_input_size then