浏览代码

Use clearState()

nagadomi 9 年之前
父节点
当前提交
9b238bd693
共有 3 个文件被更改,包括 4 次插入51 次删除
  1. 1 46
      lib/cleanup_model.lua
  2. 2 4
      train.lua
  3. 1 1
      web.lua

+ 1 - 46
lib/cleanup_model.lua

@@ -1,48 +1,3 @@
--- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
-
-local function zeroDataSize(data)
-   if type(data) == 'table' then
-      for i = 1, #data do
-	 data[i] = zeroDataSize(data[i])
-      end
-   elseif type(data) == 'userdata' then
-      data = torch.Tensor():typeAs(data)
-   end
-   return data
-end
--- Resize the output, gradInput, etc temporary tensors to zero (so that the
--- on disk size is smaller)
-local function cleanupModel(node)
-   if node.output ~= nil then
-      node.output = zeroDataSize(node.output)
-   end
-   if node.gradInput ~= nil then
-      node.gradInput = zeroDataSize(node.gradInput)
-   end
-   if node.finput ~= nil then
-      node.finput = zeroDataSize(node.finput)
-   end
-   if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
-      if node.negative ~= nil then
-	 node.negative = zeroDataSize(node.negative)
-      end
-   end
-   if tostring(node) == "nn.Dropout" then
-      if node.noise ~= nil then
-	 node.noise = zeroDataSize(node.noise)
-      end
-   end
-   -- Recurse on nodes with 'modules'
-   if (node.modules ~= nil) then
-     if (type(node.modules) == 'table') then
-	for i = 1, #node.modules do
-	   local child = node.modules[i]
-	   cleanupModel(child)
-	end
-     end
-   end
-end
 function w2nn.cleanup_model(model)
-   cleanupModel(model)
-   return model
+   return model:clearState()
 end

+ 2 - 4
train.lua

@@ -206,9 +206,7 @@ local function train()
 	    best_score = score
 	    print("* update best model")
 	    if settings.save_history then
-	       local model_clone = model:clone()
-	       w2nn.cleanup_model(model_clone)
-	       torch.save(string.format(settings.model_file, epoch, i), model_clone)
+	       torch.save(string.format(settings.model_file, epoch, i), model:clearState())
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.%d-%d.png"):format(settings.noise_level,
@@ -221,7 +219,7 @@ local function train()
 		  save_test_scale(model, test_image, log)
 	       end
 	    else
-	       torch.save(settings.model_file, model)
+	       torch.save(settings.model_file, model:clearState())
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.png"):format(settings.noise_level))

+ 1 - 1
web.lua

@@ -103,7 +103,7 @@ local function get_image(req)
 end
 local function cleanup_model(model)
    if CLEANUP_MODEL then
-      w2nn.cleanup_model(model) -- release GPU memory
+      model:clearState() -- release GPU memory
    end
 end
 local function convert(x, alpha, options)