# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hjson
import os
from typing import Dict, Union, Callable, List, Any
from dlk.utils.parser import BaseConfigParser
from dlk.utils.config import ConfigTool
from dlk.data.datamodules import datamodule_register, datamodule_config_register
from dlk.managers import manager_register, manager_config_register
from dlk.core.imodels import imodel_register, imodel_config_register
import pickle as pkl
from dlk.utils.io import open
import torch
import copy
import uuid
import json
from dlk.utils.logger import Logger
logger = Logger.get_logger()
[docs]class Predict(object):
"""Predict
Config Example:
>>> {
>>> "_focus": {
>>> },
>>> "_link": {},
>>> "_search": {},
>>> "config": {
>>> "save_dir": "*@*", # must be provided
>>> "data_path": "*@*", # must be provided
>>> },
>>> "task": {
>>> "_name": task_name
>>> ...
>>> }
>>> }
"""
def __init__(self, config: Union[str, dict], checkpoint: str):
super(Predict, self).__init__()
if not isinstance(config, dict):
with open(config) as f:
config = hjson.load(f, object_pairs_hook=dict)
self.focus = config.pop('_focus', {})
configs = BaseConfigParser(config).parser_with_check()
assert len(
configs
) == 1, f"For predict currently the config length must be 1(you cannot use _search in predict)."
self.config = configs[0]
# TODO: FIXME: use pytorch-lightning build in remote filesystem
# https://pytorch-lightning.readthedocs.io/en/latest/common/remote_fs.html
if isinstance(checkpoint, str):
with open(checkpoint, 'rb') as f:
self.ckpt = torch.load(f, map_location=torch.device('cpu'))
else:
self.ckpt = torch.load(checkpoint,
map_location=torch.device('cpu'))
config_name = []
for source, to in self.focus.items():
config_point = self.config
trace = source.split('.')
for t in trace:
config_point = config_point[t]
config_name.append(to + str(config_point))
if config_name:
name_str = '_'.join(config_name)
else:
name_str = self.config['root']['_name']
self.name_str = name_str
[docs] def trace(self):
"""trace the model to torchscript
Returns:
TODO
"""
config = self.config['root']
name = self.name_str
# get data
data = self.get_data(config)
# set datamodule
datamodule = self.get_datamodule(config, data)
# init imodel and inject the origin test and valid data
imodel = self.get_imodel(config, data)
dataloader = datamodule.train_dataloader()
for data in dataloader:
# script = torch.jit.trace(imodel.model, example_inputs=data, strict=False)
script = torch.jit.trace(imodel.model,
example_inputs=data,
strict=False)
# script = torch.jit.script(imodel.model, example_inputs=data, strict=False)
print(script)
print(script(data))
# imodel.model(data)
break
logger.error('The trace method is not implement yet.')
raise NotImplementedError
[docs] def predict(self, data=None, save_condition=False):
"""init the model, datamodule, manager then predict the predict_dataloader
Args:
data: if provide will not load from data_path
Returns:
None
"""
config = copy.deepcopy(self.config['root'])
name = self.name_str
# get data
if not data:
data = self.get_data(config)
# set datamodule
datamodule = self.get_datamodule(config, data)
# set training manager
manager = self.get_manager(config, name)
# init imodel and inject the origin test and valid data
imodel = self.get_imodel(config, data)
# start predict
predict_result = manager.predict(model=imodel, datamodule=datamodule)
return imodel.postprocessor(stage='predict',
list_batch_outputs=predict_result,
origin_data=data['predict'],
rt_config={},
save_condition=save_condition)
[docs] def get_data(self, config):
"""get the data decided by config
Args:
config: {"config": {"data_path": '..'}}
Returns:
loaded data
"""
with open(config['config']['data_path'], 'rb') as f:
self.data = pkl.load(f).get('data', {})
return self.data
[docs] def get_datamodule(self, config, data):
"""get the datamodule decided by config, and fit the data to datamodule
Args:
config: {"task": {"datamodule": '..'}}
data: {"train": '..', 'valid': '..', ..}
Returns:
datamodule
"""
DataModule, DataModuleConfig = ConfigTool.get_leaf_module(
datamodule_register, datamodule_config_register, 'datamodule',
config['task']['datamodule'])
datamodule = DataModule(DataModuleConfig, data)
return datamodule
[docs] def get_manager(self, config, name):
"""get the tranin/predict manager decided by config
Args:
config: {"task": {"manager": '..'}, "config": {"save_dir"}}
name: the predict progress name
Returns:
manager
"""
Manager, ManagerConfig = ConfigTool.get_leaf_module(
manager_register, manager_config_register, 'manager',
config.get('task').get('manager'))
manager = Manager(ManagerConfig,
rt_config={
"save_dir": config.get('config').get("save_dir"),
"name": name
})
return manager
[docs] def get_imodel(self, config, data):
"""get the imodel decided by config
Args:
config: {"task": {"imodel": '..'}}
data: {"train": '..', 'valid': '..', ..}
Returns:
imodel
"""
IModel, IModelConfig = ConfigTool.get_leaf_module(
imodel_register, imodel_config_register, 'imodel',
config.get('task').get('imodel'))
imodel = IModel(IModelConfig, checkpoint=True)
imodel.load_state_dict(self.ckpt['state_dict'])
imodel.eval()
return imodel