#  Copyright (C) 2019  Robert J Weston, Oxford Robotics Institute
#  xmen
#  email:
#  github:
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  GNU General Public License for more details.
#  You should have received a copy of the GNU General Public License
#   along with this program. If not, see <>.
import time
import datetime
import glob
import re
import inspect
import copy
import warnings
import os
import logging

TRIGGERS = ['step', 'epoch', 'era', 'eon', 'supereon']
BRIEF = {"step": "s", "epoch": "e", "era": "er", "eon": "eo", "supereon": "se"}
IDX = {c: i for i, c in enumerate(TRIGGERS)}

class Spec(object):
    def __init__(self, spec):
        assert isinstance(spec, str), 'invalid specification found'
        regex, steps, trigger = None, None, None
        if '@' in spec:
            parts = spec.split('@')
            regex, string = parts
            if regex == '':
                regex = None
            steps ='[0-9]+', string)
            assert steps is not None, f'No steps found in {spec}'
            steps =
            trigger = {}
            for k, v in BRIEF.items():
                trigger[k] = k
                trigger[v] = k
            trig_key = string.split(steps)[1]
            trigger = trigger.get(trig_key, None)
            assert trigger is not None, f'Invalid or no trigger specified in {spec}'
            regex = spec
        self.regex = regex
        self.modulo = int(steps) if steps is not None else None
        self.trigger = trigger

    def __str__(self):
        return f'{self.regex if self.regex is not None else ""}@{self.modulo}{self.trigger}'

def read_modulo_string(string):
    assert isinstance(string, str)
    parts = string.split('@')
    regex, string = parts
    steps ='[0-9]+', string)
    assert steps is not None
    steps =
    trigger = {}
    for k, v in BRIEF.items():
        trigger[k] = k
        trigger[v] = k
    trigger = trigger[string.split(steps)[1]]
    return regex, int(steps), trigger

class LastStep(Exception):
    def __init__(self, trigger, step):
        self.trigger = trigger
        self.step = step

