convert_data.lua 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. local ffi = require 'ffi'
  2. require './lib/portable'
  3. require 'image'
  4. require 'snappy'
  5. local settings = require './lib/settings'
  6. local image_loader = require './lib/image_loader'
  7. local MAX_SIZE = 1440
  8. local function count_lines(file)
  9. local fp = io.open(file, "r")
  10. local count = 0
  11. for line in fp:lines() do
  12. count = count + 1
  13. end
  14. fp:close()
  15. return count
  16. end
  17. local function crop_if_large(src, max_size)
  18. if max_size > 0 and (src:size(2) > max_size or src:size(3) > max_size) then
  19. local sx = torch.random(0, src:size(3) - math.min(max_size, src:size(3)))
  20. local sy = torch.random(0, src:size(2) - math.min(max_size, src:size(2)))
  21. return image.crop(src, sx, sy,
  22. math.min(sx + max_size, src:size(3)),
  23. math.min(sy + max_size, src:size(2)))
  24. else
  25. return src
  26. end
  27. end
  28. local function crop_4x(x)
  29. local w = x:size(3) % 4
  30. local h = x:size(2) % 4
  31. return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
  32. end
  33. local function load_images(list)
  34. local MARGIN = 32
  35. local count = count_lines(list)
  36. local fp = io.open(list, "r")
  37. local x = {}
  38. local c = 0
  39. for line in fp:lines() do
  40. local im, alpha = image_loader.load_byte(line)
  41. im = crop_if_large(im, settings.max_size)
  42. im = crop_4x(im)
  43. if alpha then
  44. io.stderr:write(string.format("%s: skip: reason: alpha channel.", line))
  45. else
  46. local scale = 1.0
  47. if settings.random_half then
  48. scale = 2.0
  49. end
  50. if im then
  51. if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
  52. table.insert(x, {im:size(), torch.ByteStorage():string(snappy.compress(im:storage():string()))})
  53. else
  54. io.stderr:write(string.format("%s: skip: reason: too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
  55. end
  56. else
  57. io.stderr:write(string.format("%s: skip: reason: load error.\n", line))
  58. end
  59. end
  60. c = c + 1
  61. xlua.progress(c, count)
  62. if c % 10 == 0 then
  63. collectgarbage()
  64. end
  65. end
  66. return x
  67. end
  68. torch.manualSeed(settings.seed)
  69. print(settings)
  70. local x = load_images(settings.image_list)
  71. torch.save(settings.images, x)