12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- require 'optim'
- require 'cutorch'
- require 'xlua'
- local function minibatch_adam(model, criterion, eval_metric,
- train_x, train_y,
- config)
- local parameters, gradParameters = model:getParameters()
- config = config or {}
- if config.xEvalCount == nil then
- config.xEvalCount = 0
- config.learningRate = config.xLearningRate
- end
- local sum_psnr = 0
- local sum_loss = 0
- local sum_eval = 0
- local count_loss = 0
- local batch_size = config.xBatchSize or 32
- local shuffle = torch.randperm(train_x:size(1))
- local c = 1
- local inputs_tmp = torch.Tensor(batch_size,
- train_x:size(2), train_x:size(3), train_x:size(4)):zero()
- local targets_tmp = torch.Tensor(batch_size,
- train_y:size(2)):zero()
- local inputs = inputs_tmp:clone():cuda()
- local targets = targets_tmp:clone():cuda()
- local instance_loss = torch.Tensor(train_x:size(1)):zero()
- print("## update")
- for t = 1, train_x:size(1), batch_size do
- if t + batch_size -1 > train_x:size(1) then
- break
- end
- for i = 1, batch_size do
- inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
- targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
- end
- inputs:copy(inputs_tmp)
- targets:copy(targets_tmp)
- local feval = function(x)
- if x ~= parameters then
- parameters:copy(x)
- end
- gradParameters:zero()
- local output = model:forward(inputs)
- local f = criterion:forward(output, targets)
- local se = 0
- if config.xInstanceLoss then
- if type(output) then
- local tbl = {}
- for i = 1, batch_size do
- for j = 1, #output do
- tbl[j] = output[j][i]
- end
- local el = eval_metric:forward(tbl, targets[i])
- se = se + el
- instance_loss[shuffle[t + i - 1]] = el
- end
- se = (se / batch_size)
- else
- for i = 1, batch_size do
- local el = eval_metric:forward(output[i], targets[i])
- se = se + el
- instance_loss[shuffle[t + i - 1]] = el
- end
- se = (se / batch_size)
- end
- else
- se = eval_metric:forward(output, targets)
- end
- sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
- sum_eval = sum_eval + se
- sum_loss = sum_loss + f
- count_loss = count_loss + 1
- model:backward(inputs, criterion:backward(output, targets))
- return f, gradParameters
- end
- optim.adam(feval, parameters, config)
- config.xEvalCount = config.xEvalCount + batch_size
- config.learningRate = config.xLearningRate / (1 + config.xEvalCount * config.xLearningRateDecay)
- c = c + 1
- if c % 50 == 0 then
- collectgarbage()
- xlua.progress(t, train_x:size(1))
- end
- end
- xlua.progress(train_x:size(1), train_x:size(1))
- return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
- end
- return minibatch_adam
|