minibatch_adam.lua 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. require 'optim'
  2. require 'cutorch'
  3. require 'xlua'
  4. local function minibatch_adam(model, criterion, eval_metric,
  5. train_x, train_y,
  6. config)
  7. local parameters, gradParameters = model:getParameters()
  8. config = config or {}
  9. local sum_loss = 0
  10. local sum_eval = 0
  11. local count_loss = 0
  12. local batch_size = config.xBatchSize or 32
  13. local shuffle = torch.randperm(train_x:size(1))
  14. local c = 1
  15. local inputs_tmp = torch.Tensor(batch_size,
  16. train_x:size(2), train_x:size(3), train_x:size(4)):zero()
  17. local targets_tmp = torch.Tensor(batch_size,
  18. train_y:size(2)):zero()
  19. local inputs = inputs_tmp:clone():cuda()
  20. local targets = targets_tmp:clone():cuda()
  21. print("## update")
  22. for t = 1, train_x:size(1), batch_size do
  23. if t + batch_size -1 > train_x:size(1) then
  24. break
  25. end
  26. for i = 1, batch_size do
  27. inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
  28. targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
  29. end
  30. inputs:copy(inputs_tmp)
  31. targets:copy(targets_tmp)
  32. local feval = function(x)
  33. if x ~= parameters then
  34. parameters:copy(x)
  35. end
  36. gradParameters:zero()
  37. local output = model:forward(inputs)
  38. local f = criterion:forward(output, targets)
  39. sum_eval = sum_eval + eval_metric:forward(output, targets)
  40. sum_loss = sum_loss + f
  41. count_loss = count_loss + 1
  42. model:backward(inputs, criterion:backward(output, targets))
  43. return f, gradParameters
  44. end
  45. optim.adam(feval, parameters, config)
  46. c = c + 1
  47. if c % 50 == 0 then
  48. collectgarbage()
  49. xlua.progress(t, train_x:size(1))
  50. end
  51. end
  52. xlua.progress(train_x:size(1), train_x:size(1))
  53. return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
  54. end
  55. return minibatch_adam