iproc.lua 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. local gm = require 'graphicsmagick'
  2. local image = require 'image'
  3. local iproc = {}
  4. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  5. function iproc.crop_mod4(src)
  6. local w = src:size(3) % 4
  7. local h = src:size(2) % 4
  8. return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h)
  9. end
  10. function iproc.crop(src, w1, h1, w2, h2)
  11. local dest
  12. if src:dim() == 3 then
  13. dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
  14. else -- dim == 2
  15. dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
  16. end
  17. return dest
  18. end
  19. function iproc.crop_nocopy(src, w1, h1, w2, h2)
  20. local dest
  21. if src:dim() == 3 then
  22. dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
  23. else -- dim == 2
  24. dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
  25. end
  26. return dest
  27. end
  28. function iproc.byte2float(src)
  29. local conversion = false
  30. local dest = src
  31. if src:type() == "torch.ByteTensor" then
  32. conversion = true
  33. dest = src:float():div(255.0)
  34. end
  35. return dest, conversion
  36. end
  37. function iproc.float2byte(src)
  38. local conversion = false
  39. local dest = src
  40. if src:type() == "torch.FloatTensor" then
  41. conversion = true
  42. dest = (src + clip_eps8):mul(255.0)
  43. dest[torch.lt(dest, 0.0)] = 0
  44. dest[torch.gt(dest, 255.0)] = 255.0
  45. dest = dest:byte()
  46. end
  47. return dest, conversion
  48. end
  49. function iproc.scale(src, width, height, filter)
  50. local conversion, color
  51. src, conversion = iproc.byte2float(src)
  52. filter = filter or "Box"
  53. if src:size(1) == 3 then
  54. color = "RGB"
  55. else
  56. color = "I"
  57. end
  58. local im = gm.Image(src, color, "DHW")
  59. im:size(math.ceil(width), math.ceil(height), filter)
  60. local dest = im:toTensor("float", color, "DHW")
  61. if conversion then
  62. dest = iproc.float2byte(dest)
  63. end
  64. return dest
  65. end
  66. function iproc.scale_with_gamma22(src, width, height, filter)
  67. local conversion
  68. src, conversion = iproc.byte2float(src)
  69. filter = filter or "Box"
  70. local im = gm.Image(src, "RGB", "DHW")
  71. im:gammaCorrection(1.0 / 2.2):
  72. size(math.ceil(width), math.ceil(height), filter):
  73. gammaCorrection(2.2)
  74. local dest = im:toTensor("float", "RGB", "DHW")
  75. if conversion then
  76. dest = iproc.float2byte(dest)
  77. end
  78. return dest
  79. end
  80. function iproc.padding(img, w1, w2, h1, h2)
  81. local dst_height = img:size(2) + h1 + h2
  82. local dst_width = img:size(3) + w1 + w2
  83. local flow = torch.Tensor(2, dst_height, dst_width)
  84. flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
  85. flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
  86. flow[1]:add(-h1)
  87. flow[2]:add(-w1)
  88. return image.warp(img, flow, "simple", false, "clamp")
  89. end
  90. function iproc.zero_padding(img, w1, w2, h1, h2)
  91. local dst_height = img:size(2) + h1 + h2
  92. local dst_width = img:size(3) + w1 + w2
  93. local flow = torch.Tensor(2, dst_height, dst_width)
  94. flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
  95. flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
  96. flow[1]:add(-h1)
  97. flow[2]:add(-w1)
  98. return image.warp(img, flow, "simple", false, "pad", 0)
  99. end
  100. function iproc.white_noise(src, std, rgb_weights, gamma)
  101. gamma = gamma or 0.454545
  102. local conversion
  103. src, conversion = iproc.byte2float(src)
  104. std = std or 0.01
  105. local noise = torch.Tensor():resizeAs(src):normal(0, std)
  106. if rgb_weights then
  107. noise[1]:mul(rgb_weights[1])
  108. noise[2]:mul(rgb_weights[2])
  109. noise[3]:mul(rgb_weights[3])
  110. end
  111. local dest
  112. if gamma ~= 0 then
  113. dest = src:clone():pow(gamma):add(noise)
  114. dest[torch.lt(dest, 0.0)] = 0.0
  115. dest[torch.gt(dest, 1.0)] = 1.0
  116. dest:pow(1.0 / gamma)
  117. else
  118. dest = src + noise
  119. end
  120. if conversion then
  121. dest = iproc.float2byte(dest)
  122. end
  123. return dest
  124. end
  125. local function test_conversion()
  126. local a = torch.linspace(0, 255, 256):float():div(255.0)
  127. local b = iproc.float2byte(a)
  128. local c = iproc.byte2float(a)
  129. local d = torch.linspace(0, 255, 256)
  130. assert((a - c):abs():sum() == 0)
  131. assert((d:float() - b:float()):abs():sum() == 0)
  132. a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0)
  133. b = iproc.float2byte(a)
  134. assert(b:float():sum() == 255.0 * 3)
  135. a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0)
  136. b = iproc.float2byte(a)
  137. print(b)
  138. assert(b:float():sum() == 254.0 * 3)
  139. end
  140. --test_conversion()
  141. return iproc