Selaa lähdekoodia

update training script

nagadomi 10 vuotta sitten
vanhempi
commit
2231423056
7 muutettua tiedostoa jossa 65 lisäystä ja 18 poistoa
  1. 1 0
      .gitignore
  2. 41 0
      convert_data.lua
  3. 5 6
      lib/minibatch_adam.lua
  4. 1 1
      lib/pairwise_transform.lua
  5. 1 1
      lib/settings.lua
  6. 6 10
      train.lua
  7. 10 0
      train.sh

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
 *~
 cache/*.png
+models/*.png
 waifu2x.log

+ 41 - 0
convert_data.lua

@@ -0,0 +1,41 @@
+require 'torch'
+local settings = require './lib/settings'
+local image_loader = require './lib/image_loader'
+
+local function count_lines(file)
+   local fp = io.open(file, "r")
+   local count = 0
+   for line in fp:lines() do
+      count = count + 1
+   end
+   fp:close()
+   
+   return count
+end
+
+local function load_images(list)
+   local count = count_lines(list)
+   local fp = io.open(list, "r")
+   local x = {}
+   local c = 0
+   for line in fp:lines() do
+      local im = image_loader.load_byte(line)
+      if im then
+	 if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then
+	    table.insert(x, im)
+	 end
+      else
+	 print("error:" .. line)
+      end
+      c = c + 1
+      xlua.progress(c, count)
+      if c % 10 == 0 then
+	 collectgarbage()
+      end
+   end
+   return x
+end
+print(settings)
+local x = load_images(settings.image_list)
+torch.save(settings.images, x)
+

+ 5 - 6
lib/minibatch_sgd.lua → lib/minibatch_adam.lua

@@ -2,10 +2,10 @@ require 'optim'
 require 'cutorch'
 require 'xlua'
 
-local function minibatch_sgd(model, criterion,
-			     train_x,
-			     config, transformer,
-			     input_size, target_size)
+local function minibatch_adam(model, criterion,
+			      train_x,
+			      config, transformer,
+			      input_size, target_size)
    local parameters, gradParameters = model:getParameters()
    config = config or {}
    local sum_loss = 0
@@ -47,7 +47,6 @@ local function minibatch_sgd(model, criterion,
 	 model:backward(inputs, criterion:backward(output, targets))
 	 return f, gradParameters
       end
-      -- must use Adam!!
       optim.adam(feval, parameters, config)
       
       c = c + 1
@@ -60,4 +59,4 @@ local function minibatch_sgd(model, criterion,
    return { mse = sum_loss / count_loss}
 end
 
-return minibatch_sgd
+return minibatch_adam

+ 1 - 1
lib/pairwise_transform.lua

@@ -6,7 +6,7 @@ local pairwise_transform = {}
 
 function pairwise_transform.scale(src, scale, size, offset, options)
    options = options or {}
-   local yi = torch.radom(0, src:size(2) - size - 1)
+   local yi = torch.random(0, src:size(2) - size - 1)
    local xi = torch.random(0, src:size(3) - size - 1)
    local down_scale = 1.0 / scale
    local y = image.crop(src, xi, yi, xi + size, yi + size)

+ 1 - 1
lib/settings.lua

@@ -51,7 +51,7 @@ torch.setnumthreads(settings.core)
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
-settings.validation_ratio = 01
+settings.validation_ratio = 0.1
 settings.validation_crops = 40
 settings.block_offset = 7 -- see srcnn.lua
 

+ 6 - 10
train.lua

@@ -5,7 +5,7 @@ require 'xlua'
 require 'pl'
 
 local settings = require './lib/settings'
-local minibatch_sgd = require './lib/minibatch_sgd'
+local minibatch_adam = require './lib/minibatch_adam'
 local iproc = require './lib/iproc'
 local create_model = require './lib/srcnn'
 local reconstract, reconstract_ch = require './lib/reconstract'
@@ -77,10 +77,6 @@ local function train()
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
    }
-   local denoise_model = nil
-   if settings.method == "scale" and path.exists(settings.denoise_model_file) then
-      denoise_model = torch.load(settings.denoise_model_file)
-   end
    local transformer = function(x, is_validation)
       if is_validation == nil then is_validation = false end
       if settings.method == "scale" then
@@ -109,11 +105,11 @@ local function train()
    for epoch = 1, settings.epoch do
       model:training()
       print("# " .. epoch)
-      print(minibatch_sgd(model, criterion, train_x, adam_config,
-			  transformer,
-			  {1, settings.crop_size, settings.crop_size},
-			  {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
-			 ))
+      print(minibatch_adam(model, criterion, train_x, adam_config,
+			   transformer,
+			   {1, settings.crop_size, settings.crop_size},
+			   {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
+			  ))
       if epoch % 1 == 0 then
 	 collectgarbage()
 	 model:evaluate()

+ 10 - 0
train.sh

@@ -0,0 +1,10 @@
+#!/bin/sh
+
+th train.lua -method noise -noise_level 1 -test images/miku_noise.png
+th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii
+
+th train.lua -method noise -noise_level 2 -test images/miku_noise.png
+th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii
+
+th train.lua -method scale -scale 2 -test images/miku_small.png
+th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii