convert_data.lua 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  2. package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
  3. require 'pl'
  4. require 'image'
  5. local compression = require 'compression'
  6. local settings = require 'settings'
  7. local image_loader = require 'image_loader'
  8. local iproc = require 'iproc'
  9. local function load_images(list)
  10. local MARGIN = 32
  11. local lines = utils.split(file.read(list), "\n")
  12. local x = {}
  13. for i = 1, #lines do
  14. local line = lines[i]
  15. local im, alpha = image_loader.load_byte(line)
  16. if alpha then
  17. io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
  18. else
  19. im = iproc.crop_mod4(im)
  20. local scale = 1.0
  21. if settings.random_half_rate > 0.0 then
  22. scale = 2.0
  23. end
  24. if im then
  25. if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
  26. table.insert(x, compression.compress(im))
  27. else
  28. io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
  29. end
  30. else
  31. io.stderr:write(string.format("\n%s: skip: load error.\n", line))
  32. end
  33. end
  34. xlua.progress(i, #lines)
  35. if i % 10 == 0 then
  36. collectgarbage()
  37. end
  38. end
  39. return x
  40. end
  41. torch.manualSeed(settings.seed)
  42. print(settings)
  43. local x = load_images(settings.image_list)
  44. torch.save(settings.images, x)