nagadomi 9 年之前
父節點
當前提交
903d945652
共有 7 個文件被更改,包括 43 次插入49 次删除
  1. 1 0
      README.md
  2. 6 7
      lib/pairwise_transform.lua
  3. 20 20
      lib/settings.lua
  4. 3 3
      train.lua
  5. 6 7
      train.sh
  6. 3 8
      train_ukbench.sh
  7. 4 4
      waifu2x.lua

+ 1 - 0
README.md

@@ -152,6 +152,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
 ```
 
 ## Training Your Own Model
+Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
 
 ### Data Preparation
 

+ 6 - 7
lib/pairwise_transform.lua

@@ -7,8 +7,7 @@ local pairwise_transform = {}
 
 local function random_half(src, p)
    p = p or 0.25
-   --local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
-   local filter = "Box"
+   local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
    if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
@@ -163,8 +162,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
    end
    return batch
 end
-function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
-   if category == "anime_style_art" then
+function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
+   if style == "art" then
       if level == 1 then
 	 if torch.uniform() > 0.8 then
 	    return pairwise_transform.jpeg_(src, {},
@@ -200,7 +199,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
       else
 	 error("unknown noise level: " .. level)
       end
-   elseif category == "photo" then
+   elseif style == "photo" then
       if level == 1 then
 	 if torch.uniform() > 0.7 then
 	    return pairwise_transform.jpeg_(src, {},
@@ -225,7 +224,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
 	 error("unknown noise level: " .. level)
       end
    else
-      error("unknown category: " .. category)
+      error("unknown style: " .. style)
    end
 end
 
@@ -239,7 +238,7 @@ function pairwise_transform.test_jpeg(src)
    }
    for i = 1, 9 do
       local xy = pairwise_transform.jpeg(src,
-					 "anime_style_art",
+					 "art",
 					 torch.random(1, 2),
 					 128, 7, 1, options)
       image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})

+ 20 - 20
lib/settings.lua

@@ -17,30 +17,30 @@ local cmd = torch.CmdLine()
 cmd:text()
 cmd:text("waifu2x-training")
 cmd:text("Options:")
-cmd:option("-seed", 11, 'fixed input seed')
-cmd:option("-data_dir", "./data", 'data directory')
+cmd:option("-seed", 11, 'RNG seed')
+cmd:option("-data_dir", "./data", 'path to data directory')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
-cmd:option("-test", "images/miku_small.png", 'test image file')
+cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
-cmd:option("-method", "scale", '(noise|scale)')
+cmd:option("-method", "scale", 'method to training (noise|scale)')
 cmd:option("-noise_level", 1, '(1|2)')
-cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
+cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
-cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
-cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
-cmd:option("-scale", 2.0, 'scale')
+cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
+cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
+cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
-cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
-cmd:option("-crop_size", 128, 'crop size')
-cmd:option("-max_size", 512, 'crop if image size larger then this value.')
-cmd:option("-batch_size", 2, 'mini batch size')
-cmd:option("-epoch", 200, 'epoch')
+cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
+cmd:option("-crop_size", 46, 'crop size')
+cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
+cmd:option("-batch_size", 8, 'mini batch size')
+cmd:option("-epoch", 200, 'number of total epochs to run')
 cmd:option("-thread", -1, 'number of CPU threads')
-cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
-cmd:option("-validation_ratio", 0.1, 'validation ratio')
-cmd:option("-validation_crops", 40, 'number of crop region in validation')
+cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
+cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
+cmd:option("-validation_crops", 80, 'number of region per image in validation')
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
-cmd:option("-active_cropping_tries", 20, 'active cropping tries')
+cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
@@ -64,9 +64,9 @@ end
 if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
    error("scale must be mod-2")
 end
-if not (settings.category == "anime_style_art" or
-	settings.category == "photo") then
-   error(string.format("unknown category: %s", settings.category))
+if not (settings.style == "art" or
+	settings.style == "photo") then
+   error(string.format("unknown style: %s", settings.style))
 end
 if settings.random_half == 1 then
    settings.random_half = true

+ 3 - 3
train.lua

@@ -87,7 +87,7 @@ local function transformer(x, is_validation, n, offset)
    if is_validation == nil then is_validation = false end
    local color_noise = nil 
    local overlay = nil
-   local active_cropping_ratio = nil
+   local active_cropping_rate = nil
    local active_cropping_tries = nil
    
    if is_validation then
@@ -117,7 +117,7 @@ local function transformer(x, is_validation, n, offset)
 				      })
    elseif settings.method == "noise" then
       return pairwise_transform.jpeg(x,
-				     settings.category,
+				     settings.style,
 				     settings.noise_level,
 				     settings.crop_size, offset,
 				     n,
@@ -142,7 +142,7 @@ local function train()
    local criterion = create_criterion(model)
    local x = torch.load(settings.images)
    local lrd_count = 0
-   local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
+   local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
       learningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,

+ 6 - 7
train.sh

@@ -2,12 +2,11 @@
 
 th convert_data.lua
 
-th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8  -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
-th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
+th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
 
-th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8  -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
-th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
-
-th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
-th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
+th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
 
+th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
+th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

+ 3 - 8
train_ukbench.sh

@@ -1,11 +1,6 @@
 #!/bin/sh
 
-th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
-th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
-
-th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
-th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
-
-th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
-th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
+th convert_data.lua -data_dir ./data/ukbench
 
+th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg -thread 4
+th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii

+ 4 - 4
waifu2x.lua

@@ -128,11 +128,11 @@ local function waifu2x()
    cmd:text()
    cmd:text("waifu2x")
    cmd:text("Options:")
-   cmd:option("-i", "images/miku_small.png", 'path of the input image')
-   cmd:option("-l", "", 'path of the image-list')
+   cmd:option("-i", "images/miku_small.png", 'path to input image')
+   cmd:option("-l", "", 'path to image-list.txt')
    cmd:option("-scale", 2, 'scale factor')
-   cmd:option("-o", "(auto)", 'path of the output file')
-   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
+   cmd:option("-o", "(auto)", 'path to output file')
+   cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2)')
    cmd:option("-crop_size", 128, 'patch size per process')