[docs]class Hook(Spec): """A base class defining a variable passing protocol with the Monitor class. Ever time the monitors count is divisible by ``modulo`` for a particular ``trigger`` then the hook is passed all the variables matching ``regex`` from the current stack. Users therefore define hooks by overloading the hooks __call__ method (with the same call signature). """ def __call__(self, var_dict, monitor): raise NotImplementedError('All methods must implement the call method')
class EarlyStop(object): def __init__(self, condition='max', memory=5): """An early stop implementation. Returns true if the score in most distant memory is <= / >= all other scores in memory. Example: :: import numpy as np import matplotlib.pyplot as plt'ggplot') es = EarlyStop('max', 36) X = np.linspace(0, 10, 100) Y = 1 - np.sin(X) * X y = [] for x in X: y.append(1 - np.sin(x) * x) es.append(y[-1], an_example=True, step=1) if es: print(f'Breaking at {x}') break print(max(es)) best, count, tags = max(es) plt.plot(X, Y) plt.plot(X[:len(y)], np.array(y)) plt.scatter(X[count], best, color='red') """ assert condition in ['max', 'min'] assert memory >= 1 from collections import deque self.condition = condition self.count = 0 self.history = deque([None] * memory, maxlen=memory) self.counts = deque([None] * memory, maxlen=memory) self.tags = deque([None] * memory, maxlen=memory) def append(self, value, **tags): self.history.append(value) self.counts.append(self.count) self.tags.append(tags) self.count += 1 def __bool__(self): import operator op = {'max': operator.le, 'min':}[self.condition] first = self.history[0] return all(op(h, first) for h in self.history) if first is not None else False def __iter__(self): for h, s, t in zip(self.history, self.counts, self.tags): yield h, s, t def __repr__(self): best, count, tags = dict((("max", max), ("min", min)))[self.condition](self) return f'EarlyStop({bool(self)}, memory={len(self.history)}, condition={self.condition}, ' \ f'best={best}, count={count}, tags={tags})' class EarlyStopper(Hook): def __init__(self, metric, condition='max', memory=5, tags='', modulo=None, trigger='step'): if tags not in ['', '^$', None]: metric = tags + '|' + metric super(EarlyStopper, self).__init__(metric, modulo, trigger) self.early_stop = EarlyStop(condition, memory) self.metric = read_modulo_string(metric)[0] def __call__(self, tags, monitor): metric = tags.pop([k for k in tags if re.match(self.metric, k) is not None][0], None) if metric is None: raise RuntimeError(f'No Metric was found matching {self.metric}') self.early_stop.append(metric, **tags) if self.early_stop: monitor.stop(f'----- {self.early_stop} ----')
[docs]class Checkpointer(Hook):
[docs] def __init__(self, spec, to_keep=None, expand=True): super(Checkpointer, self).__init__(spec) self.to_keep = to_keep self._checkpoint_buffer = {} self.expand = expand
def __call__(self, var_dict, monitor): import torch from xmen.utils import get_version saved = [] if is None: monitor.log(f'WARNING: Cannot checkpoint {list(var_dict.keys())} as monitor does not have a directory') else: if self.expand: pops = [] updates = {} for k, v in var_dict.items(): if isinstance(v, dict): pops.append(k) updates.update({k + '_' + kk: vv for kk, vv in v.items()}) for p in pops: var_dict.pop(p) var_dict.update(updates) for k, v in var_dict.items(): if hasattr(v, 'state_dict'): # Save directory save_dir = os.path.join(, 'checkpoints', k) # Create directory if it doesn't exist if not os.path.exists(save_dir): os.makedirs(save_dir) # Get PATHS in file if buffer is not loaded if k not in self._checkpoint_buffer: files = glob.glob(os.path.join(save_dir, '*')) # Order files in ascending order by step (note sort must be used on int) if len(files) != 0: _, files = zip(*sorted([(int(os.path.splitext(os.path.split(ff)[-1])[0]), ff) for ff in files])) self._checkpoint_buffer[k] = list(files) if len(self._checkpoint_buffer[k]) == self.to_keep: file = self._checkpoint_buffer[k].pop(0) os.remove(file) self._checkpoint_buffer[k].append(os.path.join(save_dir, f'{monitor.step}.pt')) save_dict = {k: v for k, v in monitor.triggers.items() if v != 0.} save_dict.update({ 'version': get_version(cls=type(v)), 'date':"%I:%M%p %B %d, %Y"), 'state_dict': v.state_dict()}), self._checkpoint_buffer[k][-1]) saved.append(k) else: warnings.warn(f'Value for key = {k} does not have a state_dict') monitor.log(f'saved {saved} at {}')
[docs]class XmenMessenger(Hook): """A hook for logging messages with an ``xmen.Experiment`` object. Example 1:: from xmen import Experiment from xmen.monitor import Monitor messenger = XmenMessenger('y.*->ex.*@10s') # log all variables matching the loss to experiments matching ex m = Monitor(hooks=[messenger]) y1, y2 = 0, 0 ex1, ex2 = Experiment(), Experiment()'/tmp', 'ex1')'/tmp', 'ex2') for i in m(range(40)): y1 += 1 y2 += 2 if i % 10 == 1: print(', '.join( [f"ex1: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')] + [f"ex2: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')])) # Output # ex1: y1 = None, ex1: y2 = None, ex2: y1 = None, ex2: y2 = None # ex1: y1 = 10, ex1: y2 = 20, ex2: y1 = 10, ex2: y2 = 20 # ex1: y1 = 20, ex1: y2 = 40, ex2: y1 = 20, ex2: y2 = 40 # ex1: y1 = 30, ex1: y2 = 60, ex2: y1 = 30, ex2: y2 = 60 Example 2:: from xmen import Experiment from xmen.monitor import Monitor ex = Experiment()'/tmp', 'ex') m = Monitor( hooks=[ XmenMessenger('^y$|^i$->^ex$@10s', keep='min', leader='^y$')]) x = -50 for i in m(range(100)): x += 1 y = x ** 2 if i % 10 == 1: print([ex.messages.get(k, None) for k in ('i', 'y')]) # Output # [None, None] # [9, 1600] # [19, 900] # [29, 400] # [39, 100] # [49, 0] # [49, 0] # [49, 0] # [49, 0] # [49, 0] Example 3:: from xmen import Experiment from xmen.monitor import Monitor ex = Experiment()'/tmp', 'ex') m = Monitor( hooks=[ XmenMessenger('z->^ex$@10s', keep='min', leader='y', expand=True)]) z = {'x': 5, 'y': 10} for i in m(range(100)): z['i'] = i z['x'] += 1 z['y'] = z['x'] ** 2 if i % 10 == 1: print([ex.messages.get(k, None) for k in ('i', 'y', 'x')]) # [None, None, None] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] # [9, 225, 15] Example 4:: from xmen import Experiment from xmen.monitor import Monitor ex = Experiment()'/tmp', 'ex') m = Monitor( hooks=[ XmenMessenger('z->^ex$@10s', keep='min', leader='z_y', expand=True, prepend=True)]) z = {'x': 5, 'y': 10} for i in m(range(100)): z['i'] = i z['x'] += 1 z['y'] = z['x'] ** 2 if i % 10 == 1: print([ex.messages.get(k, None) for k in ('z_i', 'z_y', 'z_x')]) # same output as above """
[docs] def __init__(self, spec, keep='latest', leader=None, expand=False, prepend=None): """ Args: spec (Spec): hook specification in the form either ``{log_regex}->{exp_regex}@{modulo}{trigger}`` or ``{exp_regex}@{modulo}{trigger}`` where ``exp_regex`` is the experiments to message and ``log_regex`` are additional messages to log with the experiment. In the second case only timing and step information will be logged keep (str): One of ['latest', 'max', 'min'] in the case of message collision with each experiment leader (regex): If leader is not None then messages will be logged according to ``keep`` as judged by the variable in var_dict which matches ``leader``. expand: If a dictionary variable with K keys matches ``log_regex`` then it is converted to K variables with names corresponding to the keys in dict if ``expand==True``. prepend: If prepend is ``True`` then in the case above the name of each variable will be prepended by the dict variable name. eg. each variable will be called '{name}_{k}'. """ assert keep in ['max', 'min', 'latest'] self.keep = keep self.expand = expand self.leader = leader self.prepend = prepend self.log = r'' if '->' in spec: self.log, spec = spec.split('->') spec = self.log + '|' + spec super(XmenMessenger, self).__init__(spec)
def __call__(self, var_dict, monitor): """Leave messages, timing and step information with experiments""" from xmen.experiment import Experiment # if is None: # monitor.log(f'WARNING: Cannot log {list(var_dict.keys())} as monitor does not have a directory') # else: # get and remove experiments from var_dict results = zip(*[(k, var_dict[k]) for k in var_dict if re.match(self.log, k) is None]) results = list(results) if len(results) == 2: names, experiments = results else: return for k in names: var_dict.pop(k) if self.expand: # Expand dictionaries in var_dict pop, add = [], {} for k, v in var_dict.items(): if isinstance(v, dict): p = '' if self.prepend is None else k + '_' add.update({p + kk: vv for kk, vv in v.items()}) pop.append(k) for p in pop: var_dict.pop(p) var_dict.update(add) for e, name in zip(experiments, names): # Get leader leader, best_leader = None, False if self.leader is not None: leader = [k for k in var_dict if re.match(self.leader, k) is not None][0] if isinstance(e, Experiment): e.message(monitor.summary(verbose=1)) e.message(var_dict, keep=self.keep, leader=leader) keys = list(var_dict.keys()) monitor.log(f'Left messages {keys if keys != [] else ""} with {name} at {}')
[docs]class Timer(Hook): """A simple timing hook used to log any timers open_socket by the experiment monitor""" def __call__(self, var_dict, monitor): s = monitor.summary(verbose=1) keys = [k for k in s if k not in BRIEF.values() and k != 'last'] monitor.log(' '.join(f'{k}={s[k]}' for k in keys))
class Probe(Hook): """A simple probing hook used to retrieve and log a system snapshot with the experiment""" def __call__(self, var_dict, monitor): for k, v in var_dict.items(): try: v.update_meta(get_cpu=True, get_gpu=True, save=True) string = '' cpu = v._meta.get('cpu', None) if cpu is not None: cpu_use = sum(float(c.replace('%', '')) for c in cpu['usage']) / len(cpu['usage']) string += f"cpu={cpu_use}%" gpu = v._meta.get('gpu', None) if gpu is not None: string += ' ' string += ' '.join(f"{kk} = {vv['name']} {vv['load']} {vv['memory']} {vv['temperature']}" for kk, vv in gpu.items()) string += ' ' monitor.log(string) except AttributeError: pass
[docs]class Logger(Hook): """A simple logging hook used to log variables to stdout. Example:: from xmen.monitor import Monitor, Logger m = Monitor( hooks=[ Logger('x@2s', process_func=lambda x: '|'.join(x)), Logger('y@1e', format='.5f')]) x = ['cat', 'dog'] y = 0. for _ in m(range(3)): for i in m(range(5)): y += i # [01:17PM 18/11/20 0/3 2/15]: x = cat|do # [01:17PM 18/11/20 0/3 4/15]: x = cat|do # [01:17PM 18/11/20 1/3]: y = 10.0000 # [01:17PM 18/11/20 1/3 6/15]: x = cat|do # [01:17PM 18/11/20 1/3 8/15]: x = cat|do # [01:17PM 18/11/20 1/3 10/15]: x = cat|do # [01:17PM 18/11/20 2/3]: y = 20.0000 # [01:17PM 18/11/20 2/3 12/15]: x = cat|do # [01:17PM 18/11/20 2/3 14/15]: x = cat|do # [01:17PM 18/11/20 3/3]: y = 30.0000 """
[docs] def __init__(self, spec, format='', process_func=None): """ Args: spec (Spec): a specification string of the form "{regex}@{modulo}{trigger}". Variables matching ``regex`` will be logged when ``trigger`` % ``modulo`` == 0. format (str): a format string used as f"{var:format}" for logging string variables process_func (callable): used to convert variables to a string for logging to stdout. Format will be applied after. """ super(Logger, self).__init__(spec) self.format = format self.process_func = process_func
def __call__(self, var_dict, monitor, *args, format=None, process_func=None): """Log variables to standard out""" s = monitor.summary(verbose=1) triggers = [k for k in s if k in BRIEF.values()] elems = [f'{s["last"]}'] elems += [f'{s[k]}' for k in triggers] string = '[' + ' '.join(elems) + ']:' # string = ' '.join(elems) + ':' # Log other arguments process_func = self.process_func if process_func is None else process_func format = self.format if format is None else format for k, v in var_dict.items(): if process_func is not None: v = process_func(v) string += f' {k} = {v:{format}}' if len(var_dict) != 0: string = string[:-1] for v in args: if process_func is not None: v = process_func(v) string += f' {v:{format}}' print(string)
[docs]class TensorboardLogger(Hook): """A hook for logging tensorboard summaries. Currently ``image``, ``scalar``, ``histogram``, ``figure``, ``video``, ``text``, ``pr_curve``, ``mesh`` are all supported. Before being passed to the summary writer each variable is processes as follows: 1. First all variables are passed to ``fn`` (if supplied). 2. Variables of type list or dict with length K will be expanded to give K variables. The name of each variable will be the list name postpended with its index or the dictionary name postpended with its key. 3. If the summary type is image or scalar then some additional processing will be performed: 1. For scalars if variables are not already scalar variables they will be converted by calling var.mean() 2. For images the variable must be either a 3D [C, H, W] or 4D [B, C, H, W] tensor. Tensors are converted to images of shape [C, H, W]: 3. if the variable is 3D it will be converted to 4D 4. the variable is then converted to an image using ``torchvision.utils.make_grid``. Additional options can be passed to ``torchvision.utils.make_grid`` using the ``options`` parameter. Available options include: - ``'nrow' # images per row`` - ``'padding'`` - ``'normalize'`` - ``'range'`` - ``'scale_each'`` - ``'pad_value'`` (see torchvision.utils.make_grid for more info) Note: Tensorboard does not allow summaries to have the same name. If you want to leave to different types of summary for the same variable then you will need to use the `prefix` argument. """
[docs] def __init__(self, type, spec, fn=None, prefix='', prepend=True, **options): """ Args: type (str): The type of tensorboard summary to log. Should be one of ['image', 'scalar', 'histogram', 'figure', 'video', 'text', 'pr_curve', 'mesh'] spec (str): A string in the form ``"{regex}@{modulo}{trigger}"``. tensorbaord summaries will be logged for all variables matching ``regex`` when ``monitor.{trigger} % modulo == 0` fn (callable): A function used to convert each variable to a summary prefix (str): prepend all summaries with this string prepend (bool): If true prepend dictionary variable names with the dictionary name. **options: Keyword arguments passed to ``torchvision.utils.make_grid`` """ super(TensorboardLogger, self).__init__(spec) self.fn = fn self.type = type self.prefix = prefix self.prepend = prepend if type not in ['image', 'scalar', 'histogram', 'figure', 'video', 'text', 'pr_curve', 'mesh']: raise NotImplementedError("Only image and scalar summaries are currently supported.") self.options = options
def __call__(self, var_dict, monitor): if len(var_dict) > 0: with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=FutureWarning, append=True) import torch.utils.tensorboard as tb if is None: monitor.log(f'WARNING: Cannot log {list(var_dict.keys())} to tensorboard as ' f'monitor does not have a directory') else: if not os.path.join(, 'summaries'): os.makedirs(os.path.join(, 'summaries')) from torch.utils.tensorboard import SummaryWriter tb_writer = tb.SummaryWriter(os.path.join(, 'summaries')) monitor.log(f'saved tb {self.type} summaries for {list(var_dict.keys())} at {}') # print(, 'summaries') for k, v in var_dict.items(): k = self.prefix + k if v is not None: if self.fn is not None: v = self.fn(v) add_summary = getattr(tb_writer, 'add_' + self.type) if isinstance(v, dict): for kk, vv in v.items(): vv = self.make_compatible(vv) p = k + '_' if self.prepend else '' add_summary(p + kk, vv, global_step=monitor.step) elif isinstance(v, (list, tuple)): for i, vv in enumerate(v): vv = self.make_compatible(vv) add_summary(k + '_' + str(i), vv, global_step=monitor.step) else: v = self.make_compatible(v) add_summary(k, v, global_step=monitor.step) tb_writer.flush() tb_writer.close()
[docs] def make_compatible(self, v): """Convert the variable v to a valid input for the summary writer""" import torch import numpy as np from torchvision.utils import make_grid if hasattr(v, 'detach') and hasattr(v, 'clone'): v = v.detach().clone() if self.type == 'image': if isinstance(v, np.ndarray): v = torch.as_tensor(v) options = {'nrow', 'padding', 'normalize', 'range', 'scale_each', 'pad_value'} make_grid_params = {'normalize': True, 'scale_each': True} make_grid_params.update( {k: v for k, v in self.options.items() if k in options}) if len(v.shape) == 4: v = make_grid(v.float(), **make_grid_params) elif len(v.shape) == 3: # Make images a grid if batched and normalise each v = make_grid([v], **make_grid_params) if self.type == 'scalar': if isinstance(v, (torch.Tensor, np.ndarray)) and len(v.shape) > 0: v = v.mean() return v
class StopWatch(object): """A simple timer class""" def __init__(self, name, length=None, time_format='.4f', date_format='%j %H:%M:%S'): """ Args: name (str): the name of the stopwatch length (int): the number of steps to time for (can be None) time_format (str): display timings in this format date_format (str): display date in this format """ = name self.start_time = time.time() self.average = 0. = 0. self.length = length self.since_start = 0. self.n = 0. self._t_fmt = time_format self._date_fmt = date_format self.reference = 0. def start(self): """Start the stop watch""" self.reference = time.time() return self def stop(self): """Stop the stopwatch""" now_time = time.time() = now_time - self.reference self.average += ( - self.average) / (self.n + 1) self.n += 1. return self def projected(self, n_more=None): """Get projected time to completion. If ``n_more is None`` then the time to completion will be inferred from ``length``. As a result at least of one of ``n_more`` or ``length`` must be set""" assert n_more is not None or self.length is not None, 'One of n_more or length must be set' if n_more is None: n_more = (self.length - self.n) return self.average * n_more def wall_time(self): """The time since the experiment started in seconds""" return time.time() - self.start_time def __repr__(self): """String representation""" string = f'{} {{self._t_fmt}}secs ({self.average:{self._t_fmt}} avg)' if self.length is not None: string += ' wall ' + time.strftime(self._date_fmt, time.gmtime(time.time() - self.start_time)) string += ' finish ' + time.strftime(self._date_fmt, time.gmtime(self.projected())) return string
[docs]def load(load_dir, step=None, attempt=True, **modules): """Load torch objects from a checkpoint object. If ``Step is None`` then the most recent checkpoint is loaded. Else the checkpoint is loaded at the specified step. If attempt == False then load will raise a ValueError if either one of the modules was not found in modules or if no checkpoints were found in load_dir. Note: No check to ensure the meta information in each file is the same. The meta information returned corresponds to the first module encountered in modules. """ import torch if load_dir.endswith('/'): load_dir = load_dir[:-1] folders = glob.glob(load_dir + '/*') meta = {} found = [] steps = [] load_step = step if len(folders) != 0: for i, f in enumerate(folders): module_key = f.split('/')[-1] if module_key in modules: files = glob.glob(os.path.join(f, '*')) # Order files in ascending order by step (note sort must be used on int) _, files = zip(*sorted([(int(os.path.splitext(os.path.split(ff)[-1])[0]), ff) for ff in files])) if not len(files): msg = f'No checkpoint found for {f}' if attempt: logging.warning(msg) break else: raise RuntimeWarning(msg) if step is None: file = files[-1] # Take most recent load_step = int(file.split('/')[-1].split('.')[0]) else: file = [f for f in files if str(step) in f] if file: file = file[0] else: msg = f'Could not find checkpoint at step {step} for {f}' if attempt: logging.warning(msg) break else: raise RuntimeWarning(msg) checkpoint = torch.load(file) if i == 0: meta.update({k: checkpoint.get(k, 0) for k in TRIGGERS}) if 'state_dict' in checkpoint: if hasattr(modules[module_key], 'load_state_dict'): modules[module_key].load_state_dict(checkpoint['state_dict']) found += [module_key] steps += [load_step]"Loaded {found} at step {steps}") missing = set(modules.keys()).difference(found) if len(missing) != 0: msg = f'Keys {missing} were not found in the checkpoint' if attempt: logging.warning(msg) else: raise ValueError(msg) else: msg = f'No checkpoints exist in {load_dir}' if attempt: logging.warning(msg) else: raise RuntimeWarning(msg) return modules, meta
[docs]class Monitor(object): """Automate tracking and logging of experiments. Configured through arguments of the form ``f'{regex}@{modulo}{trigger}'``. Any variables in the local scope of the monitor which match the specified regular expression will be logged every modulo steps of the trigger. Triggers are incremented either manually (using the ``inc`` method) or automatically using nested iterators (see example). Supported triggers include ["step", "epoch", "era", "eon", "supereon"] or their abbreviations ["s", "e", "er", "eo", "se"]. If a hook is passed with modulo == None it can instead be triggered manually as ``monitor.to_hooks()``. Example:: from xmen import Experiment from xmen.monitor import Monitor X = Experiment(..., ...) a, b, c, d, e, f = 0, 0, 0, 0, 0, 0 def identity(x): return x def mult(x): return 10 * x m = Monitor( log=('a|b@2s', 'b@1e', 'c@1er', "d|e@1eo", "e@1se"), log_fn=(mult, identity, identity, identity, identity), log_format='.3f', msg='a->X@1s', time=('@2s', '@1e'), probe='X@10s', limit='@20s') for _ in m(range(2)): for _ in m(range(2)): for _ in m(range(3)): for _ in m(range(4)): for _ in m(range(5)): a += 1 b += 1 c += 1 d += 1 e += 1 """
[docs] def __init__(self, *, hooks=[], log=None, log_fn=None, log_format='', time=(), msg=None, msg_keep='latest', msg_leader=None, msg_expand=False, msg_prep=None, probe=None, limit=None): """ Args: hooks (Iterable[Hook]): User defined hooks used to extend the functionality of the Monitor class inheriting from ``Hook``. log (str, Iterable[str]): A modulo string of the form ``"f{regex}"`` or ``"{regex}@{steps}s"`` (or list of) giving the variables to log @ a particular logging frequency to stdout. log_fn (Callable, Iterable[Callable]): Converts the variable into a string for logging. log_format (str, Iterable[str]): A format used to format a string as ``f"{string}:{format}"`` time (str, Iterable[str]): A string of the form ``f"@{steps}"`` to log timing statistics at (or list of for different triggers). msg (str, Iterable[str]): A modulo string of the form ``"{regex}->{exp_regex}@{steps}s"`` (or list of) giving the variables to log as messages with the experiments matching ``exp_regex``. msg_keep (str, Iterable[str]): One of ['latest', 'max', 'min'] giving the protocol to use on message collision msg_leader (str, Iterable[str]): A regex to a single variable. If supplied then this variable will be treated as the leader and all other variables will be logged only if the keep condition is met for the leader msg_expand (bool, Iterable[str]): If True then dictionary variables with K keys will be expanded to give K variables msg_prep (bool, Iterable[str]): If True then each variable in the dictionary will be prepended by the dictionary name. probe (str, Iterable[str]): A string of the form ``f"{regex}@{steps}"`` to log resource use to each experiment that matches regex (or list of for different triggers). limit (str, Iterable[str]): A modulo string of the form ``f"@{modulo}{triger}" used to limit the number of iterations of an experiment at a particular trigger level. This is useful if an experiment is restarted for example. Note: All variables ``ckpt_keep``, ``msg_leader``, ``msg_expand`` and ``msg_prep`` can be supplied as a single or as a list of entries, one for each set of variables matching each modulo string in each case.""" self.hooks = [] self.triggers = {k: 0 for k in TRIGGERS} self.log_regex = None self.timers = {} self.message = None self._limit_reached = False self.limit = read_modulo_string(limit)[1:] if limit is not None else None if log is not None: if not isinstance(log, (list, tuple)): log = [log] if not isinstance(log_fn, (list, tuple)): log_fn = tuple([log_fn] * len(log)) if not isinstance(log_format, (list, tuple)): log_format = tuple([log_format] * len(log)) warn = f"Either one fn, pref and prep or one for each '{log}' must be set" assert all(len(f) == len(log) for f in (log_fn, log_format)), warn for l, f, fmt in zip(log, log_fn, log_format): self.hooks.append(Logger(l, fmt, f)) if time is not None: if not isinstance(time, (list, tuple)): time = [time] for t in time: self.hooks.append(Timer(t)) if msg is not None: if not isinstance(msg, (list, tuple)): msg = [msg] if not isinstance(msg_keep, (list, tuple)): msg_keep = tuple([msg_keep] * len(msg)) if not isinstance(msg_leader, (list, tuple)): msg_leader = tuple([msg_leader] * len(msg)) if not isinstance(msg_prep, (list, tuple)): msg_prep = tuple([msg_prep] * len(msg)) if not isinstance(msg_expand, (list, tuple)): msg_expand = tuple([msg_expand] * len(msg)) warn = f"Either one fn, pref and prep or one for each '{msg}' must be set" assert all(len(f) == len(msg) for f in (msg_keep, msg_leader, msg_expand, msg_prep)), warn for m, k, l, ex, pr in zip(msg, msg_keep, msg_leader, msg_expand, msg_prep): self.hooks.append(XmenMessenger(m, k, l, ex, pr)) if probe is not None: if not isinstance(probe, (list, tuple)): probe = [probe] for p in probe: self.hooks.append(Probe(p)) # Add user hooks self.n_user_hooks = len(hooks) self.hooks.extend(hooks) # -- Use Logger and Check pointer for manual logging # These special hooks are always available self._logger = Logger(r'.*')
@property def supereon(self): return self.triggers['supereon'] @property def eon(self): return self.triggers['eon'] @property def era(self): return self.triggers['era'] @property def epoch(self): return self.triggers['epoch'] @property def step(self): return self.triggers['step'] def stop(self, msg): self._limit_reached = True self.log(msg) def __call__(self, iter, length=None, back=2): """Manage incrementing triggers, logging and collecting timing statistics around current the passed iterator. Args: iter: An iterator back: The number of calls back up to the frame from which called once inside ``inc()``. length: The length of the iterator (useful i length is known but iterator does not have attribute __len__) """ if self._limit_reached: self.log(f'---- Stop Criterion @ ' + ', '.join( [f"{k}={v}" for k, v in self.triggers.items() if v != 0.]) + ' ----') return length = length if length is not None else len(iter) if hasattr(iter, '__len__') else None load, step = 'load', 'step' # Promote old timers if TRIGGERS[0] in self.timers: # Setup (all triggers start at 0 for i, (next_trigger, trigger) in enumerate( zip(reversed(TRIGGERS[1:len(self.timers) + 1]), reversed(TRIGGERS[:len(self.timers)]))): self.timers.update({next_trigger: copy.deepcopy(self.timers[trigger])}) # Get current trigger level if len(self.timers) == 0 or TRIGGERS[0] in self.timers: current = TRIGGERS[0] else: current = TRIGGERS[[i for i, t in enumerate(TRIGGERS) if t in self.timers][0] - 1] # Add timing and logging for current trigger level self.timers[current] = {load: StopWatch(load, length), step: StopWatch(step, length)} def trigger(): return [t for t in TRIGGERS if t in self.timers][0] self.timers[trigger()][load].start() for x in iter: if self.limit is not None and trigger() == self.limit[1] and self.triggers[trigger()] == self.limit[0]: self.stop(f'---- Limit reached for {self.limit[1]} {self.limit[0]} ----') # self._limit_reached = True # print(f'Limit reached for {self.limit[1]} {self.limit[0]}') break elif self._limit_reached: break else: self.timers[trigger()][load].stop() self.timers[trigger()][step].start() yield x self.timers[trigger()][step].stop(), back) self.timers[trigger()][load].start() # Remove timers and hooks for trigger trigger = trigger() self.timers.pop(trigger)
[docs] def inc(self, trigger, back=1): """Manually increment trigger. Will also run modulo hooks defined by the user.""" if trigger not in TRIGGERS: raise NotImplementedError(f'{trigger} is not in {TRIGGERS}') self.triggers[trigger] += 1 f = inspect.currentframe() for _ in range(back): f = f.f_back possible_vars = {k: v for k, v in f.f_locals.items()} for hook in self.hooks: if hook.modulo is not None and self.modulo(trigger, hook.modulo) and hook.trigger == trigger: matches = {} if hook.regex is not None: matches = { k: v for k, v in possible_vars.items() if re.match(hook.regex, k) is not None} hook(matches, self)
[docs] def to_hooks(self, **kwargs): """Pass all keyword arguments to manual hooks (hooks where modulo is None)""" for hook in self.hooks: if hook.modulo is None: matches = { k: v for k, v in kwargs.items() if re.match(hook.regex, k) is not None} hook(matches, self)
[docs] def log(self, x, format='', process_func=str): """Log string with time, epoch and step. Args: x: Value to be logged format: The format used to convert x to a string as '{x:{format}}' process_func: A callable able to convert x to a valid format for printing as {x:{format}}' """ self._logger({}, self, x, format=format, process_func=process_func)
[docs] def modulo(self, trigger, modulo, exclude_1st=False): """Check if trigger % modulo == 0. If exclude_1st is True then modulo will return False for triggers at 0""" is_trigger = self.triggers[trigger] % modulo == 0 if exclude_1st: is_trigger = is_trigger and self.triggers[trigger] > 0 return is_trigger
[docs] def summary(self, verbose=0): """Summarise the current state of the experiment monitor""" # Get statistics timings = {} summaries = {'last':"%I:%M%p %d/%m/%y")} if len(self.timers) > 0: to_date = lambda x: str(datetime.timedelta(seconds=x)) k, N, length, wall = zip(*[( k, self.timers[k]['step'].n, self.timers[k]['step'].length, self.timers[k]['step'].wall_time()) for k in reversed(TRIGGERS) if k in self.timers]) so_far, total = 0., [1.] offset = [l for l in length[1:]] + [1] wall = wall[0] if length[0] is not None: for k, (n, lnow, lnext) in enumerate(zip(N, length, offset)): so_far += n total += [total[k] * lnow] if lnext is None: break else: so_far *= lnext if so_far == 0.: # There is a case once an inner iterator terminates for the # first time meaning that the higher order iterator has yet to be triggered, # left == 0. and a ZeroDivisionError will occur. It could be assumed that the trigger is just # about to be triggered and left promoted to 1.0. In the # case of iterators with a lot of work after the inner iterator has terminated then # this timing estimate could be inaccurate. Instead when this case occurs (which is a minority) # I ignore the timing. lnext = None if lnext is not None: left = (wall / so_far) * (total[-1] - so_far) else: left = None else: left = None length = list(length) length[:len(total)] = total[1:] for i, (trigger, timer) in enumerate([(k, self.timers[k]) for k in reversed(TRIGGERS) if k in self.timers]): step_timer = timer['step'] val = f'{int(self.triggers[trigger])}' if step_timer.length is not None: val += f'/{int(length[i])}' # val += f'%{step_timer.length}' summaries.update({BRIEF[trigger]: val}) if verbose: # Get trigger with highest precedance trigger, timer = [(k, self.timers[k]) for k in TRIGGERS if k in self.timers][0] step_timer, load_timer = timer['step'], timer.get('load', None) if load_timer is not None: if != 0.: if step_timer.length is not None: timings.update( {'next': to_date( (step_timer.average + load_timer.average) * (step_timer.length - step_timer.n))}) timings.update( {'step': to_date(, 'm_step': to_date(step_timer.average), 'load': to_date(, 'm_load': to_date(load_timer.average)}) else: if step_timer.length is not None: timings.update( {'next': to_date( step_timer.average * (step_timer.length - step_timer.n))}) timings.update( {'step': to_date(, 'm_step': to_date(step_timer.average)}) summaries.update({'wall': to_date(wall)}) if left is not None: summaries.update({'end': to_date(left)}) summaries.update(timings) return summaries
def __repr__(self): summary = self.summary(1) string = f'triggers: {self.triggers}' if len(summary) > 0: string += f', current state: {summary}' return string
[docs]class TorchMonitor(Monitor): """Automate tracking, logging and saving of experiments with additional hooks for logging to tensorboard and checkpointing of experiments. Manual logging and checkpoint are also supported as ``monitor.log(...)`` and ``monitor.checkpoint(...)``"""
[docs] def __init__(self, directory=None, *, # User defined hooks (either) modulo or not modulo hooks=[], # Default saving for parameters ckpt=None, ckpt_keep=None, log=None, log_fn=None, log_format='', img=None, img_fn=None, img_pref='', img_prep=True, img_options={}, sca=None, sca_fn=None, sca_pref='', sca_prep=True, hist=None, hist_fn=None, hist_pref='', hist_prep=True, fig=None, fig_fn=None, fig_pref='', fig_prep=True, txt=None, txt_fn=None, txt_pref='', txt_prep=True, vid=None, vid_fn=None, vid_pref='', vid_prep=True, time=(), msg=None, msg_keep='latest', msg_leader=None, msg_expand=False, msg_prep=None, probe=None, limit=None): """ Args: directory (str): The directory used to log checkpoints in hooks (Iterable[Hook]): User defined hooks used to extend the functionality of the Monitor class ckpt (str, Iterable[str]): A modulo string of the form ``"f{regex}"`` or ``"{regex}@{steps}s"`` (or list of) giving the variables to checkpoint @ a particular logging frequency to stdout. The regex much match objects inheriting form torch.Module. ckpt_keep (int, Iterable[int]): The number of checkpoints to keep. The most recent checkpoints will be kept. If None then all checkpoints will be kept. log (str, Iterable[str]): A modulo string of the form ``"f{regex}"`` or ``"{regex}@{steps}s"`` (or list of) giving the variables to log @ a particular logging frequency to stdout. log_fn (Callable, Iterable[Callable]): Converts the variable into a string for logging. log_format (str, Iterable[str]): A format used to format a string as ``f"{string}:{format}"`` time (str, Iterable[str]): A string of the form ``f"@{steps}"`` to log timing statistics at (or list of for different triggers). img (str, Iterable[str]): A modulo string of the form ``"f{regex}"`` or ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard images @ a particular logging frequency. img_fn (Callable, Iterable[Callable]): Converts the variable into an image of shape [B, C, H, W] or [C, H, W] for tensorboard. See TensorBoardLogger for more details in terms of automatic-processing. If ``img`` is a list then ``img_fn`` can also be passed as list with a callable for each entry in ``img`` or can be passed as a single callable used for all entries. img_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string img_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary img (str, Iterable[str]): A modulo string of the form ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard scalars @ a particular logging frequency. sca_fn (Callable, Iterable[Callable]): Converts the variable into a scalar for tensorboard. See TensorBoardLogger for more details in terms of automatic-processing. If ``sca`` is a list then ``sca_fn`` can also be passed as list with a callable for each entry in ``sca`` or can be passed as a single callable used for all entries. sca_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string sca_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary hist (str, Iterable[str]): A modulo string of the form ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard histograms @ a particular logging frequency. If no logging frequency is supplied then any variable logged in the experiment which matches ``regex`` will be logged to tensorboard each time it is passed to the logger. This is useful for logging variables at the end of an epoch for example. hist_fn (Callable, Iterable[Callable]): Preprocess variable before logging to tensorboard. If ``hist`` is a list then ``hist_fn`` can also be passed as list with a callable for each entry in ``hist`` or can be passed as a single callable used for all entries. hist_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string hist_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary fig (str, Iterable[str]): A modulo string of the form or ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard figures @ a particular logging frequency. fig_fn (Callable, Iterable[Callable]): Preprocess variable before logging to tensorboard into a plt.figure(). If ``fig`` is a list then ``fig_fn`` can also be passed as list with a callable for each entry in ``fig`` or can be passed as a single callable used for all entries. fig_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string fig_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary txt (str): A modulo string of the form ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard text @ a particular logging frequency. txt_fn (Callable, Iterable[Callable]): Preprocess variable before logging to tensorboard. If ``txt`` is a list then ``txt_fn`` can also be passed as list with a callable for each entry in ``txt`` or can be passed as a single callable used for all entries. txt_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string txt_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary vid (str, Iterable[str]): A modulo string of the form ``"{regex}@{steps}s"`` (or list of) giving the variables to add as tensorboard videos @ a particular logging frequency. vid_fn (Callable, Iterable[Callable]): Preprocess variable to a tensor of shape [B, T, C, H, W] before logging to tensorboard. If ``vid`` is a list then ``vid_fn`` can also be passed as list with a callable for each entry in ``vid`` or can be passed as a single callable used for all entries. vid_pref (str, Iterable[str]): Prefix all summaries in tensorboard with this string vid_prep (bool, Iterable[bool]): If True then dictionary variables will be prepended by the name of the dictionary msg (str, Iterable[str]): A modulo string of the form ``"{regex}->{exp_regex}@{steps}s"`` (or list of) giving the variables to log as messages with the experiments matching ``exp_regex``. msg_keep (str, Iterable[str]): One of ['latest', 'max', 'min'] giving the protocol to use on message collision msg_leader (str, Iterable[str]): A regex to a single variable. If supplied then this variable will be treated as the leader and all other variables will be logged only if the keep condition is met for the leader msg_expand (bool, Iterable[bool]): If True then dictionary variables with K keys will be expanded to give K variables msg_prep (bool, Iterable[bool]): If True then each variable in the dictionary will be prepended by the dictionary name. probe (str, Iterable[str]): A string of the form ``f"{regex}@{steps}"`` to log resource use to each experiment that matches regex (or list of for different triggers). limit (str): A modulo string of the form ``f"@{modulo}{triger}" used to limit the number of iterations of an experiment at a particular trigger level. This is useful if an experiment is restarted for example. Note: All variables `..._fn`, `..._pref` and `..._prep` as well as ``ckpt_keep`` and ``msg_leader``, ``msg_expand`` and ``msg_prep`` can be supplied as a single or as a list of entries, one for each set of variables matching each modulo string in each case. Example 1:: nn, opt, dataset = ..., ... m = Monitor( directory, checkpoint=('model@1e', 'opt@100s'), # Checkpoint the model once per epoch and opt every 100 steps log='^loss$@100s', # Log the loss to stdout every 100 steps img='^x$@1000s', sca=('^loss$@100s', 'eval_.*@1e'), time=('@100s') # Log to tensorboard time=('@100s', ), # Generate timing statistics every 100 steps hooks=[ # Custom hooks are also supported MyVeryOwnHook(...)]) # The only modification needed to the training loop are the em calls. # Nested loops corresponds to different triggers from inside out # we have ["step" or "s", "epoch" or "e", "era" or "er", "eon" or "eo", "supereon" or "se"] for epoch in m(range(10)): for x, y in m(datset): _y_ = model(x) opt.zero_grad() loss = loss_fn(y, _y_) loss.backward() loss.step() em.log('Manual Logging is also supported') eval_1, eval_2 = eval(model, ds) # Steps and epoch have been incremented assert em.step == len(ds) * 10 assert em.epoch == 10 # Lets reload the model at the 5th epoch em.load(step=5*len(ds), model) # The step and epoch will be updated print(em.step, em.epoch) Example 2:: from xmen.monitor import Monitor import numpy as np import torch import os import matplotlib.pyplot as plt from torchvision.datasets.mnist import MNIST from import DataLoader import torchvision.transforms as T'ggplot') ds = DataLoader(MNIST(os.getenv("HOME") + '/data/mnist', download=True, transform=T.Compose( [T.Resize([64, 64]), T.CenterCrop([64, 64]), T.ToTensor(), T.Normalize([0.5], [0.5])])), 8) m = Monitor( directory='/tmp/tb_5', sca=['^z$|^X$@10s', '^a|^x$@5s'], img=['^mnist@10s', '^mnist@5s'], img_fn=[lambda x: x[:2], lambda x: x[:5]], img_pref=['2', '5'], hist='Z@1s', fig='fig@10s', txt='i@5s', txt_fn=lambda x: f'Hello at step {x}', vid='^mnist@10s', vid_fn=lambda x: (x.unsqueeze(0) - x.min()) / (x.max() - x.min()) ) # variables x = 0. a = [1, 2] z = {'x': 5, 'y': 10} for i, (mnist, _) in m(zip(range(31), ds)): # plot a figure fig = plt.figure(figsize=[10, 5]) plt.plot(np.linspace(0, 1000), np.cos(np.linspace(0, 1000) * i)) # random tensor Z = torch.randn([10, 3, 64, 64]) * i / 100 # scalars x = (i - 15) ** 2 z['i'] = i z['x'] += 1 z['y'] = z['x'] ** 2 """ super().__init__(hooks=hooks, log=log, log_fn=log_fn, log_format=log_format, time=time, msg=msg, msg_keep=msg_keep, msg_leader=msg_leader, msg_expand=msg_expand, msg_prep=msg_prep, probe=probe, limit=limit) = directory # -- Add modulo hooks to hooks if ckpt is not None: if not isinstance(ckpt, (list, tuple)): ckpt = [ckpt] if not isinstance(ckpt_keep, (list, tuple)): ckpt_keep = tuple([ckpt_keep] * len(ckpt)) warn = f"Either one fn, pref and prep or one for each '{ckpt_keep}' must be set" assert all(len(f) == len(ckpt) for f in (ckpt_keep, )), warn for c, k in zip(ckpt, ckpt_keep): self.hooks.append(Checkpointer(c, to_keep=k)) # tensorboard summaries kinds = {'sca': 'scalar', 'img': 'image', 'hist': 'histogram', 'fig': 'figure', 'txt': 'text', 'vid': 'video'} for k in kinds: kind = locals()[k] if kind is not None: options = {} if k == 'img': options = locals()[k + '_options'] fn = locals()[k + '_fn'] pref = locals()[k + '_pref'] prep = locals()[k + '_prep'] if not isinstance(kind, (list, tuple)): kind = [kind] if not isinstance(fn, (list, tuple)): fn = tuple([fn] * len(kind)) if not isinstance(pref, (list, tuple)): pref = tuple([pref] * len(kind)) if not isinstance(prep, (list, tuple)): prep = tuple([prep] * len(kind)) warn = f"Either one fn, pref and prep or one for each '{kind}' must be set" assert all(len(f) == len(kind) for f in (fn, pref, prep)), warn for c, f, pr, pre in zip(kind, fn, pref, prep): self.hooks.append(TensorboardLogger(kinds[k], c, fn=f, prefix=pr, prepend=pre, **options)) self._checkpointer = Checkpointer(r'.*', to_keep=ckpt_keep)
[docs] def checkpoint(self, **kwargs): """Checkpoint the torch.nn objects with step and epoch passed as ``name==variable_to_save``""" self._checkpointer(kwargs, self)
[docs] def load(self, directory=None, step=None, attempt=True, update_triggers=True, **modules): """Load the torch torch.nn objects passed as name=variable_to_load, from the directory and reset the state of the em (if update_triggers == True). If attempt == False then an Exception will be raised if either the directory does not contain checkpoints corresponding to modules. """ if directory is None: if is None: self.log(f'WARNING: Cannot load as monitor does not have a directory') return {} else: directory = os.path.join(, 'checkpoints') modules, triggers = load(directory, step=step, attempt=attempt, **modules) if update_triggers: self.triggers.update(triggers) return modules