iproc.lua 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. local gm = {}
  2. gm.Image = require 'graphicsmagick.Image'
  3. local image = nil
  4. local iproc = {}
  5. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  6. function iproc.crop_mod4(src)
  7. local w = src:size(3) % 4
  8. local h = src:size(2) % 4
  9. return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h)
  10. end
  11. function iproc.crop(src, w1, h1, w2, h2)
  12. local dest
  13. if src:dim() == 3 then
  14. dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
  15. else -- dim == 2
  16. dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
  17. end
  18. return dest
  19. end
  20. function iproc.crop_nocopy(src, w1, h1, w2, h2)
  21. local dest
  22. if src:dim() == 3 then
  23. dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
  24. else -- dim == 2
  25. dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
  26. end
  27. return dest
  28. end
  29. function iproc.byte2float(src)
  30. local conversion = false
  31. local dest = src
  32. if src:type() == "torch.ByteTensor" then
  33. conversion = true
  34. dest = src:float():div(255.0)
  35. end
  36. return dest, conversion
  37. end
  38. function iproc.float2byte(src)
  39. local conversion = false
  40. local dest = src
  41. if src:type() == "torch.FloatTensor" then
  42. conversion = true
  43. dest = (src + clip_eps8):mul(255.0)
  44. dest[torch.lt(dest, 0.0)] = 0
  45. dest[torch.gt(dest, 255.0)] = 255.0
  46. dest = dest:byte()
  47. end
  48. return dest, conversion
  49. end
  50. function iproc.scale(src, width, height, filter, blur)
  51. local conversion, color
  52. src, conversion = iproc.byte2float(src)
  53. filter = filter or "Box"
  54. if src:size(1) == 3 then
  55. color = "RGB"
  56. else
  57. color = "I"
  58. end
  59. local im = gm.Image(src, color, "DHW")
  60. im:size(math.ceil(width), math.ceil(height), filter, blur)
  61. local dest = im:toTensor("float", color, "DHW")
  62. if conversion then
  63. dest = iproc.float2byte(dest)
  64. end
  65. return dest
  66. end
  67. function iproc.scale_with_gamma22(src, width, height, filter, blur)
  68. local conversion
  69. src, conversion = iproc.byte2float(src)
  70. filter = filter or "Box"
  71. local im = gm.Image(src, "RGB", "DHW")
  72. im:gammaCorrection(1.0 / 2.2):
  73. size(math.ceil(width), math.ceil(height), filter, blur):
  74. gammaCorrection(2.2)
  75. local dest = im:toTensor("float", "RGB", "DHW"):clamp(0.0, 1.0)
  76. if conversion then
  77. dest = iproc.float2byte(dest)
  78. end
  79. return dest
  80. end
  81. function iproc.padding(img, w1, w2, h1, h2)
  82. image = image or require 'image'
  83. local dst_height = img:size(2) + h1 + h2
  84. local dst_width = img:size(3) + w1 + w2
  85. local flow = torch.Tensor(2, dst_height, dst_width)
  86. flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
  87. flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
  88. flow[1]:add(-h1)
  89. flow[2]:add(-w1)
  90. return image.warp(img, flow, "simple", false, "clamp")
  91. end
  92. function iproc.zero_padding(img, w1, w2, h1, h2)
  93. image = image or require 'image'
  94. local dst_height = img:size(2) + h1 + h2
  95. local dst_width = img:size(3) + w1 + w2
  96. local flow = torch.Tensor(2, dst_height, dst_width)
  97. flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
  98. flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
  99. flow[1]:add(-h1)
  100. flow[2]:add(-w1)
  101. return image.warp(img, flow, "simple", false, "pad", 0)
  102. end
  103. function iproc.white_noise(src, std, rgb_weights, gamma)
  104. gamma = gamma or 0.454545
  105. local conversion
  106. src, conversion = iproc.byte2float(src)
  107. std = std or 0.01
  108. local noise = torch.Tensor():resizeAs(src):normal(0, std)
  109. if rgb_weights then
  110. noise[1]:mul(rgb_weights[1])
  111. noise[2]:mul(rgb_weights[2])
  112. noise[3]:mul(rgb_weights[3])
  113. end
  114. local dest
  115. if gamma ~= 0 then
  116. dest = src:clone():pow(gamma):add(noise)
  117. dest[torch.lt(dest, 0.0)] = 0.0
  118. dest[torch.gt(dest, 1.0)] = 1.0
  119. dest:pow(1.0 / gamma)
  120. else
  121. dest = src + noise
  122. end
  123. if conversion then
  124. dest = iproc.float2byte(dest)
  125. end
  126. return dest
  127. end
  128. function iproc.hflip(src)
  129. local t
  130. if src:type() == "torch.ByteTensor" then
  131. t = "byte"
  132. else
  133. t = "float"
  134. end
  135. if src:size(1) == 3 then
  136. color = "RGB"
  137. else
  138. color = "I"
  139. end
  140. local im = gm.Image(src, color, "DHW")
  141. return im:flop():toTensor(t, color, "DHW")
  142. end
  143. function iproc.vflip(src)
  144. local t
  145. if src:type() == "torch.ByteTensor" then
  146. t = "byte"
  147. else
  148. t = "float"
  149. end
  150. if src:size(1) == 3 then
  151. color = "RGB"
  152. else
  153. color = "I"
  154. end
  155. local im = gm.Image(src, color, "DHW")
  156. return im:flip():toTensor(t, color, "DHW")
  157. end
  158. -- from torch/image
  159. ----------------------------------------------------------------------
  160. -- image.rgb2yuv(image)
  161. -- converts a RGB image to YUV
  162. --
  163. function iproc.rgb2yuv(...)
  164. -- arg check
  165. local output,input
  166. local args = {...}
  167. if select('#',...) == 2 then
  168. output = args[1]
  169. input = args[2]
  170. elseif select('#',...) == 1 then
  171. input = args[1]
  172. else
  173. print(dok.usage('image.rgb2yuv',
  174. 'transforms an image from RGB to YUV', nil,
  175. {type='torch.Tensor', help='input image', req=true},
  176. '',
  177. {type='torch.Tensor', help='output image', req=true},
  178. {type='torch.Tensor', help='input image', req=true}
  179. ))
  180. dok.error('missing input', 'image.rgb2yuv')
  181. end
  182. -- resize
  183. output = output or input.new()
  184. output:resizeAs(input)
  185. -- input chanels
  186. local inputRed = input[1]
  187. local inputGreen = input[2]
  188. local inputBlue = input[3]
  189. -- output chanels
  190. local outputY = output[1]
  191. local outputU = output[2]
  192. local outputV = output[3]
  193. -- convert
  194. outputY:zero():add(0.299, inputRed):add(0.587, inputGreen):add(0.114, inputBlue)
  195. outputU:zero():add(-0.14713, inputRed):add(-0.28886, inputGreen):add(0.436, inputBlue)
  196. outputV:zero():add(0.615, inputRed):add(-0.51499, inputGreen):add(-0.10001, inputBlue)
  197. -- return YUV image
  198. return output
  199. end
  200. ----------------------------------------------------------------------
  201. -- image.yuv2rgb(image)
  202. -- converts a YUV image to RGB
  203. --
  204. function iproc.yuv2rgb(...)
  205. -- arg check
  206. local output,input
  207. local args = {...}
  208. if select('#',...) == 2 then
  209. output = args[1]
  210. input = args[2]
  211. elseif select('#',...) == 1 then
  212. input = args[1]
  213. else
  214. print(dok.usage('image.yuv2rgb',
  215. 'transforms an image from YUV to RGB', nil,
  216. {type='torch.Tensor', help='input image', req=true},
  217. '',
  218. {type='torch.Tensor', help='output image', req=true},
  219. {type='torch.Tensor', help='input image', req=true}
  220. ))
  221. dok.error('missing input', 'image.yuv2rgb')
  222. end
  223. -- resize
  224. output = output or input.new()
  225. output:resizeAs(input)
  226. -- input chanels
  227. local inputY = input[1]
  228. local inputU = input[2]
  229. local inputV = input[3]
  230. -- output chanels
  231. local outputRed = output[1]
  232. local outputGreen = output[2]
  233. local outputBlue = output[3]
  234. -- convert
  235. outputRed:copy(inputY):add(1.13983, inputV)
  236. outputGreen:copy(inputY):add(-0.39465, inputU):add(-0.58060, inputV)
  237. outputBlue:copy(inputY):add(2.03211, inputU)
  238. -- return RGB image
  239. return output
  240. end
  241. local function test_conversion()
  242. local a = torch.linspace(0, 255, 256):float():div(255.0)
  243. local b = iproc.float2byte(a)
  244. local c = iproc.byte2float(a)
  245. local d = torch.linspace(0, 255, 256)
  246. assert((a - c):abs():sum() == 0)
  247. assert((d:float() - b:float()):abs():sum() == 0)
  248. a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0)
  249. b = iproc.float2byte(a)
  250. assert(b:float():sum() == 255.0 * 3)
  251. a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0)
  252. b = iproc.float2byte(a)
  253. print(b)
  254. assert(b:float():sum() == 254.0 * 3)
  255. end
  256. local function test_flip()
  257. require 'sys'
  258. require 'torch'
  259. torch.setdefaulttensortype("torch.FloatTensor")
  260. image = require 'image'
  261. local src = image.lena()
  262. local src_byte = src:clone():mul(255):byte()
  263. print(src:size())
  264. print((image.hflip(src) - iproc.hflip(src)):sum())
  265. print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum())
  266. print((image.vflip(src) - iproc.vflip(src)):sum())
  267. print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
  268. end
  269. --test_conversion()
  270. --test_flip()
  271. return iproc