benchmark.lua 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. require './lib/portable'
  2. require './lib/mynn'
  3. require 'xlua'
  4. require 'pl'
  5. local iproc = require './lib/iproc'
  6. local reconstruct = require './lib/reconstruct'
  7. local image_loader = require './lib/image_loader'
  8. local gm = require 'graphicsmagick'
  9. local cmd = torch.CmdLine()
  10. cmd:text()
  11. cmd:text("waifu2x-benchmark")
  12. cmd:text("Options:")
  13. cmd:option("-seed", 11, 'fixed input seed')
  14. cmd:option("-test_dir", "./test", 'test image directory')
  15. cmd:option("-jpeg_quality", 50, 'jpeg quality')
  16. cmd:option("-jpeg_times", 3, 'number of jpeg compression ')
  17. cmd:option("-jpeg_quality_down", 5, 'reducing jpeg quality each times')
  18. cmd:option("-core", 4, 'threads')
  19. local opt = cmd:parse(arg)
  20. torch.setnumthreads(opt.core)
  21. torch.setdefaulttensortype('torch.FloatTensor')
  22. local function MSE(x1, x2)
  23. return (x1 - x2):pow(2):mean()
  24. end
  25. local function YMSE(x1, x2)
  26. local x1_2 = x1:clone()
  27. local x2_2 = x2:clone()
  28. x1_2[1]:mul(0.299 * 3)
  29. x1_2[2]:mul(0.587 * 3)
  30. x1_2[3]:mul(0.114 * 3)
  31. x2_2[1]:mul(0.299 * 3)
  32. x2_2[2]:mul(0.587 * 3)
  33. x2_2[3]:mul(0.114 * 3)
  34. return (x1_2 - x2_2):pow(2):mean()
  35. end
  36. local function PSNR(x1, x2)
  37. local mse = MSE(x1, x2)
  38. return 20 * (math.log(1.0 / math.sqrt(mse)) / math.log(10))
  39. end
  40. local function YPSNR(x1, x2)
  41. local mse = YMSE(x1, x2)
  42. return 20 * (math.log((0.587 * 3) / math.sqrt(mse)) / math.log(10))
  43. end
  44. local function transform_jpeg(x)
  45. for i = 1, opt.jpeg_times do
  46. jpeg = gm.Image(x, "RGB", "DHW")
  47. jpeg:format("jpeg")
  48. jpeg:samplingFactors({1.0, 1.0, 1.0})
  49. blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
  50. jpeg:fromBlob(blob, len)
  51. x = jpeg:toTensor("byte", "RGB", "DHW")
  52. end
  53. return x
  54. end
  55. local function noise_benchmark(x, v1_noise, v2_noise)
  56. local v1_mse = 0
  57. local v2_mse = 0
  58. local jpeg_mse = 0
  59. local v1_psnr = 0
  60. local v2_psnr = 0
  61. local jpeg_psnr = 0
  62. local v1_time = 0
  63. local v2_time = 0
  64. for i = 1, #x do
  65. local ground_truth = x[i]
  66. local jpg, blob, len, input, v1_out, v2_out, t, mse
  67. input = transform_jpeg(ground_truth)
  68. input = input:float():div(255)
  69. ground_truth = ground_truth:float():div(255)
  70. jpeg_mse = jpeg_mse + MSE(ground_truth, input)
  71. jpeg_psnr = jpeg_psnr + PSNR(ground_truth, input)
  72. t = sys.clock()
  73. v1_output = reconstruct.image(v1_noise, input)
  74. v1_time = v1_time + (sys.clock() - t)
  75. v1_mse = v1_mse + MSE(ground_truth, v1_output)
  76. v1_psnr = v1_psnr + PSNR(ground_truth, v1_output)
  77. t = sys.clock()
  78. v2_output = reconstruct.image(v2_noise, input)
  79. v2_time = v2_time + (sys.clock() - t)
  80. v2_mse = v2_mse + MSE(ground_truth, v2_output)
  81. v2_psnr = v2_psnr + PSNR(ground_truth, v2_output)
  82. io.stdout:write(
  83. string.format("%d/%d; v1_time=%f, v2_time=%f, jpeg_mse=%f, v1_mse=%f, v2_mse=%f, jpeg_psnr=%f, v1_psnr=%f, v2_psnr=%f \r",
  84. i, #x,
  85. v1_time / i, v2_time / i,
  86. jpeg_mse / i,
  87. v1_mse / i, v2_mse / i,
  88. jpeg_psnr / i,
  89. v1_psnr / i, v2_psnr / i
  90. )
  91. )
  92. io.stdout:flush()
  93. end
  94. io.stdout:write("\n")
  95. end
  96. local function noise_scale_benchmark(x, params, v1_noise, v1_scale, v2_noise, v2_scale)
  97. local v1_mse = 0
  98. local v2_mse = 0
  99. local jinc_mse = 0
  100. local v1_time = 0
  101. local v2_time = 0
  102. for i = 1, #x do
  103. local ground_truth = x[i]
  104. local downscale = iproc.scale(ground_truth,
  105. ground_truth:size(3) * 0.5,
  106. ground_truth:size(2) * 0.5,
  107. params[i].filter)
  108. local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
  109. jpeg = gm.Image(downscale, "RGB", "DHW")
  110. jpeg:format("jpeg")
  111. blob, len = jpeg:toBlob(params[i].quality)
  112. jpeg:fromBlob(blob, len)
  113. input = jpeg:toTensor("byte", "RGB", "DHW")
  114. input = input:float():div(255)
  115. ground_truth = ground_truth:float():div(255)
  116. jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
  117. jinc_mse = jinc_mse + (ground_truth - jinc_output):pow(2):mean()
  118. t = sys.clock()
  119. v1_output = reconstruct.image(v1_noise, input)
  120. v1_output = reconstruct.scale(v1_scale, 2.0, v1_output)
  121. v1_time = v1_time + (sys.clock() - t)
  122. mse = (ground_truth - v1_output):pow(2):mean()
  123. v1_mse = v1_mse + mse
  124. t = sys.clock()
  125. v2_output = reconstruct.image(v2_noise, input)
  126. v2_output = reconstruct.scale(v2_scale, 2.0, v2_output)
  127. v2_time = v2_time + (sys.clock() - t)
  128. mse = (ground_truth - v2_output):pow(2):mean()
  129. v2_mse = v2_mse + mse
  130. io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
  131. i, #x,
  132. v1_time / i, v2_time / i,
  133. (v1_time / i) / (v2_time / i),
  134. jinc_mse / i,
  135. v1_mse / i, (v1_mse/i) / (jinc_mse/i),
  136. v2_mse / i, (v2_mse/i) / (jinc_mse/i),
  137. (v1_mse / i) / (v2_mse / i)))
  138. io.stdout:flush()
  139. end
  140. io.stdout:write("\n")
  141. end
  142. local function scale_benchmark(x, params, v1_scale, v2_scale)
  143. local v1_mse = 0
  144. local v2_mse = 0
  145. local jinc_mse = 0
  146. local v1_psnr = 0
  147. local v2_psnr = 0
  148. local jinc_psnr = 0
  149. local v1_time = 0
  150. local v2_time = 0
  151. for i = 1, #x do
  152. local ground_truth = x[i]
  153. local downscale = iproc.scale(ground_truth,
  154. ground_truth:size(3) * 0.5,
  155. ground_truth:size(2) * 0.5,
  156. params[i].filter)
  157. local jpg, blob, len, input, v1_output, v2_output, jinc_output, t, mse
  158. input = downscale
  159. input = input:float():div(255)
  160. ground_truth = ground_truth:float():div(255)
  161. jinc_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, "Jinc")
  162. mse = (ground_truth - jinc_output):pow(2):mean()
  163. jinc_mse = jinc_mse + mse
  164. jinc_psnr = jinc_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
  165. t = sys.clock()
  166. v1_output = reconstruct.scale(v1_scale, 2.0, input)
  167. v1_time = v1_time + (sys.clock() - t)
  168. mse = (ground_truth - v1_output):pow(2):mean()
  169. v1_mse = v1_mse + mse
  170. v1_psnr = v1_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
  171. t = sys.clock()
  172. v2_output = reconstruct.scale(v2_scale, 2.0, input)
  173. v2_time = v2_time + (sys.clock() - t)
  174. mse = (ground_truth - v2_output):pow(2):mean()
  175. v2_mse = v2_mse + mse
  176. v2_psnr = v2_psnr + (10 * (math.log(1.0 / mse) / math.log(10)))
  177. io.stdout:write(string.format("%d/%d; time: v1=%f, v2=%f, v1/v2=%f; mse: jinc=%f, v1=%f(%f), v2=%f(%f), v1/v2=%f \r",
  178. i, #x,
  179. v1_time / i, v2_time / i,
  180. (v1_time / i) / (v2_time / i),
  181. jinc_psnr / i,
  182. v1_psnr / i, (v1_psnr/i) / (jinc_psnr/i),
  183. v2_psnr / i, (v2_psnr/i) / (jinc_psnr/i),
  184. (v1_psnr / i) / (v2_psnr / i)))
  185. io.stdout:flush()
  186. end
  187. io.stdout:write("\n")
  188. end
  189. local function split_data(x, test_size)
  190. local index = torch.randperm(#x)
  191. local train_size = #x - test_size
  192. local train_x = {}
  193. local valid_x = {}
  194. for i = 1, train_size do
  195. train_x[i] = x[index[i]]
  196. end
  197. for i = 1, test_size do
  198. valid_x[i] = x[index[train_size + i]]
  199. end
  200. return train_x, valid_x
  201. end
  202. local function crop_4x(x)
  203. local w = x:size(3) % 4
  204. local h = x:size(2) % 4
  205. return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
  206. end
  207. local function load_data(valid_dir)
  208. local valid_x = {}
  209. local files = dir.getfiles(valid_dir, "*.png")
  210. for i = 1, #files do
  211. table.insert(valid_x, crop_4x(image_loader.load_byte(files[i])))
  212. xlua.progress(i, #files)
  213. end
  214. return valid_x
  215. end
  216. local function noise_main(valid_dir, level)
  217. local v1_noise = torch.load(path.join(V1_DIR, string.format("noise%d_model.t7", level)), "ascii")
  218. local v2_noise = torch.load(path.join(V2_DIR, string.format("noise%d_model.t7", level)), "ascii")
  219. local valid_x = load_data(valid_dir)
  220. noise_benchmark(valid_x, v1_noise, v2_noise)
  221. end
  222. local function scale_main(valid_dir)
  223. local v1 = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
  224. local v2 = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
  225. local valid_x = load_data(valid_dir)
  226. local params = random_params(valid_x, 2)
  227. scale_benchmark(valid_x, params, v1, v2)
  228. end
  229. local function noise_scale_main(valid_dir)
  230. local v1_noise = torch.load(path.join(V1_DIR, "noise2_model.t7"), "ascii")
  231. local v1_scale = torch.load(path.join(V1_DIR, "scale2.0x_model.t7"), "ascii")
  232. local v2_noise = torch.load(path.join(V2_DIR, "noise2_model.t7"), "ascii")
  233. local v2_scale = torch.load(path.join(V2_DIR, "scale2.0x_model.t7"), "ascii")
  234. local valid_x = load_data(valid_dir)
  235. local params = random_params(valid_x, 2)
  236. noise_scale_benchmark(valid_x, params, v1_noise, v1_scale, v2_noise, v2_scale)
  237. end
  238. V1_DIR = "models/anime_style_art_rgb"
  239. V2_DIR = "models/anime_style_art_rgb5"
  240. torch.manualSeed(opt.seed)
  241. cutorch.manualSeed(opt.seed)
  242. noise_main("./test", 2)
  243. --scale_main("./test")
  244. --noise_scale_main("./test")