minibatch_adam.lua 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. require 'optim'
  2. require 'cutorch'
  3. require 'xlua'
  4. local function minibatch_adam(model, criterion, eval_metric,
  5. train_x, train_y,
  6. config)
  7. local parameters, gradParameters = model:getParameters()
  8. config = config or {}
  9. if config.xEvalCount == nil then
  10. config.xEvalCount = 0
  11. config.learningRate = config.xLearningRate
  12. end
  13. local sum_psnr = 0
  14. local sum_loss = 0
  15. local sum_eval = 0
  16. local count_loss = 0
  17. local batch_size = config.xBatchSize or 32
  18. local shuffle = torch.randperm(train_x:size(1))
  19. local c = 1
  20. local inputs_tmp = torch.Tensor(batch_size,
  21. train_x:size(2), train_x:size(3), train_x:size(4)):zero()
  22. local targets_tmp = torch.Tensor(batch_size,
  23. train_y:size(2)):zero()
  24. local inputs = inputs_tmp:clone():cuda()
  25. local targets = targets_tmp:clone():cuda()
  26. local instance_loss = torch.Tensor(train_x:size(1)):zero()
  27. print("## update")
  28. for t = 1, train_x:size(1), batch_size do
  29. if t + batch_size -1 > train_x:size(1) then
  30. break
  31. end
  32. for i = 1, batch_size do
  33. inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
  34. targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
  35. end
  36. inputs:copy(inputs_tmp)
  37. targets:copy(targets_tmp)
  38. local feval = function(x)
  39. if x ~= parameters then
  40. parameters:copy(x)
  41. end
  42. gradParameters:zero()
  43. local output = model:forward(inputs)
  44. local f = criterion:forward(output, targets)
  45. local se = 0
  46. if config.xInstanceLoss then
  47. if type(output) then
  48. local tbl = {}
  49. for i = 1, batch_size do
  50. for j = 1, #output do
  51. tbl[j] = output[j][i]
  52. end
  53. local el = eval_metric:forward(tbl, targets[i])
  54. se = se + el
  55. instance_loss[shuffle[t + i - 1]] = el
  56. end
  57. se = (se / batch_size)
  58. else
  59. for i = 1, batch_size do
  60. local el = eval_metric:forward(output[i], targets[i])
  61. se = se + el
  62. instance_loss[shuffle[t + i - 1]] = el
  63. end
  64. se = (se / batch_size)
  65. end
  66. else
  67. se = eval_metric:forward(output, targets)
  68. end
  69. sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
  70. sum_eval = sum_eval + se
  71. sum_loss = sum_loss + f
  72. count_loss = count_loss + 1
  73. model:backward(inputs, criterion:backward(output, targets))
  74. return f, gradParameters
  75. end
  76. optim.adam(feval, parameters, config)
  77. config.xEvalCount = config.xEvalCount + batch_size
  78. config.learningRate = config.xLearningRate / (1 + config.xEvalCount * config.xLearningRateDecay)
  79. c = c + 1
  80. if c % 50 == 0 then
  81. collectgarbage()
  82. xlua.progress(t, train_x:size(1))
  83. end
  84. end
  85. xlua.progress(train_x:size(1), train_x:size(1))
  86. return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
  87. end
  88. return minibatch_adam