convert_data.lua 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
  4. require 'image'
  5. local compression = require 'compression'
  6. local settings = require 'settings'
  7. local image_loader = require 'image_loader'
  8. local iproc = require 'iproc'
  9. local alpha_util = require 'alpha_util'
  10. local function crop_if_large(src, max_size)
  11. local tries = 4
  12. if src:size(2) >= max_size and src:size(3) >= max_size then
  13. local rect
  14. for i = 1, tries do
  15. local yi = torch.random(0, src:size(2) - max_size)
  16. local xi = torch.random(0, src:size(3) - max_size)
  17. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  18. -- ignore simple background
  19. if rect:float():std() >= 0 then
  20. break
  21. end
  22. end
  23. return rect
  24. else
  25. return src
  26. end
  27. end
  28. local function load_images(list)
  29. local MARGIN = 32
  30. local lines = utils.split(file.read(list), "\n")
  31. local x = {}
  32. local skip_notice = false
  33. for i = 1, #lines do
  34. local line = lines[i]
  35. local v = utils.split(line, ",")
  36. local filename = v[1]
  37. local filters = v[2]
  38. if filters then
  39. filters = utils.split(filters, ":")
  40. end
  41. local im, meta = image_loader.load_byte(filename)
  42. local skip = false
  43. if meta and meta.alpha then
  44. if settings.use_transparent_png then
  45. im = alpha_util.fill(im, meta.alpha, torch.random(0, 1))
  46. else
  47. skip = true
  48. end
  49. end
  50. if skip then
  51. if not skip_notice then
  52. io.stderr:write("skip transparent png (settings.use_transparent_png=0)\n")
  53. skip_notice = true
  54. end
  55. else
  56. if settings.max_training_image_size > 0 then
  57. im = crop_if_large(im, settings.max_training_image_size)
  58. end
  59. im = iproc.crop_mod4(im)
  60. local scale = 1.0
  61. if settings.random_half_rate > 0.0 then
  62. scale = 2.0
  63. end
  64. if im then
  65. if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
  66. table.insert(x, {compression.compress(im), {data = {filters = filters}}})
  67. else
  68. io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
  69. end
  70. else
  71. io.stderr:write(string.format("\n%s: skip: load error.\n", filename))
  72. end
  73. end
  74. xlua.progress(i, #lines)
  75. if i % 10 == 0 then
  76. collectgarbage()
  77. end
  78. end
  79. return x
  80. end
  81. torch.manualSeed(settings.seed)
  82. print(settings)
  83. local x = load_images(settings.image_list)
  84. torch.save(settings.images, x)