gen.lua 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. require 'pl'
  2. require 'image'
  3. require 'trepl'
  4. local gm = require 'graphicsmagick'
  5. torch.setdefaulttensortype("torch.FloatTensor")
  6. local function color(black)
  7. local r, g, b
  8. if torch.uniform() > 0.8 then
  9. if black then
  10. return {0, 0, 0}
  11. else
  12. return {1, 1, 1}
  13. end
  14. else
  15. if torch.uniform() > 0.7 then
  16. r = torch.random(0, 1)
  17. else
  18. r = torch.uniform(0, 1)
  19. end
  20. if torch.uniform() > 0.7 then
  21. g = torch.random(0, 1)
  22. else
  23. g = torch.uniform(0, 1)
  24. end
  25. if torch.uniform() > 0.7 then
  26. b = torch.random(0, 1)
  27. else
  28. b = torch.uniform(0, 1)
  29. end
  30. end
  31. return {r,g,b}
  32. end
  33. local function gen_mod()
  34. local f = function()
  35. local xm = torch.random(2, 4)
  36. local ym = torch.random(2, 4)
  37. return function(x, y) return x % xm == 0 and y % ym == 0 end
  38. end
  39. return f()
  40. end
  41. local function dot()
  42. local sp = 1
  43. local blocks = {}
  44. local n = 64
  45. local s = 24
  46. for i = 1, n do
  47. local block = torch.Tensor(3, s, s)
  48. local margin = torch.random(1, 3)
  49. local size = torch.random(1, 5)
  50. local mod = gen_mod()
  51. local swap_color = torch.uniform() > 0.5
  52. local fg, bg
  53. if swap_color then
  54. fg = color()
  55. bg = color(true)
  56. else
  57. fg = color(true)
  58. bg = color()
  59. end
  60. local use_cross_and_skip = torch.uniform() > 0.5
  61. for j = 1, 3 do
  62. block[j]:fill(bg[j])
  63. end
  64. for y = margin, s - margin do
  65. local b = 0
  66. if use_cross_and_skip and torch.random(0, 1) == 1 then
  67. b = torch.random(0, 1)
  68. end
  69. for x = margin, s - margin do
  70. local yc = math.floor(y / size)
  71. local xc = math.floor(x / size)
  72. if use_corss_and_skip then
  73. if torch.uniform() > 0.25 and mod(yc + b, xc + b) then
  74. block[1][y][x] = fg[1]
  75. block[2][y][x] = fg[2]
  76. block[3][y][x] = fg[3]
  77. end
  78. else
  79. if mod(yc + b, xc + b) then
  80. block[1][y][x] = fg[1]
  81. block[2][y][x] = fg[2]
  82. block[3][y][x] = fg[3]
  83. end
  84. end
  85. end
  86. end
  87. block = image.scale(block, s * 2, s * 2, "simple")
  88. if (not use_corss_and_skip) and size >= 3 and torch.uniform() > 0.5 then
  89. block = image.rotate(block, math.pi / 4, "bilinear")
  90. end
  91. blocks[i] = block
  92. end
  93. local img = torch.Tensor(#blocks, 3, s * 2, s * 2)
  94. for i = 1, #blocks do
  95. img[i]:copy(blocks[i])
  96. end
  97. img = image.toDisplayTensor({input = img, padding = 0, nrow = math.pow(n, 0.5), min = 0, max = 1})
  98. return img
  99. end
  100. local function gen()
  101. return dot()
  102. end
  103. local cmd = torch.CmdLine()
  104. cmd:text()
  105. cmd:text("dot image generator")
  106. cmd:text("Options:")
  107. cmd:option("-o", "", 'output directory')
  108. cmd:option("-n", 64, 'number of images')
  109. local opt = cmd:parse(arg)
  110. if opt.o:len() == 0 then
  111. cmd:help()
  112. os.exit(1)
  113. end
  114. for i = 1, opt.n do
  115. local img = gen()
  116. image.save(path.join(opt.o, i .. ".png"), img)
  117. end