"""
**emloop** main loop for training models.
The MainLoop requires AbstractModel, AbstractDataset and a list of AbstractHooks.
Having all that, it manages iterating through streams, training and hooks execution.
"""
import logging
from typing import Iterable, Callable, List, Dict, Optional
from collections import OrderedDict
from .datasets import AbstractDataset
from .models.abstract_model import AbstractModel
from .hooks.abstract_hook import AbstractHook, TrainingTerminated
from .utils import Timer, TrainingTrace, TrainingTraceKeys
from .utils.misc import CaughtInterrupts
from .datasets.stream_wrapper import StreamWrapper
from .constants import EL_DEFAULT_TRAIN_STREAM, EL_PREDICT_STREAM
from .types import EpochData
[docs]class MainLoop(CaughtInterrupts): # pylint: disable=too-many-instance-attributes
"""**emloop** main loop for training and model inference."""
EMPTY_ACTIONS = ['ignore', 'warn', 'error']
"""Possible actions to be taken when a batch/stream is empty."""
UNUSED_SOURCE_ACTIONS = ['ignore', 'warn', 'error']
"""Possible actions to be taken when a stream source is unused by the trained model."""
INCORRECT_CONFIG_ACTIONS = ['ignore', 'warn', 'error']
"""Possible actions to be taken when a mainloop config contains some unexpected arguments."""
[docs] def __init__(self, # pylint: disable=too-many-arguments
model: AbstractModel, dataset: AbstractDataset,
hooks: Iterable[AbstractHook]=(),
train_stream_name: str=EL_DEFAULT_TRAIN_STREAM,
extra_streams: List[str]=(), # pylint: disable=invalid-sequence-index
buffer: int=0,
on_empty_batch: str='error',
on_empty_stream: str='error',
on_unused_sources: str='warn',
on_incorrect_config: str= 'error',
fixed_batch_size: Optional[int]=None,
fixed_epoch_size: Optional[int]=None,
skip_zeroth_epoch: bool=False,
**kwargs):
"""
:param model: trained model
:param dataset: loaded dataset
:param hooks: training hooks
:param train_stream_name: name of the training stream
:param extra_streams: additional stream names to be evaluated between epochs
:param buffer: size of the batch buffer, 0 means no buffer
:param on_empty_batch: action to take when batch is empty; one of :py:attr:`MainLoop.EMPTY_ACTIONS`
:param on_empty_stream: action to take when stream is empty; one of :py:attr:`MainLoop.EMPTY_ACTIONS`
:param on_unused_sources: action to take when stream provides an unused sources; one of
:py:attr:`UNUSED_SOURCE_ACTIONS`
:param on_incorrect_config: action to take when mainloop config contains unexpected arguments; one of
:py:attr:`MainLoop.INCORRECT_CONFIG_ACTIONS`
:param fixed_batch_size: if specified, main_loop removes all batches that do not have the specified size
:param fixed_epoch_size: if specified, cut the train stream to epochs of at most ``fixed_epoch_size`` batches
:param skip_zeroth_epoch: if specified, main loop skips the 0th epoch
:raise AssertionError: in case of unsupported value of ``on_empty_batch``, ``on_empty_stream`` or \
``on_unused_sources``
"""
assert on_empty_batch in MainLoop.EMPTY_ACTIONS
assert on_empty_stream in MainLoop.EMPTY_ACTIONS
assert on_unused_sources in MainLoop.UNUSED_SOURCE_ACTIONS
assert on_incorrect_config in MainLoop.INCORRECT_CONFIG_ACTIONS
if kwargs:
if on_incorrect_config == 'error':
raise ValueError('Config yaml contains some unexpected arguments in mainloop section. '
'Set `main_loop.on_incorrect_config` to `warn` in order to suppress this error.\n'
'Extra arguments: {}'.format(kwargs))
elif on_incorrect_config == 'warn':
logging.warning('Config yaml contains some unexpected arguments in mainloop section. '
'Set `main_loop.on_incorrect_config` to `ignore` in order to suppress this warning. '
'Extra arguments: %s', kwargs)
self._model = model
self._dataset = dataset
self._hooks = hooks
self._buffer = buffer
self._on_empty_batch = on_empty_batch
self._on_empty_stream = on_empty_stream
self._on_unused_sources = on_unused_sources
self._fixed_batch_size = fixed_batch_size
self._fixed_epoch_size = fixed_epoch_size
self._extra_sources_warned = False
self._epoch_profile = {}
self._train_stream_name = train_stream_name
self._extra_streams = list(extra_streams)
self._skip_zeroth_epoch = skip_zeroth_epoch
self._streams = {}
self._epochs_done = None
for hook in self._hooks:
hook.register_mainloop(self)
super().__init__()
@property
def epochs_done(self) -> Optional[int]:
"""Number of training epochs done in the last call of :py:meth:`self._run_training`."""
return self._epochs_done
@property
def fixed_epoch_size(self) -> Optional[int]:
"""Fixed epoch size parameter as specified in :py:meth:`self.__init__`."""
return self._fixed_epoch_size
@property
def extra_streams(self) -> List[str]:
"""List of extra stream names as specified in :py:meth:`self.__init__`."""
return self._extra_streams
[docs] def _create_epoch_data(self, streams: Optional[Iterable[str]]=None) -> EpochData:
"""Create empty epoch data double dict."""
if streams is None:
streams = [self._train_stream_name] + self._extra_streams
return OrderedDict([(stream_name, OrderedDict()) for stream_name in streams])
[docs] def _check_sources(self, batch: Dict[str, object]) -> None:
"""
Check for unused and missing sources.
:param batch: batch to be checked
:raise ValueError: if a source is missing or unused and ``self._on_unused_sources`` is set to ``error``
"""
unused_sources = [source for source in batch.keys() if source not in self._model.input_names]
missing_sources = [source for source in self._model.input_names if source not in batch.keys()]
# check stream sources
if unused_sources:
if self._on_unused_sources == 'warn' and not self._extra_sources_warned:
logging.warning('Some sources provided by the stream do not match model placeholders. Set '
'`main_loop.on_unused_sources` to `ignore` in order to suppress this warning. '
'Extra sources: %s', unused_sources)
self._extra_sources_warned = True
elif self._on_unused_sources == 'error':
raise ValueError('Some sources provided by the stream do not match model placeholders. Set'
'`main_loop.on_unused_sources` to `warn` in order to suppress this error.\n'
'Extra sources: {}'.format(unused_sources))
if missing_sources:
raise ValueError('Stream does not provide all required sources. Missing sources: {}'
.format(missing_sources))
[docs] def _run_epoch(self, stream: StreamWrapper, train: bool) -> None:
"""
Iterate through the given stream and evaluate/train the model with the received batches.
Calls :py:meth:`emloop.hooks.AbstractHook.after_batch` events.
:param stream: stream to iterate
:param train: if set to ``True``, the model will be trained
:raise ValueError: in case of empty batch when ``on_empty_batch`` is set to ``error``
:raise ValueError: in case of empty stream when ``on_empty_stream`` is set to ``error``
:raise ValueError: in case of two batch variables having different lengths
"""
nonempty_batch_count = 0
for i, batch_input in enumerate(stream):
self.raise_check_interrupt()
batch_sizes = {len(source) for source in batch_input.values()}
if len(batch_sizes) == 0 or batch_sizes == {0}:
if self._on_empty_batch == 'warn':
logging.warning('%i-th batch in stream `%s` appears to be empty (%i-th empty batch in total). Set '
'`main_loop.on_empty_batch` to `ignore` in order to suppress this warning.',
i, stream.name, nonempty_batch_count)
elif self._on_empty_batch == 'error':
raise ValueError('{}-th batch in stream `{}` appears to be empty ({}-th empty batch in total). Set '
'`main_loop.on_empty_batch` to `warn` in order to change this error into warning; '
'set to `ignore` to remove it.'.format(i, stream.name, nonempty_batch_count))
continue
elif self._fixed_batch_size:
if batch_sizes != {self._fixed_batch_size}:
var, len_ = [(k, len(v)) for k, v in batch_input.items() if len(v) != self._fixed_batch_size][0]
logging.debug('%i-th batch in stream `%s` has variable `%s` of length %i inconsistent with '
'`main_loop.fixed_size` = %i', i, stream.name, var, len_, self._fixed_batch_size)
continue
nonempty_batch_count += 1
self._check_sources(batch_input)
with Timer('eval_batch_{}'.format(stream.name), self._epoch_profile):
batch_output = self._model.run(batch=batch_input, train=train, stream=stream)
assert set(batch_input.keys()).isdisjoint(set(batch_output)), 'Batch inputs and outputs must not overlap.'
with Timer('after_batch_hooks_{}'.format(stream.name), self._epoch_profile):
batch_data = {**batch_input, **batch_output}
for hook in self._hooks:
hook.after_batch(stream_name=stream.name, batch_data=batch_data)
if nonempty_batch_count == 0:
if self._on_empty_stream == 'warn':
logging.warning('Stream `%s` appears to be empty. Set `main_loop.on_empty_stream` to `ignore` in order '
'to suppress this warning.', stream.name)
elif self._on_empty_stream == 'error':
raise ValueError('Stream `{}` appears to be empty. Set '
'`main_loop.on_empty_stream` to `warn` in order to change this error into warning; '
'set to `ignore` to remove it.'.format(stream.name))
[docs] def train_by_stream(self, stream: StreamWrapper) -> None:
"""
Train the model with the given stream.
:param stream: stream to train with
"""
self._run_epoch(stream=stream, train=True)
[docs] def evaluate_stream(self, stream: StreamWrapper) -> None:
"""
Evaluate the given stream.
:param stream: stream to be evaluated
:param stream_name: stream name
"""
self._run_epoch(stream=stream, train=False)
[docs] def get_stream(self, stream_name: str) -> StreamWrapper:
"""
Get a :py:class:`StreamWrapper` with the given name.
:param stream_name: stream name
:return: dataset function name providing the respective stream
:raise AttributeError: if the dataset does not provide the function creating the stream
"""
if stream_name not in self._streams:
stream_fn_name = '{}_stream'.format(stream_name)
try:
stream_fn = getattr(self._dataset, stream_fn_name)
stream_epoch_limit = -1
if self._fixed_epoch_size is not None and stream_name == self._train_stream_name:
stream_epoch_limit = self._fixed_epoch_size
self._streams[stream_name] = StreamWrapper(stream_fn, buffer_size=self._buffer,
epoch_size=stream_epoch_limit, name=stream_name,
profile=self._epoch_profile)
except AttributeError as ex:
raise AttributeError('The dataset does not have a function for creating a stream named `{}`. '
'The function has to be named `{}`.'.format(stream_name, stream_fn_name)) from ex
return self._streams[stream_name]
[docs] def _run_zeroth_epoch(self, streams: Iterable[str]) -> None:
"""
Run zeroth epoch on the specified streams.
Calls
- :py:meth:`emloop.hooks.AbstractHook.after_epoch`
:param streams: stream names to be evaluated
"""
self._epoch_profile.clear()
for stream_name in streams:
with self.get_stream(stream_name) as stream:
self.evaluate_stream(stream)
epoch_data = self._create_epoch_data(streams)
with Timer('after_epoch_hooks', self._epoch_profile):
for hook in self._hooks:
hook.after_epoch(epoch_id=0, epoch_data=epoch_data)
for hook in self._hooks:
hook.after_epoch_profile(epoch_id=0, profile=self._epoch_profile,
streams=streams)
[docs] def _try_run(self, run_func: Callable[[], None]) -> None:
"""
Try running the given function (training/prediction).
Calls
- :py:meth:`emloop.hooks.AbstractHook.before_training`
- :py:meth:`emloop.hooks.AbstractHook.after_training`
:param run_func: function to be run
"""
# Initialization: before_training
for hook in self._hooks:
hook.before_training()
try:
run_func()
except TrainingTerminated as ex:
logging.info('Training terminated: %s', ex)
# After training: after_training
for hook in self._hooks:
hook.after_training()
[docs] def run_training(self, trace: Optional[TrainingTrace]=None) -> None:
"""
Run the main loop in the training mode.
Calls
- :py:meth:`emloop.hooks.AbstractHook.after_epoch`
- :py:meth:`emloop.hooks.AbstractHook.after_epoch_profile`
"""
for stream_name in [self._train_stream_name] + self._extra_streams:
self.get_stream(stream_name)
def training():
logging.debug('Training started')
self._epochs_done = 0
# Zeroth epoch: after_epoch
if not self._skip_zeroth_epoch:
logging.info('Evaluating 0th epoch')
self._run_zeroth_epoch([self._train_stream_name] + self._extra_streams)
logging.info('0th epoch done\n\n')
# Training loop: after_epoch, after_epoch_profile
while True:
epoch_id = self._epochs_done + 1
logging.info('Training epoch %s', epoch_id)
self._epoch_profile.clear()
epoch_data = self._create_epoch_data()
with self.get_stream(self._train_stream_name) as stream:
self.train_by_stream(stream)
for stream_name in self._extra_streams:
with self.get_stream(stream_name) as stream:
self.evaluate_stream(stream)
with Timer('after_epoch_hooks', self._epoch_profile):
for hook in self._hooks:
hook.after_epoch(epoch_id=epoch_id, epoch_data=epoch_data)
for hook in self._hooks:
hook.after_epoch_profile(epoch_id=epoch_id, profile=self._epoch_profile,
streams=[self._train_stream_name] + self._extra_streams)
self._epochs_done = epoch_id
if trace is not None:
trace[TrainingTraceKeys.EPOCHS_DONE] = self._epochs_done
logging.info('Epoch %s done\n\n', epoch_id)
self._try_run(training)
[docs] def run_evaluation(self, stream_name: str) -> None:
"""
Run the main loop with the given stream in the prediction mode.
:param stream_name: name of the stream to be evaluated
"""
def prediction():
logging.info('Running prediction')
self._run_zeroth_epoch([stream_name])
logging.info('Prediction done\n\n')
self._try_run(prediction)