alpha_util.lua 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. local w2nn = require 'w2nn'
  2. local reconstruct = require 'reconstruct'
  3. local image = require 'image'
  4. local iproc = require 'iproc'
  5. local gm = require 'graphicsmagick'
  6. alpha_util = {}
  7. function alpha_util.make_border(rgb, alpha, offset)
  8. if not alpha then
  9. return rgb
  10. end
  11. local sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda()
  12. sum2d.weight:fill(1)
  13. sum2d.bias:zero()
  14. local mask = alpha:clone()
  15. mask[torch.gt(mask, 0.0)] = 1
  16. mask[torch.eq(mask, 0.0)] = 0
  17. local mask_nega = (mask - 1):abs():byte()
  18. local eps = 1.0e-7
  19. rgb = rgb:clone()
  20. rgb[1][mask_nega] = 0
  21. rgb[2][mask_nega] = 0
  22. rgb[3][mask_nega] = 0
  23. for i = 1, offset do
  24. local mask_weight = sum2d:forward(mask:cuda()):float()
  25. local border = rgb:clone()
  26. for j = 1, 3 do
  27. border[j]:copy(sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda()))
  28. border[j]:cdiv((mask_weight + eps))
  29. rgb[j][mask_nega] = border[j][mask_nega]
  30. end
  31. mask = mask_weight:clone()
  32. mask[torch.gt(mask_weight, 0.0)] = 1
  33. mask_nega = (mask - 1):abs():byte()
  34. if border:size(2) * border:size(3) > 1024*1024 then
  35. collectgarbage()
  36. end
  37. end
  38. rgb[torch.gt(rgb, 1.0)] = 1.0
  39. rgb[torch.lt(rgb, 0.0)] = 0.0
  40. return rgb
  41. end
  42. function alpha_util.composite(rgb, alpha, model2x)
  43. if not alpha then
  44. return rgb
  45. end
  46. if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
  47. if model2x then
  48. alpha = reconstruct.scale(model2x, 2.0, alpha)
  49. else
  50. alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
  51. end
  52. end
  53. local out = torch.Tensor(4, rgb:size(2), rgb:size(3))
  54. out[1]:copy(rgb[1])
  55. out[2]:copy(rgb[2])
  56. out[3]:copy(rgb[3])
  57. out[4]:copy(alpha)
  58. return out
  59. end
  60. function alpha_util.fill(fg, alpha, val)
  61. assert(fg:size(2) == alpha:size(2) and fg:size(3) == alpha:size(3))
  62. local conversion = false
  63. fg, conversion = iproc.byte2float(fg)
  64. val = val or 0
  65. fg = fg:clone()
  66. bg = fg:clone():fill(val)
  67. bg[1]:cmul(1-alpha)
  68. bg[2]:cmul(1-alpha)
  69. bg[3]:cmul(1-alpha)
  70. fg[1]:cmul(alpha)
  71. fg[2]:cmul(alpha)
  72. fg[3]:cmul(alpha)
  73. local ret = bg:add(fg)
  74. if conversion then
  75. ret = iproc.float2byte(ret)
  76. end
  77. return ret
  78. end
  79. local function test()
  80. require 'sys'
  81. require 'trepl'
  82. torch.setdefaulttensortype("torch.FloatTensor")
  83. local image_loader = require 'image_loader'
  84. local rgb, alpha = image_loader.load_float("alpha.png")
  85. local t = sys.clock()
  86. rgb = alpha_util.make_border(rgb, alpha, 7)
  87. print(sys.clock() - t)
  88. print(rgb:min(), rgb:max())
  89. image.display({image = rgb, min = 0, max = 1})
  90. image.save("out.png", rgb)
  91. end
  92. --test()
  93. return alpha_util