Skip to content

Commit

Permalink
REF: 重构5.5
Browse files Browse the repository at this point in the history
  • Loading branch information
johncage committed Mar 21, 2022
1 parent 2f5a5c1 commit 4cd6287
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
79 changes: 70 additions & 9 deletions torch_lib/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, Optional, Union, TypeVar
from torch_lib.callback.dataset import DataParser, DataProvider
from torch_lib.callback.metrics import M_SEQ
from torch_lib.callback.run import R_SEQ
from torch_lib.utils import MultiConst, get_device, type_cast, MethodChaining, InvocationDebug
from torch_lib.callback.dataset import ConstantDataProvider, DataParser, DataProvider
from torch_lib.callback.metrics import M_SEQ, MetricCallbackExecutor
from torch_lib.callback.run import R_SEQ, RunCallbackExecutor
from torch_lib.utils import NOTHING, MultiConst, get_device, type_cast, MethodChaining, InvocationDebug, is_nothing, logger
from torch_lib.utils.type import NUMBER
from torch_lib.context import Context
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -32,28 +32,36 @@ def train(
self,
train_dataset: DATASET,
total_epochs: int = 1,
eval_dataset: DATASET = None,
run_callbacks: R_SEQ = None,
eval_dataset: DATASET = NOTHING,
run_callbacks: R_SEQ = NOTHING,
log_option = None # TODO: log system design
):
self._build_total_epochs(total_epochs)
self._build_run_callback_exec(run_callbacks)
self._build_dataset(train_dataset, 'train')
self._build_dataset(eval_dataset, 'eval')
self.ctx.build.train(self.ctx)

@InvocationDebug('ModelProxy.Predict')
def predict(
self,
dataset: DATASET,
run_callbacks: R_SEQ = None,
run_callbacks: R_SEQ = NOTHING,
log_option = None # TODO: log system design
):
self._build_run_callback_exec(run_callbacks)
self._build_dataset(dataset, 'eval')
self.ctx.build.predict(self.ctx)

@InvocationDebug('ModelProxy.Eval')
def eval(
self,
dataset: DATASET,
run_callbacks: R_SEQ = None,
run_callbacks: R_SEQ = NOTHING,
log_option = None # TODO: log system design
):
self._build_run_callback_exec(run_callbacks)
self._build_dataset(dataset, 'eval')
self.ctx.build.eval(self.ctx)

@InvocationDebug('ModelProxy.Summary')
Expand All @@ -73,7 +81,11 @@ def build(
lr_decay_options: Optional[Dict] = None,
data_parser: Optional[DataParser] = None
) -> MP:
pass
self._build_loss(loss_func)
self._build_metric_callback_exec(metric_callbacks)
self._build_data_parser(data_parser)
self._build_optimizer(optimizer, learning_rate, optimizer_options)
self._build_lr_decay(lr_decay, lr_decay_options)

@InvocationDebug('ModelProxy.TrainBuilder')
@MethodChaining
Expand Down Expand Up @@ -188,3 +200,52 @@ def build_eval(self) -> MP:
# end callback
handler.End()
])

@InvocationDebug('ModelProxy._build_loss')
def _build_loss(self, loss_func):
if loss_func is not None:
self.ctx.build.loss_func = loss_func if is_nothing(loss_func) is False else NOTHING

@InvocationDebug('ModelProxy._build_metric_callback_exec')
def _build_metric_callback_exec(self, metric_callbacks):
if metric_callbacks is not None:
self.ctx.build.metric_callback_exec = MetricCallbackExecutor(metric_callbacks) if is_nothing(metric_callbacks) is False else NOTHING

@InvocationDebug('ModelProxy._build_data_parser')
def _build_data_parser(self, data_parser):
if data_parser is not None:
self.ctx.build.data_parser = data_parser if is_nothing(data_parser) is False else NOTHING

@InvocationDebug('ModelProxy._build_run_callback_exec')
def _build_run_callback_exec(self, run_callbacks):
if run_callbacks is not None:
self.ctx.build.run_callback_exec = RunCallbackExecutor(run_callbacks) if is_nothing(run_callbacks) is False else NOTHING

@InvocationDebug('ModelProxy._build_optimizer')
def _build_optimizer(self, optimizer, learning_rate, optimizer_options):
if optimizer is not None:
if isinstance(optimizer, Optimizer):
self.ctx.build.optimizer = optimizer

@InvocationDebug('ModelProxy._build_lr_decay')
def _build_lr_decay(self, lr_decay, lr_decay_options):
pass

@InvocationDebug('ModelProxy._build_total_epochs')
def _build_total_epochs(self, total_epochs):
self.ctx.epoch.total = total_epochs if isinstance(total_epochs, int) else NOTHING

@InvocationDebug('ModelProxy._build_dataset')
def _build_dataset(self, dataset, mode: str):
if dataset is not None:
if is_nothing(dataset):
dataset = NOTHING
else:
dataset = dataset if isinstance(dataset, DataProvider) else ConstantDataProvider(dataset)

if mode == 'train':
self.ctx.build.train_provider = dataset
elif mode == 'eval':
self.ctx.build.eval_provider = dataset
else:
logger.warn('_build_dataset mode not supported.')
12 changes: 7 additions & 5 deletions torch_lib/log/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@


color_dict = {
'r': 31,
'g': 32,
'y': 33,
'b': 34,
'w': 38
'r': 31, # red
'g': 32, # green
'y': 33, # yellow
'b': 34, # blue
'p': 35, # purple
'c': 36, # cyan
'w': 38 # white
}

info_prefix = '[TORCH_LIB INFO]'
Expand Down
3 changes: 3 additions & 0 deletions torch_lib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __setattr__(self, *_):
def __setitem__(self, *_):
pass

def __len__(self):
return 0

def __iter__(self):
return self

Expand Down

0 comments on commit 4cd6287

Please sign in to comment.