123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- require 'optim'
- require 'cutorch'
- require 'xlua'
- local function minibatch_adam(model, criterion,
- train_x,
- config, transformer,
- input_size, target_size)
- local parameters, gradParameters = model:getParameters()
- config = config or {}
- local sum_loss = 0
- local count_loss = 0
- local batch_size = config.xBatchSize or 32
- local shuffle = torch.randperm(#train_x)
- local c = 1
- local inputs = torch.Tensor(batch_size,
- input_size[1], input_size[2], input_size[3]):cuda()
- local targets = torch.Tensor(batch_size,
- target_size[1] * target_size[2] * target_size[3]):cuda()
- local inputs_tmp = torch.Tensor(batch_size,
- input_size[1], input_size[2], input_size[3])
- local targets_tmp = torch.Tensor(batch_size,
- target_size[1] * target_size[2] * target_size[3])
- for t = 1, #train_x do
- xlua.progress(t, #train_x)
- local xy = transformer(train_x[shuffle[t]], false, batch_size)
- for i = 1, #xy do
- inputs_tmp[i]:copy(xy[i][1])
- targets_tmp[i]:copy(xy[i][2])
- end
- inputs:copy(inputs_tmp)
- targets:copy(targets_tmp)
- local feval = function(x)
- if x ~= parameters then
- parameters:copy(x)
- end
- gradParameters:zero()
- local output = model:forward(inputs)
- local f = criterion:forward(output, targets)
- sum_loss = sum_loss + f
- count_loss = count_loss + 1
- model:backward(inputs, criterion:backward(output, targets))
- return f, gradParameters
- end
- optim.adam(feval, parameters, config)
-
- c = c + 1
- if c % 20 == 0 then
- collectgarbage()
- end
- end
- xlua.progress(#train_x, #train_x)
-
- return { loss = sum_loss / count_loss}
- end
- return minibatch_adam
|