make_cunet.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Generate prototxt of waifu2x's cunet/upcunet arch. Training is not possible.
  2. from __future__ import print_function
  3. import sys
  4. sys.path.insert(0, "../python") # pycaffe path
  5. from caffe import layers as L, params as P, to_proto
  6. from caffe.proto import caffe_pb2
  7. import caffe
  8. def seblock(bottom, o, r):
  9. m = int(o / r)
  10. gap = L.Pooling(bottom, pool=P.Pooling.AVE, global_pooling=True)
  11. linear1 = L.Convolution(gap, kernel_size=1, pad=0, stride=1, num_output=m)
  12. relu1 = L.ReLU(linear1, in_place=True)
  13. linear2 = L.Convolution(relu1, kernel_size=1, pad=0, stride=1, num_output=o)
  14. sigmoid1 = L.Sigmoid(linear2, in_place=True)
  15. flatten1 = L.Flatten(sigmoid1)
  16. return flatten1
  17. def conv_relu(bottom, o, stride=1, pad=0):
  18. conv = L.Convolution(bottom, kernel_size=3, stride=stride, num_output=o, pad=pad)
  19. relu = L.ReLU(conv, in_place=True, negative_slope=0.1)
  20. return relu
  21. def unet_conv(bottom, m, o, se=True):
  22. conv1 = L.Convolution(bottom, kernel_size=3, stride=1,
  23. num_output=m, pad=0)
  24. relu1 = L.ReLU(conv1, in_place=True, negative_slope=0.1)
  25. conv2 = L.Convolution(relu1, kernel_size=3, stride=1,
  26. num_output=o, pad=0)
  27. relu2 = L.ReLU(conv2, in_place=True, negative_slope=0.1)
  28. if se:
  29. se1 = seblock(relu2, o, 8)
  30. return L.Scale(relu2, se1, axis=0, bias_term=False)
  31. else:
  32. return relu2
  33. def unet_branch(bottom, insert_f, i, o, depad):
  34. pool = L.Convolution(bottom, kernel_size=2, stride=2, num_output=i, pad=0)
  35. relu1 = L.ReLU(pool, in_place=True, negative_slope=0.1)
  36. feat = insert_f(relu1)
  37. unpool = L.Deconvolution(feat, convolution_param=dict(num_output=o, kernel_size=2, pad=0, stride=2))
  38. relu2 = L.ReLU(unpool, in_place=True, negative_slope=0.1)
  39. crop = L.Crop(bottom, relu2, crop_param=dict(axis=2, offset=depad))
  40. cadd = L.Eltwise(crop, relu2, operation=P.Eltwise.SUM)
  41. return cadd
  42. def unet1(bottom, ch, deconv):
  43. block1 = lambda bottom: unet_conv(bottom, 128, 64, True)
  44. conv1 = unet_conv(bottom, 32, 64, se=False)
  45. ub1 = unet_branch(conv1, block1, 64, 64, 4)
  46. conv2 = conv_relu(ub1, 64)
  47. if deconv:
  48. return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
  49. else:
  50. return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
  51. def unet2(bottom, ch, deconv):
  52. def block1(bottom):
  53. return unet_conv(bottom, 256, 128, se=True)
  54. def block2(bottom):
  55. conv1 = unet_conv(bottom, 64, 128, se=True)
  56. ub1 = unet_branch(conv1, block1, 128, 128, 4)
  57. conv2 = unet_conv(ub1, 64, 64, se=True)
  58. return conv2
  59. conv1 = unet_conv(bottom, 32, 64, se=False)
  60. ub1 = unet_branch(conv1, block2, 64, 64, 16)
  61. conv2 = conv_relu(ub1, 64)
  62. if deconv:
  63. return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
  64. else:
  65. return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
  66. def make_upcunet():
  67. netoffset = 36
  68. ch = 3
  69. input_size = (256 / 2) + netoffset * 2
  70. assert(input_size % 4 == 0)
  71. data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
  72. u1 = unet1(data, ch=ch, deconv=True)
  73. u2 = unet2(u1, ch=ch, deconv=False)
  74. crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
  75. cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
  76. return to_proto(cadd)
  77. def make_cunet():
  78. netoffset = 28
  79. ch = 3
  80. input_size = 256 + netoffset * 2
  81. assert(input_size % 4 == 0)
  82. data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
  83. u1 = unet1(data, ch=ch, deconv=False)
  84. u2 = unet2(u1, ch=ch, deconv=False)
  85. crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
  86. cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
  87. return to_proto(cadd)
  88. def make_net():
  89. with open('upcunet.prototxt', 'w') as f:
  90. print(make_upcunet(), file=f)
  91. with open('cunet.prototxt', 'w') as f:
  92. print(make_cunet(), file=f)
  93. if __name__ == '__main__':
  94. make_net()
  95. # test loading the net
  96. caffe.Net('upcunet.prototxt', caffe.TEST)
  97. caffe.Net('cunet.prototxt', caffe.TEST)