nagadomi преди 8 години
родител
ревизия
f16950438c
променени са 1 файла, в които са добавени 12 реда и са изтрити 0 реда
  1. 12 0
      lib/reconstruct.lua

+ 12 - 0
lib/reconstruct.lua

@@ -172,6 +172,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
    return output
    return output
 end
 end
 function reconstruct.image(model, x, block_size)
 function reconstruct.image(model, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    local i2rgb = false
    local i2rgb = false
    if x:size(1) == 1 then
    if x:size(1) == 1 then
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
@@ -194,6 +197,9 @@ function reconstruct.image(model, x, block_size)
    return x
    return x
 end
 end
 function reconstruct.scale(model, scale, x, block_size)
 function reconstruct.scale(model, scale, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    local i2rgb = false
    local i2rgb = false
    if x:size(1) == 1 then
    if x:size(1) == 1 then
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
@@ -287,6 +293,9 @@ local function tta(f, n, model, x, block_size)
    return average:div(#augments)
    return average:div(#augments)
 end
 end
 function reconstruct.image_tta(model, n, x, block_size)
 function reconstruct.image_tta(model, n, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    if reconstruct.is_rgb(model) then
    if reconstruct.is_rgb(model) then
       return tta(reconstruct.image_rgb, n, model, x, block_size)
       return tta(reconstruct.image_rgb, n, model, x, block_size)
    else
    else
@@ -294,6 +303,9 @@ function reconstruct.image_tta(model, n, x, block_size)
    end
    end
 end
 end
 function reconstruct.scale_tta(model, n, scale, x, block_size)
 function reconstruct.scale_tta(model, n, scale, x, block_size)
+   if model.w2nn_input_size then
+      block_size = model.w2nn_input_size
+   end
    if reconstruct.is_rgb(model) then
    if reconstruct.is_rgb(model) then
       local f = function (model, x, offset, block_size)
       local f = function (model, x, offset, block_size)
 	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)
 	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)