|
@@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|
|
local f = criterion:forward(output, targets)
|
|
|
local se = 0
|
|
|
if config.xInstanceLoss then
|
|
|
- for i = 1, batch_size do
|
|
|
- local el = eval_metric:forward(output[i], targets[i])
|
|
|
- se = se + el
|
|
|
- instance_loss[shuffle[t + i - 1]] = el
|
|
|
- end
|
|
|
- se = (se / batch_size)
|
|
|
+ if type(output) then
|
|
|
+ local tbl = {}
|
|
|
+ for i = 1, batch_size do
|
|
|
+ for j = 1, #output do
|
|
|
+ tbl[j] = output[j][i]
|
|
|
+ end
|
|
|
+ local el = eval_metric:forward(tbl, targets[i])
|
|
|
+ se = se + el
|
|
|
+ instance_loss[shuffle[t + i - 1]] = el
|
|
|
+ end
|
|
|
+ se = (se / batch_size)
|
|
|
+ else
|
|
|
+ for i = 1, batch_size do
|
|
|
+ local el = eval_metric:forward(output[i], targets[i])
|
|
|
+ se = se + el
|
|
|
+ instance_loss[shuffle[t + i - 1]] = el
|
|
|
+ end
|
|
|
+ se = (se / batch_size)
|
|
|
+ end
|
|
|
else
|
|
|
se = eval_metric:forward(output, targets)
|
|
|
end
|