123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Generate prototxt of waifu2x's cunet/upcunet arch. Training is not possible.
- from __future__ import print_function
- import sys
- sys.path.insert(0, "../python") # pycaffe path
- from caffe import layers as L, params as P, to_proto
- from caffe.proto import caffe_pb2
- import caffe
- def seblock(bottom, o, r):
- m = int(o / r)
- gap = L.Pooling(bottom, pool=P.Pooling.AVE, global_pooling=True)
- linear1 = L.Convolution(gap, kernel_size=1, pad=0, stride=1, num_output=m)
- relu1 = L.ReLU(linear1, in_place=True)
- linear2 = L.Convolution(relu1, kernel_size=1, pad=0, stride=1, num_output=o)
- sigmoid1 = L.Sigmoid(linear2, in_place=True)
- flatten1 = L.Flatten(sigmoid1)
- return flatten1
- def conv_relu(bottom, o, stride=1, pad=0):
- conv = L.Convolution(bottom, kernel_size=3, stride=stride, num_output=o, pad=pad)
- relu = L.ReLU(conv, in_place=True, negative_slope=0.1)
- return relu
- def unet_conv(bottom, m, o, se=True):
- conv1 = L.Convolution(bottom, kernel_size=3, stride=1,
- num_output=m, pad=0)
- relu1 = L.ReLU(conv1, in_place=True, negative_slope=0.1)
- conv2 = L.Convolution(relu1, kernel_size=3, stride=1,
- num_output=o, pad=0)
- relu2 = L.ReLU(conv2, in_place=True, negative_slope=0.1)
- if se:
- se1 = seblock(relu2, o, 8)
- return L.Scale(relu2, se1, axis=0, bias_term=False)
- else:
- return relu2
- def unet_branch(bottom, insert_f, i, o, depad):
- pool = L.Convolution(bottom, kernel_size=2, stride=2, num_output=i, pad=0)
- relu1 = L.ReLU(pool, in_place=True, negative_slope=0.1)
- feat = insert_f(relu1)
- unpool = L.Deconvolution(feat, convolution_param=dict(num_output=o, kernel_size=2, pad=0, stride=2))
- relu2 = L.ReLU(unpool, in_place=True, negative_slope=0.1)
- crop = L.Crop(bottom, relu2, crop_param=dict(axis=2, offset=depad))
- cadd = L.Eltwise(crop, relu2, operation=P.Eltwise.SUM)
- return cadd
- def unet1(bottom, ch, deconv):
- block1 = lambda bottom: unet_conv(bottom, 128, 64, True)
- conv1 = unet_conv(bottom, 32, 64, se=False)
- ub1 = unet_branch(conv1, block1, 64, 64, 4)
- conv2 = conv_relu(ub1, 64)
- if deconv:
- return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
- else:
- return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
- def unet2(bottom, ch, deconv):
- def block1(bottom):
- return unet_conv(bottom, 256, 128, se=True)
- def block2(bottom):
- conv1 = unet_conv(bottom, 64, 128, se=True)
- ub1 = unet_branch(conv1, block1, 128, 128, 4)
- conv2 = unet_conv(ub1, 64, 64, se=True)
- return conv2
- conv1 = unet_conv(bottom, 32, 64, se=False)
- ub1 = unet_branch(conv1, block2, 64, 64, 16)
- conv2 = conv_relu(ub1, 64)
- if deconv:
- return L.Deconvolution(conv2, convolution_param=dict(num_output=ch, kernel_size=4, pad=3, stride=2))
- else:
- return L.Convolution(conv2, kernel_size=3, stride=1, num_output=ch, pad=0)
- def make_upcunet():
- netoffset = 36
- ch = 3
- input_size = (256 / 2) + netoffset * 2
- assert(input_size % 4 == 0)
- data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
- u1 = unet1(data, ch=ch, deconv=True)
- u2 = unet2(u1, ch=ch, deconv=False)
- crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
- cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
- return to_proto(cadd)
- def make_cunet():
- netoffset = 28
- ch = 3
- input_size = 256 + netoffset * 2
- assert(input_size % 4 == 0)
- data = L.Input(name="input", shape=dict(dim=[1, ch, input_size, input_size]))
- u1 = unet1(data, ch=ch, deconv=False)
- u2 = unet2(u1, ch=ch, deconv=False)
- crop = L.Crop(u1, u2, crop_param=dict(axis=2, offset=20))
- cadd = L.Eltwise(crop, u2, operation=P.Eltwise.SUM)
- return to_proto(cadd)
- def make_net():
- with open('upcunet.prototxt', 'w') as f:
- print(make_upcunet(), file=f)
- with open('cunet.prototxt', 'w') as f:
- print(make_cunet(), file=f)
- if __name__ == '__main__':
- make_net()
- # test loading the net
- caffe.Net('upcunet.prototxt', caffe.TEST)
- caffe.Net('cunet.prototxt', caffe.TEST)
|