Explorar el Código

performance tuning

nagadomi hace 6 años
padre
commit
daadbaccae
Se han modificado 1 ficheros con 27 adiciones y 27 borrados
  1. 27 27
      lib/reconstruct.lua

+ 27 - 27
lib/reconstruct.lua

@@ -30,38 +30,38 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
 	 end
       end
    end
-   local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
-   local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
-   for i = 1, #input_indexes, batch_size do
-      local c = 0
-      local output
-      for j = 0, batch_size - 1 do
-	 if i + j > #input_indexes then
-	    break
-	 end
-	 input[j+1]:copy(x[input_indexes[i + j]])
-	 if model.w2nn_gcn then
-	    local mean = input[j + 1]:mean()
-	    local stdv = input[j + 1]:std()
-	    if stdv > 0 then
-	       input[j + 1]:add(-mean):div(stdv)
-	    else
-	       input[j + 1]:add(-mean)
-	    end
+   local input = torch.Tensor(#input_indexes, ch, input_block_size, input_block_size)
+   local input_cuda = torch.CudaTensor():resize(input:size())
+   local output_cuda = torch.CudaTensor():resize(new_x:size())
+   for i = 1, #input_indexes do
+      input[i]:copy(x[input_indexes[i]])
+      if model.w2nn_gcn then
+	 local mean = input[i]:mean()
+	 local stdv = input[i]:std()
+	 if stdv > 0 then
+	    input[i]:add(-mean):div(stdv)
+	 else
+	    input[i]:add(-mean)
 	 end
-	 c = c + 1
       end
-      input_cuda:copy(input)
-      if c == batch_size then
-	 output = model:forward(input_cuda)
-      else
-	 output = model:forward(input_cuda:narrow(1, 1, c))
+   end
+   input_cuda:copy(input)
+   local batch_n = math.floor(#input_indexes / batch_size)
+   local batch_rem = #input_indexes % batch_size
+   for i = 1, batch_n * batch_size, batch_size do
+      local output = model:forward(input_cuda:narrow(1, i, batch_size))
+      for j = 0, batch_size - 1 do
+	 output_cuda[output_indexes[i + j]]:copy(output[j + 1])
       end
-      --output = output:view(batch_size, ch, output_size, output_size)
-      for j = 0, c - 1 do
-	 new_x[output_indexes[i + j]]:copy(output[j+1])
+   end
+   if batch_rem > 0 then
+      local i = 1 + batch_n * batch_size
+      local output = model:forward(input_cuda:narrow(1, i, batch_rem))
+      for j = 0, batch_rem - 1 do
+	 output_cuda[output_indexes[i + j]]:copy(output[j+1])
       end
    end
+   new_x:copy(output_cuda)
    return new_x
 end
 local reconstruct = {}