|
@@ -30,38 +30,38 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
- local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
|
|
|
- local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
|
|
|
- for i = 1, #input_indexes, batch_size do
|
|
|
- local c = 0
|
|
|
- local output
|
|
|
- for j = 0, batch_size - 1 do
|
|
|
- if i + j > #input_indexes then
|
|
|
- break
|
|
|
- end
|
|
|
- input[j+1]:copy(x[input_indexes[i + j]])
|
|
|
- if model.w2nn_gcn then
|
|
|
- local mean = input[j + 1]:mean()
|
|
|
- local stdv = input[j + 1]:std()
|
|
|
- if stdv > 0 then
|
|
|
- input[j + 1]:add(-mean):div(stdv)
|
|
|
- else
|
|
|
- input[j + 1]:add(-mean)
|
|
|
- end
|
|
|
+ local input = torch.Tensor(#input_indexes, ch, input_block_size, input_block_size)
|
|
|
+ local input_cuda = torch.CudaTensor():resize(input:size())
|
|
|
+ local output_cuda = torch.CudaTensor():resize(new_x:size())
|
|
|
+ for i = 1, #input_indexes do
|
|
|
+ input[i]:copy(x[input_indexes[i]])
|
|
|
+ if model.w2nn_gcn then
|
|
|
+ local mean = input[i]:mean()
|
|
|
+ local stdv = input[i]:std()
|
|
|
+ if stdv > 0 then
|
|
|
+ input[i]:add(-mean):div(stdv)
|
|
|
+ else
|
|
|
+ input[i]:add(-mean)
|
|
|
end
|
|
|
- c = c + 1
|
|
|
end
|
|
|
- input_cuda:copy(input)
|
|
|
- if c == batch_size then
|
|
|
- output = model:forward(input_cuda)
|
|
|
- else
|
|
|
- output = model:forward(input_cuda:narrow(1, 1, c))
|
|
|
+ end
|
|
|
+ input_cuda:copy(input)
|
|
|
+ local batch_n = math.floor(#input_indexes / batch_size)
|
|
|
+ local batch_rem = #input_indexes % batch_size
|
|
|
+ for i = 1, batch_n * batch_size, batch_size do
|
|
|
+ local output = model:forward(input_cuda:narrow(1, i, batch_size))
|
|
|
+ for j = 0, batch_size - 1 do
|
|
|
+ output_cuda[output_indexes[i + j]]:copy(output[j + 1])
|
|
|
end
|
|
|
- --output = output:view(batch_size, ch, output_size, output_size)
|
|
|
- for j = 0, c - 1 do
|
|
|
- new_x[output_indexes[i + j]]:copy(output[j+1])
|
|
|
+ end
|
|
|
+ if batch_rem > 0 then
|
|
|
+ local i = 1 + batch_n * batch_size
|
|
|
+ local output = model:forward(input_cuda:narrow(1, i, batch_rem))
|
|
|
+ for j = 0, batch_rem - 1 do
|
|
|
+ output_cuda[output_indexes[i + j]]:copy(output[j+1])
|
|
|
end
|
|
|
end
|
|
|
+ new_x:copy(output_cuda)
|
|
|
return new_x
|
|
|
end
|
|
|
local reconstruct = {}
|