Explorar o código

Use w2nn.load_model to use cudnn #369

nagadomi %!s(int64=4) %!d(string=hai) anos
pai
achega
c9b0fe7f41
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      train.lua

+ 1 - 1
train.lua

@@ -526,7 +526,7 @@ local function train()
    }
    }
    local model
    local model
    if settings.resume:len() > 0 then
    if settings.resume:len() > 0 then
-      model = torch.load(settings.resume, "ascii")
+      model = w2nn.load_model(settings.resume, settings.backend == "cudnn", "ascii")
       adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
       adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
       print(string.format("set eval count = %d", adam_config.xEvalCount))
       print(string.format("set eval count = %d", adam_config.xEvalCount))
       if adam_config.xEvalCount > 0 then
       if adam_config.xEvalCount > 0 then