Przeglądaj źródła

Fix validation metric

nagadomi 9 lat temu
rodzic
commit
fa9355be7c
1 zmienionych plików z 6 dodań i 0 usunięć
  1. 6 0
      train.lua

+ 6 - 0
train.lua

@@ -48,6 +48,12 @@ local function make_validation_set(x, transformer, n, patches)
       xlua.progress(i, #x)
       collectgarbage()
    end
+   local new_data = {}
+   local perm = torch.randperm(#data)
+   for i = 1, perm:size(1) do
+      new_data[i] = data[perm[i]]
+   end
+   data = new_data
    return data
 end
 local function validate(model, criterion, data, batch_size)