convert_data.lua 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  2. package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
  3. require 'pl'
  4. require 'image'
  5. local compression = require 'compression'
  6. local settings = require 'settings'
  7. local image_loader = require 'image_loader'
  8. local MAX_SIZE = 1440
  9. local function crop_if_large(src, max_size)
  10. if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
  11. local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
  12. local sy = torch.random(0, src:size(2) - math.min(max_size, src:size(2)))
  13. return image.crop(src, sx, sy,
  14. math.min(sx + max_size, src:size(3)),
  15. math.min(sy + max_size, src:size(2)))
  16. else
  17. return src
  18. end
  19. end
  20. local function crop_4x(x)
  21. local w = x:size(3) % 4
  22. local h = x:size(2) % 4
  23. return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
  24. end
  25. local function load_images(list)
  26. local MARGIN = 32
  27. local lines = utils.split(file.read(list), "\n")
  28. local x = {}
  29. for i = 1, #lines do
  30. local line = lines[i]
  31. local im, alpha = image_loader.load_byte(line)
  32. if alpha then
  33. io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
  34. else
  35. im = crop_if_large(im, settings.max_size)
  36. im = crop_4x(im)
  37. local scale = 1.0
  38. if settings.random_half then
  39. scale = 2.0
  40. end
  41. if im then
  42. if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
  43. table.insert(x, compression.compress(im))
  44. else
  45. io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
  46. end
  47. else
  48. io.stderr:write(string.format("\n%s: skip: load error.\n", line))
  49. end
  50. end
  51. xlua.progress(i, #lines)
  52. if i % 10 == 0 then
  53. collectgarbage()
  54. end
  55. end
  56. return x
  57. end
  58. torch.manualSeed(settings.seed)
  59. print(settings)
  60. local x = load_images(settings.image_list)
  61. torch.save(settings.images, x)