Source code for dlk.core.optimizers

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""optimizers"""
import importlib
import os
from typing import Callable, Dict, Tuple, Any
from dlk.utils.register import Register
import torch.optim as optim
from dlk.utils.logger import Logger
import torch
import re


logger = Logger.get_logger()
optimizer_config_register = Register("Optimizer config register.")
optimizer_register = Register("Optimizer register.")


[docs]class BaseOptimizer(object):
[docs] def get_optimizer(self)->optim.Optimizer: """return the initialized optimizer Returns: Optimizer """ raise NotADirectoryError
[docs] def init_optimizer(self, optimizer: optim.Optimizer, model: torch.nn.Module, config: Dict): """init the optimizer for paras in model, and the group is decided by config Args: optimizer: adamw, sgd, etc. model: pytorch model config: which decided the para group, lr, etc. Returns: the initialized optimizer """ optimizer_special_groups = config.pop('optimizer_special_groups', {}) params = [] all_named_parameters = list(model.named_parameters()) total_all_named_parameters = len(all_named_parameters) logger.info(f"All Named Params Num is {len(all_named_parameters)}") has_grouped_params = set() for special_group_name in optimizer_special_groups.get('order', []): group_config = optimizer_special_groups[special_group_name]['config'] group_patterns = optimizer_special_groups[special_group_name]['pattern'] # convert to regex combine_patterns = [] for pattern in group_patterns: combine_patterns.append(f"({pattern})") cc_patterns = re.compile("|".join(combine_patterns)) group_param = [] for n, p in all_named_parameters: # logger.info(f"Param name {n}") if n in has_grouped_params: continue if cc_patterns.search(n): # use regex has_grouped_params.add(n) group_param.append(p) group_config['params'] = group_param group_config['name'] = special_group_name params.append(group_config) reserve_params = [p for n, p in all_named_parameters if not n in has_grouped_params] params.append({"params": reserve_params, "name": config.pop('name')}) logger.info(f"Param Group Nums {len(params)}") total_param = 0 for group in params: total_param = total_param + len(group['params']) assert total_param == total_all_named_parameters return optimizer(params=params, **config)
def __call__(self): """the same as self.get_optimizer() """ return self.get_optimizer()
[docs]def import_optimizers(optimizers_dir, namespace): for file in os.listdir(optimizers_dir): path = os.path.join(optimizers_dir, file) if ( not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)) ): optimizer_name = file[: file.find(".py")] if file.endswith(".py") else file importlib.import_module(namespace + "." + optimizer_name)
# automatically import any Python files in the optimizers directory optimizers_dir = os.path.dirname(__file__) import_optimizers(optimizers_dir, "dlk.core.optimizers")