convert_data.lua 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  2. package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
  3. require 'pl'
  4. require 'image'
  5. local compression = require 'compression'
  6. local settings = require 'settings'
  7. local image_loader = require 'image_loader'
  8. local iproc = require 'iproc'
  9. local function crop_if_large(src, max_size)
  10. local tries = 4
  11. if src:size(2) >= max_size and src:size(3) >= max_size then
  12. local rect
  13. for i = 1, tries do
  14. local yi = torch.random(0, src:size(2) - max_size)
  15. local xi = torch.random(0, src:size(3) - max_size)
  16. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  17. -- ignore simple background
  18. if rect:float():std() >= 0 then
  19. break
  20. end
  21. end
  22. return rect
  23. else
  24. return src
  25. end
  26. end
  27. local function load_images(list)
  28. local MARGIN = 32
  29. local lines = utils.split(file.read(list), "\n")
  30. local x = {}
  31. for i = 1, #lines do
  32. local line = lines[i]
  33. local im, meta = image_loader.load_byte(line)
  34. if meta and meta.alpha then
  35. io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
  36. else
  37. if settings.max_training_image_size > 0 then
  38. im = crop_if_large(im, settings.max_training_image_size)
  39. end
  40. im = iproc.crop_mod4(im)
  41. local scale = 1.0
  42. if settings.random_half_rate > 0.0 then
  43. scale = 2.0
  44. end
  45. if im then
  46. if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
  47. table.insert(x, compression.compress(im))
  48. else
  49. io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
  50. end
  51. else
  52. io.stderr:write(string.format("\n%s: skip: load error.\n", line))
  53. end
  54. end
  55. xlua.progress(i, #lines)
  56. if i % 10 == 0 then
  57. collectgarbage()
  58. end
  59. end
  60. return x
  61. end
  62. torch.manualSeed(settings.seed)
  63. print(settings)
  64. local x = load_images(settings.image_list)
  65. torch.save(settings.images, x)