Browse Source

Add oracle_rate option

nagadomi 9 years ago
parent
commit
8088460a20
2 changed files with 58 additions and 7 deletions
  1. 2 0
      lib/settings.lua
  2. 56 7
      train.lua

+ 2 - 0
lib/settings.lua

@@ -56,6 +56,8 @@ cmd:option("-max_training_image_size", -1, 'if training image is larger than N,
 cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)')
 cmd:option("-resize_blur_min", 0.85, 'min blur parameter for ResizeImage')
 cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
+cmd:option("-oracle_rate", 0.0, '')
+cmd:option("-oracle_drop_rate", 0.5, '')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then

+ 56 - 7
train.lua

@@ -175,20 +175,48 @@ local function transformer(model, x, is_validation, n, offset)
 end
 
 local function resampling(x, y, train_x, transformer, input_size, target_size)
-   print("## resampling")
+   local c = 1
+   local shuffle = torch.randperm(#train_x)
    for t = 1, #train_x do
       xlua.progress(t, #train_x)
-      local xy = transformer(train_x[t], false, settings.patches)
+      local xy = transformer(train_x[shuffle[t]], false, settings.patches)
       for i = 1, #xy do
-	 local index = (t - 1) * settings.patches + i
-         x[index]:copy(xy[i][1])
-	 y[index]:copy(xy[i][2])
+         x[c]:copy(xy[i][1])
+	 y[c]:copy(xy[i][2])
+	 c = c + 1
+	 if c > x:size(1) then
+	    break
+	 end
+      end
+      if c > x:size(1) then
+	 break
       end
       if t % 50 == 0 then
 	 collectgarbage()
       end
    end
+   xlua.progress(#train_x, #train_x)
 end
+local function get_oracle_data(x, y, instance_loss, k, samples)
+   local index = torch.LongTensor(instance_loss:size(1))
+   local dummy = torch.Tensor(instance_loss:size(1))
+   torch.topk(dummy, index, instance_loss, k, 1, true)
+   print("average loss: " ..instance_loss:mean() .. ", average oracle loss: " .. dummy:mean())
+   local shuffle = torch.randperm(k)
+   local x_s = x:size()
+   local y_s = y:size()
+   x_s[1] = samples
+   y_s[1] = samples
+   local oracle_x = torch.Tensor(table.unpack(torch.totable(x_s)))
+   local oracle_y = torch.Tensor(table.unpack(torch.totable(y_s)))
+
+   for i = 1, samples do
+      oracle_x[i]:copy(x[index[shuffle[i]]])
+      oracle_y[i]:copy(y[index[shuffle[i]]])
+   end
+   return oracle_x, oracle_y
+end
+
 local function remove_small_image(x)
    local new_x = {}
    for i = 1, #x do
@@ -254,12 +282,33 @@ local function train()
       x = torch.Tensor(settings.patches * #train_x,
 		       ch, settings.crop_size, settings.crop_size)
    end
+   local instance_loss = nil
+
    for epoch = 1, settings.epoch do
       model:training()
       print("# " .. epoch)
-      resampling(x, y, train_x, pairwise_func)
+      print("## resampling")
+      if instance_loss then
+	 -- active learning
+	 local oracle_k = math.min(x:size(1) * (settings.oracle_rate * (1 / (1 - settings.oracle_drop_rate))), x:size(1))
+	 local oracle_n = math.min(x:size(1) * settings.oracle_rate, x:size(1))
+	 if oracle_n > 0 then
+	    local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
+	    resampling(x, y, train_x, pairwise_func)
+	    x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
+	    y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
+	 else
+	    resampling(x, y, train_x, pairwise_func)
+	 end
+      else
+	 resampling(x, y, train_x, pairwise_func)
+      end
+      collectgarbage()
+      instance_loss = torch.Tensor(x:size(1)):zero()
+
       for i = 1, settings.inner_epoch do
-	 local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 instance_loss:copy(il)
 	 print(train_score)
 	 model:evaluate()
 	 print("# validation")