pairwise_transform.lua 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. require 'image'
  2. local gm = require 'graphicsmagick'
  3. local iproc = require 'iproc'
  4. local data_augmentation = require 'data_augmentation'
  5. local pairwise_transform = {}
  6. local function random_half(src, p)
  7. p = p or 0.25
  8. --local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
  9. local filter = "Box"
  10. if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
  11. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  12. else
  13. return src
  14. end
  15. end
  16. local function crop_if_large(src, max_size)
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. local yi = torch.random(0, src:size(2) - max_size)
  19. local xi = torch.random(0, src:size(3) - max_size)
  20. return iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  21. else
  22. return src
  23. end
  24. end
  25. local function preprocess(src, crop_size, options)
  26. local dest = src
  27. if options.random_half then
  28. dest = random_half(dest)
  29. end
  30. dest = crop_if_large(dest, math.max(crop_size * 4, 512))
  31. dest = data_augmentation.flip(dest)
  32. if options.color_noise then
  33. dest = data_augmentation.color_noise(dest)
  34. end
  35. if options.overlay then
  36. dest = data_augmentation.overlay(dest)
  37. end
  38. dest = data_augmentation.shift_1px(dest)
  39. return dest
  40. end
  41. local function active_cropping(x, y, size, p, tries)
  42. assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
  43. local r = torch.uniform()
  44. if p < r then
  45. local xi = torch.random(0, y:size(3) - (size + 1))
  46. local yi = torch.random(0, y:size(2) - (size + 1))
  47. local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
  48. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  49. return xc, yc
  50. else
  51. local samples = {}
  52. local sum_mse = 0
  53. for i = 1, tries do
  54. local xi = torch.random(0, y:size(3) - (size + 1))
  55. local yi = torch.random(0, y:size(2) - (size + 1))
  56. local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
  57. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  58. local xcf = iproc.byte2float(xc)
  59. local ycf = iproc.byte2float(yc)
  60. local mse = (xcf - ycf):pow(2):mean()
  61. sum_mse = sum_mse + mse
  62. table.insert(samples, {xc = xc, yc = yc, mse = mse})
  63. end
  64. if sum_mse > 0 then
  65. table.sort(samples,
  66. function (a, b)
  67. return a.mse > b.mse
  68. end)
  69. end
  70. return samples[1].xc, samples[1].yc
  71. end
  72. end
  73. function pairwise_transform.scale(src, scale, size, offset, n, options)
  74. local filters = {
  75. "Box","Box", -- 0.012756949974688
  76. "Blackman", -- 0.013191924552285
  77. --"Cartom", -- 0.013753536746706
  78. --"Hanning", -- 0.013761314529647
  79. --"Hermite", -- 0.013850225205266
  80. "SincFast", -- 0.014095824314306
  81. "Jinc", -- 0.014244299255442
  82. }
  83. local downscale_filter = filters[torch.random(1, #filters)]
  84. local y = preprocess(src, size, options)
  85. assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
  86. local down_scale = 1.0 / scale
  87. local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
  88. y:size(2) * down_scale, downscale_filter),
  89. y:size(3), y:size(2))
  90. local batch = {}
  91. for i = 1, n do
  92. local xc, yc = active_cropping(x, y,
  93. size,
  94. options.active_cropping_rate,
  95. options.active_cropping_tries)
  96. xc = iproc.byte2float(xc)
  97. yc = iproc.byte2float(yc)
  98. if options.rgb then
  99. else
  100. yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
  101. xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
  102. end
  103. table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
  104. end
  105. return batch
  106. end
  107. function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
  108. local y = preprocess(src, size, options)
  109. local x = y
  110. for i = 1, #quality do
  111. x = gm.Image(x, "RGB", "DHW")
  112. x:format("jpeg")
  113. if options.jpeg_sampling_factors == 444 then
  114. x:samplingFactors({1.0, 1.0, 1.0})
  115. else -- 420
  116. x:samplingFactors({2.0, 1.0, 1.0})
  117. end
  118. local blob, len = x:toBlob(quality[i])
  119. x:fromBlob(blob, len)
  120. x = x:toTensor("byte", "RGB", "DHW")
  121. end
  122. -- TODO: use shift_1px after compression?
  123. local batch = {}
  124. for i = 1, n do
  125. local xc, yc = active_cropping(x, y, size,
  126. options.active_cropping_rate,
  127. options.active_cropping_tries)
  128. xc = iproc.byte2float(xc)
  129. yc = iproc.byte2float(yc)
  130. if options.rgb then
  131. else
  132. yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
  133. xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
  134. end
  135. table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
  136. end
  137. return batch
  138. end
  139. function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
  140. if category == "anime_style_art" then
  141. if level == 1 then
  142. if torch.uniform() > 0.8 then
  143. return pairwise_transform.jpeg_(src, {},
  144. size, offset, n, options)
  145. else
  146. return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
  147. size, offset, n, options)
  148. end
  149. elseif level == 2 then
  150. local r = torch.uniform()
  151. if torch.uniform() > 0.8 then
  152. return pairwise_transform.jpeg_(src, {},
  153. size, offset, n, options)
  154. else
  155. if r > 0.6 then
  156. return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
  157. size, offset, n, options)
  158. elseif r > 0.3 then
  159. local quality1 = torch.random(37, 70)
  160. local quality2 = quality1 - torch.random(5, 10)
  161. return pairwise_transform.jpeg_(src, {quality1, quality2},
  162. size, offset, n, options)
  163. else
  164. local quality1 = torch.random(52, 70)
  165. local quality2 = quality1 - torch.random(5, 15)
  166. local quality3 = quality1 - torch.random(15, 25)
  167. return pairwise_transform.jpeg_(src,
  168. {quality1, quality2, quality3},
  169. size, offset, n, options)
  170. end
  171. end
  172. else
  173. error("unknown noise level: " .. level)
  174. end
  175. elseif category == "photo" then
  176. if level == 1 then
  177. if torch.uniform() > 0.7 then
  178. return pairwise_transform.jpeg_(src, {},
  179. size, offset, n,
  180. options)
  181. else
  182. return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
  183. size, offset, n,
  184. options)
  185. end
  186. elseif level == 2 then
  187. if torch.uniform() > 0.7 then
  188. return pairwise_transform.jpeg_(src, {},
  189. size, offset, n,
  190. options)
  191. else
  192. return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
  193. size, offset, n,
  194. options)
  195. end
  196. else
  197. error("unknown noise level: " .. level)
  198. end
  199. else
  200. error("unknown category: " .. category)
  201. end
  202. end
  203. function pairwise_transform.test_jpeg(src)
  204. local options = {color_noise = true,
  205. random_half = true,
  206. overlay = true,
  207. active_cropping_rate = 0.5,
  208. active_cropping_tries = 10,
  209. rgb = true
  210. }
  211. for i = 1, 9 do
  212. local xy = pairwise_transform.jpeg(src,
  213. "anime_style_art",
  214. torch.random(1, 2),
  215. 128, 7, 1, options)
  216. image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
  217. image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
  218. end
  219. end
  220. function pairwise_transform.test_scale(src)
  221. local options = {color_noise = true,
  222. random_half = true,
  223. overlay = true,
  224. active_cropping_rate = 0.5,
  225. active_cropping_tries = 10,
  226. rgb = true
  227. }
  228. for i = 1, 10 do
  229. local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
  230. image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
  231. image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
  232. end
  233. end
  234. return pairwise_transform