浏览代码

Merge branch 'dev'

nagadomi 8 年之前
父节点
当前提交
4d3d123d72
共有 3 个文件被更改,包括 55 次插入12 次删除
  1. 3 2
      convert_data.lua
  2. 2 0
      lib/settings.lua
  3. 50 10
      train.lua

+ 3 - 2
convert_data.lua

@@ -95,6 +95,7 @@ local function load_images(list)
       if csv_meta and csv_meta.filters then
 	 filters = csv_meta.filters
       end
+      local basename_y = path.basename(filename)
       local im, meta = image_loader.load_byte(filename)
       local skip = false
       local alpha_color = torch.random(0, 1)
@@ -128,7 +129,7 @@ local function load_images(list)
 		     yy = iproc.rgb2y(yy)
 		  end
 		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
-				  {data = {filters = filters, has_x = true}}})
+				  {data = {filters = filters, has_x = true, basename = basename_y}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
 	       end
@@ -144,7 +145,7 @@ local function load_images(list)
 		  if settings.grayscale then
 		     im = iproc.rgb2y(im)
 		  end
-		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
+		  table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
 	       end

+ 2 - 0
lib/settings.lua

@@ -79,6 +79,7 @@ cmd:option("-update_criterion", "mse", 'mse|loss')
 cmd:option("-padding", 0, 'replication padding size')
 cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
 cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
+cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -99,6 +100,7 @@ to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_flip")
 to_bool(settings, "padding_y_zero")
 to_bool(settings, "grayscale")
+to_bool(settings, "validation_filename_split")
 
 if settings.plot then
    require 'gnuplot'

+ 50 - 10
train.lua

@@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
    end
 end
 local function split_data(x, test_size)
-   local index = torch.randperm(#x)
-   local train_size = #x - test_size
-   local train_x = {}
-   local valid_x = {}
-   for i = 1, train_size do
-      train_x[i] = x[index[i]]
-   end
-   for i = 1, test_size do
-      valid_x[i] = x[index[train_size + i]]
+   if settings.validation_filename_split then
+      if not (x[1][2].data and x[1][2].data.basename) then
+	 error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
+      end
+      local basename_db = {}
+      for i = 1, #x do
+	 local meta = x[i][2].data
+	 if basename_db[meta.basename] then
+	    table.insert(basename_db[meta.basename], x[i])
+	 else
+	    basename_db[meta.basename] = {x[i]}
+	 end
+      end
+      local basename_list = {}
+      for k, v in pairs(basename_db) do
+	 table.insert(basename_list, v)
+      end
+      local index = torch.randperm(#basename_list)
+      local train_x = {}
+      local valid_x = {}
+      local pos = 1
+      for i = 1, #basename_list do
+	 if #valid_x >= test_size then
+	    break
+	 end
+	 local xs = basename_list[index[pos]]
+	 for j = 1, #xs do
+	    table.insert(valid_x, xs[j])
+	 end
+	 pos = pos + 1
+      end
+      for i = pos, #basename_list do
+	 local xs = basename_list[index[i]]
+	 for j = 1, #xs do
+	    table.insert(train_x, xs[j])
+	 end
+      end
+      return train_x, valid_x
+   else
+      local index = torch.randperm(#x)
+      local train_size = #x - test_size
+      local train_x = {}
+      local valid_x = {}
+      for i = 1, train_size do
+	 train_x[i] = x[index[i]]
+      end
+      for i = 1, test_size do
+	 valid_x[i] = x[index[train_size + i]]
+      end
+      return train_x, valid_x
    end
-   return train_x, valid_x
 end
 
 local g_transform_pool = nil