Browse Source

Add support for plotting loss chart

nagadomi 9 years ago
parent
commit
4d115e4bdb
2 changed files with 22 additions and 2 deletions
  1. 7 0
      lib/settings.lua
  2. 15 2
      train.lua

+ 7 - 0
lib/settings.lua

@@ -47,11 +47,18 @@ cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
 cmd:option("-save_history", 0, 'save all model (0|1)')
+cmd:option("-plot", 0, 'plot loss chart(0|1)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
    settings[k] = v
 end
+if settings.plot == 1 then
+   settings.plot = true
+   require 'gnuplot'
+else
+   settings.plot = false
+end
 if settings.save_history == 1 then
    settings.save_history = true
 else

+ 15 - 2
train.lua

@@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
       end
    end
 end
-
+local function plot(train, valid)
+   gnuplot.plot({
+	 {'training', torch.Tensor(train), '-'},
+	 {'validation', torch.Tensor(valid), '-'}})
+end
 local function train()
+   local hist_train = {}
+   local hist_valid = {}
    local LR_MIN = 1.0e-5
    local model = srcnn.create(settings.method, settings.backend, settings.color)
    local offset = reconstruct.offset_size(model)
@@ -201,10 +207,17 @@ local function train()
       print("# " .. epoch)
       resampling(x, y, train_x, pairwise_func)
       for i = 1, settings.inner_epoch do
-	 print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
+	 local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 print(train_score)
 	 model:evaluate()
 	 print("# validation")
 	 local score = validate(model, eval_metric, valid_xy)
+
+	 table.insert(hist_train, train_score.PSNR)
+	 table.insert(hist_valid, score)
+	 if settings.plot then
+	    plot(hist_train, hist_valid)
+	 end
 	 if score > best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0