minibatch_adam.lua 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. require 'optim'
  2. require 'cutorch'
  3. require 'xlua'
  4. local function minibatch_adam(model, criterion,
  5. train_x,
  6. config, transformer,
  7. input_size, target_size)
  8. local parameters, gradParameters = model:getParameters()
  9. config = config or {}
  10. local sum_loss = 0
  11. local count_loss = 0
  12. local batch_size = config.xBatchSize or 32
  13. local shuffle = torch.randperm(#train_x)
  14. local c = 1
  15. local inputs = torch.Tensor(batch_size,
  16. input_size[1], input_size[2], input_size[3]):cuda()
  17. local targets = torch.Tensor(batch_size,
  18. target_size[1] * target_size[2] * target_size[3]):cuda()
  19. local inputs_tmp = torch.Tensor(batch_size,
  20. input_size[1], input_size[2], input_size[3])
  21. local targets_tmp = torch.Tensor(batch_size,
  22. target_size[1] * target_size[2] * target_size[3])
  23. for t = 1, #train_x, batch_size do
  24. if t + batch_size > #train_x then
  25. break
  26. end
  27. xlua.progress(t, #train_x)
  28. for i = 1, batch_size do
  29. local x, y = transformer(train_x[shuffle[t + i - 1]])
  30. inputs_tmp[i]:copy(x)
  31. targets_tmp[i]:copy(y)
  32. end
  33. inputs:copy(inputs_tmp)
  34. targets:copy(targets_tmp)
  35. local feval = function(x)
  36. if x ~= parameters then
  37. parameters:copy(x)
  38. end
  39. gradParameters:zero()
  40. local output = model:forward(inputs)
  41. local f = criterion:forward(output, targets)
  42. sum_loss = sum_loss + f
  43. count_loss = count_loss + 1
  44. model:backward(inputs, criterion:backward(output, targets))
  45. return f, gradParameters
  46. end
  47. optim.adam(feval, parameters, config)
  48. c = c + 1
  49. if c % 10 == 0 then
  50. collectgarbage()
  51. end
  52. end
  53. xlua.progress(#train_x, #train_x)
  54. return { mse = sum_loss / count_loss}
  55. end
  56. return minibatch_adam