pairwise_transform.lua 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. require 'image'
  2. local gm = require 'graphicsmagick'
  3. local iproc = require './iproc'
  4. local reconstract = require './reconstract'
  5. local pairwise_transform = {}
  6. function pairwise_transform.scale(src, scale, size, offset, options)
  7. options = options or {}
  8. local yi = torch.radom(0, src:size(2) - size - 1)
  9. local xi = torch.random(0, src:size(3) - size - 1)
  10. local down_scale = 1.0 / scale
  11. local y = image.crop(src, xi, yi, xi + size, yi + size)
  12. local flip = torch.random(1, 4)
  13. local nega = torch.random(0, 1)
  14. local filters = {
  15. "Box", -- 0.012756949974688
  16. "Blackman", -- 0.013191924552285
  17. --"Cartom", -- 0.013753536746706
  18. --"Hanning", -- 0.013761314529647
  19. --"Hermite", -- 0.013850225205266
  20. --"SincFast", -- 0.014095824314306
  21. --"Jinc", -- 0.014244299255442
  22. }
  23. local downscale_filter = filters[torch.random(1, #filters)]
  24. if r == 1 then
  25. y = image.hflip(y)
  26. elseif r == 2 then
  27. y = image.vflip(y)
  28. elseif r == 3 then
  29. y = image.hflip(image.vflip(y))
  30. elseif r == 4 then
  31. -- none
  32. end
  33. if options.color_augment then
  34. y = y:float():div(255)
  35. local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
  36. for i = 1, 3 do
  37. y[i]:mul(color_scale[i])
  38. end
  39. y[torch.lt(y, 0)] = 0
  40. y[torch.gt(y, 1.0)] = 1.0
  41. y = y:mul(255):byte()
  42. end
  43. local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
  44. if options.noise and (options.noise_ratio or 0.5) > torch.uniform() then
  45. -- add noise
  46. local quality = {torch.random(70, 90)}
  47. for i = 1, #quality do
  48. x = gm.Image(x, "RGB", "DHW")
  49. x:format("jpeg")
  50. local blob, len = x:toBlob(quality[i])
  51. x:fromBlob(blob, len)
  52. x = x:toTensor("byte", "RGB", "DHW")
  53. end
  54. end
  55. if options.denoise_model and (options.denoise_ratio or 0.5) > torch.uniform() then
  56. x = reconstract(options.denoise_model, x:float():div(255), offset):mul(255):byte()
  57. end
  58. x = iproc.scale(x, y:size(3), y:size(2))
  59. y = y:float():div(255)
  60. x = x:float():div(255)
  61. y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
  62. x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
  63. return x, image.crop(y, offset, offset, size - offset, size - offset)
  64. end
  65. function pairwise_transform.jpeg_(src, quality, size, offset, color_augment)
  66. if color_augment == nil then color_augment = true end
  67. local yi = torch.random(0, src:size(2) - size - 1)
  68. local xi = torch.random(0, src:size(3) - size - 1)
  69. local y = src
  70. local x
  71. local flip = torch.random(1, 4)
  72. if color_augment then
  73. local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
  74. y = y:float():div(255)
  75. for i = 1, 3 do
  76. y[i]:mul(color_scale[i])
  77. end
  78. y[torch.lt(y, 0)] = 0
  79. y[torch.gt(y, 1.0)] = 1.0
  80. y = y:mul(255):byte()
  81. end
  82. x = y
  83. for i = 1, #quality do
  84. x = gm.Image(x, "RGB", "DHW")
  85. x:format("jpeg")
  86. local blob, len = x:toBlob(quality[i])
  87. x:fromBlob(blob, len)
  88. x = x:toTensor("byte", "RGB", "DHW")
  89. end
  90. y = image.crop(y, xi, yi, xi + size, yi + size)
  91. x = image.crop(x, xi, yi, xi + size, yi + size)
  92. x = x:float():div(255)
  93. y = y:float():div(255)
  94. if flip == 1 then
  95. y = image.hflip(y)
  96. x = image.hflip(x)
  97. elseif flip == 2 then
  98. y = image.vflip(y)
  99. x = image.vflip(x)
  100. elseif flip == 3 then
  101. y = image.hflip(image.vflip(y))
  102. x = image.hflip(image.vflip(x))
  103. elseif flip == 4 then
  104. -- none
  105. end
  106. y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
  107. x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
  108. return x, image.crop(y, offset, offset, size - offset, size - offset)
  109. end
  110. function pairwise_transform.jpeg(src, level, size, offset, color_augment)
  111. if level == 1 then
  112. return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
  113. size, offset,
  114. color_augment)
  115. elseif level == 2 then
  116. local r = torch.uniform()
  117. if r > 0.6 then
  118. return pairwise_transform.jpeg_(src, {torch.random(27, 80)},
  119. size, offset,
  120. color_augment)
  121. elseif r > 0.3 then
  122. local quality1 = torch.random(32, 40)
  123. local quality2 = quality1 - 5
  124. return pairwise_transform.jpeg_(src, {quality1, quality2},
  125. size, offset,
  126. color_augment)
  127. else
  128. local quality1 = torch.random(47, 70)
  129. return pairwise_transform.jpeg_(src, {quality1, quality1 - 10, quality1 - 20},
  130. size, offset,
  131. color_augment)
  132. end
  133. else
  134. error("unknown noise level: " .. level)
  135. end
  136. end
  137. local function test_jpeg()
  138. local loader = require 'image_loader'
  139. local src = loader.load_byte("a.jpg")
  140. for i = 2, 9 do
  141. local y, x = pairwise_transform.jpeg_(src, {i * 10}, 128, 0, false)
  142. image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
  143. image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
  144. --print(x:mean(), y:mean())
  145. end
  146. end
  147. local function test_scale()
  148. require 'nn'
  149. require 'cudnn'
  150. require './LeakyReLU'
  151. local loader = require 'image_loader'
  152. local src = loader.load_byte("e.jpg")
  153. for i = 1, 9 do
  154. local y, x = pairwise_transform.scale(src, 2.0, "Box", 128, 7, {noise = true, denoise_model = torch.load("models/noise1_model.t7")})
  155. image.display({image = y, legend = "y:" .. (i * 10)})
  156. image.display({image = x, legend = "x:" .. (i * 10)})
  157. --print(x:mean(), y:mean())
  158. end
  159. end
  160. --test_jpeg()
  161. --test_scale()
  162. return pairwise_transform