find_unet.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. def find_unet_v2():
  2. avg_pool=4
  3. print_mod = False
  4. check_mod = True
  5. print("cascade")
  6. for i in range(76, 512):
  7. print("-- {}".format(i))
  8. print_buf = []
  9. s = i
  10. # unet 1
  11. s = s - 4 # conv3x3x2
  12. s = s / 2 # down2x2
  13. s = s - 4 # conv3x3x2
  14. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  15. if check_mod and s % avg_pool != 0:
  16. continue
  17. s = s / 2 # down2x2
  18. s = s - 4 # conv3x3x2
  19. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  20. if check_mod and s % avg_pool != 0:
  21. continue
  22. s = s * 2 # up2x2
  23. s = s - 4 # conv3x3x2
  24. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  25. if check_mod and s % avg_pool != 0:
  26. continue
  27. s = s * 2 # up2x2
  28. # deconv
  29. s = s
  30. s = s * 2 - 4
  31. # unet 2
  32. s = s - 4 # conv3x3x2
  33. s = s / 2 # down2x2
  34. s = s - 4 # conv3x3x2
  35. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  36. if check_mod and s % avg_pool != 0:
  37. continue
  38. s = s / 2 # down2x2
  39. s = s - 4 # conv3x3x2
  40. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  41. if check_mod and s % avg_pool != 0:
  42. continue
  43. s = s * 2 # up2x2
  44. s = s - 4 # conv3x3x2
  45. if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
  46. if check_mod and s % avg_pool != 0:
  47. continue
  48. s = s * 2 # up2x2
  49. s = s - 2 # conv3x3 last
  50. #if s % avg_pool != 0:
  51. # continue
  52. print("ok", i, s)
  53. def find_unet():
  54. check_mod = True
  55. print_size = False
  56. print("cascade")
  57. for i in range(76, 512):
  58. print_buf = []
  59. s = i
  60. # unet 1
  61. s = s - 4 # conv3x3x2
  62. if print_size: print("1/2", s)
  63. if check_mod and s % 2 != 0:
  64. continue
  65. s = s / 2 # down2x2
  66. s = s - 4 # conv3x3x2
  67. if print_size: print("1/2",s)
  68. if check_mod and s % 2 != 0:
  69. continue
  70. s = s / 2 # down2x2
  71. s = s - 4 # conv3x3x2
  72. s = s * 2 # up2x2
  73. if print_size: print("2x",s)
  74. s = s - 4 # conv3x3x2
  75. s = s * 2 # up2x2
  76. if print_size: print("2x",s)
  77. # deconv
  78. s = s - 2
  79. s = s * 2 - 4
  80. # unet 2
  81. s = s - 4 # conv3x3x2
  82. if print_size: print("1/2",s)
  83. if check_mod and s % 2 != 0:
  84. continue
  85. s = s / 2 # down2x2
  86. s = s - 4 # conv3x3x2
  87. if print_size: print("1/2",s)
  88. if check_mod and s % 2 != 0:
  89. continue
  90. s = s / 2 # down2x2
  91. s = s - 4 # conv3x3x2
  92. s = s * 2 # up2x2
  93. if print_size: print("2x",s)
  94. s = s - 4 # conv3x3x2
  95. s = s * 2 # up2x2
  96. if print_size: print("2x",s)
  97. s = s - 2 # conv3x3
  98. s = s - 2 # conv3x3 last
  99. #if s % avg_pool != 0:
  100. # continue
  101. print("ok", i, s)
  102. find_unet()