|
@@ -232,7 +232,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples)
|
|
|
local index = torch.LongTensor(instance_loss:size(1))
|
|
|
local dummy = torch.Tensor(instance_loss:size(1))
|
|
|
torch.topk(dummy, index, instance_loss, k, 1, true)
|
|
|
- print("average loss: " ..instance_loss:mean() .. ", average oracle loss: " .. dummy:mean())
|
|
|
+ print("MSE of all data: " ..instance_loss:mean() .. ", MSE of oracle data: " .. dummy:mean())
|
|
|
local shuffle = torch.randperm(k)
|
|
|
local x_s = x:size()
|
|
|
local y_s = y:size()
|
|
@@ -266,7 +266,7 @@ local function remove_small_image(x)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|
|
|
- print(string.format("removed %d small images", #x - #new_x))
|
|
|
+ print(string.format("%d small images are removed", #x - #new_x))
|
|
|
|
|
|
return new_x
|
|
|
end
|
|
@@ -374,7 +374,7 @@ local function train()
|
|
|
if score.MSE < best_score then
|
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
|
best_score = score.MSE
|
|
|
- print("* update best model")
|
|
|
+ print("* Best model is updated")
|
|
|
if settings.save_history then
|
|
|
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
|
|
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
|
@@ -413,7 +413,7 @@ local function train()
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
- print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
|
|
+ print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|