Examples

Introduction

All examples in this section are defined in xmen.examples and can be run from the commandline using the xmen command line interface; just add them using xmen --add (if they are not already).

Hello World

"""Basic Examples.
"""
#  Copyright (C) 2019  Robert J Weston, Oxford Robotics Institute
#
#  xmen
#  email:   robw@robots.ox.ac.uk
#  github: https://github.com/robw4/xmen/
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#  You should have received a copy of the GNU General Public License
#   along with this program. If not, see <http://www.gnu.org/licenses/>.
import xmen


@xmen.autodoc
def hello_world(
    root: xmen.Root,   # experiments are assigned a root before being executed
    a: str = 'Hello',  # the first
    # argument
    b: str = 'World',   # the second argument
):
    """A hello world experiment designed to demonstrate
    defining experiments through the functional experiment api"""
    print(f'{a}  {b}')

    ...     # whatever other experiment code you want

    with open(root.directory + '/out.txt', 'w') as f:
        f.write(f'{a} {b}')
    root.message({'a': a, 'b': b})


class HelloWorld(xmen.Experiment):
    """A hello world experiment designed to demonstrate
    defining experiments through the class experiment api"""
    # Parameters
    a: str = 'Hello'  # @p The first argument
    b: str = 'World'  # @p The second argument

    def run(self):
        print(f'{self.a} {self.b}!')
        self.message({'a': self.a, 'b': self.b})


if __name__ == '__main__':
    # optionally expose the command line interface if you
    # would like to run the experiment as a experiments script
    from xmen.functional import functional_experiment
    # generate experiment from function definition if defined
    # using the functional experiment (this step is not needed if
    # the experiment is defined as a class)
    Exp = functional_experiment(hello_world)
    # every experiment inherits main() allowing the experiment
    # to be configured and run from the command line.
    Exp().main()
    # try...
    # >> experiments -m xmen.examples.hello_world --help

A little more detail

"""Using the class api to define experiments with inheritance."""

from xmen.experiment import Experiment
import os
import time
from typing import List


class BaseExperiment(Experiment):
    """A basic experiments experiment demonstrating the features of the xmen api."""

    # Parameters are defined as attributes in the class body with the
    # @p identifier

    t = 'cat'    # @p
    w = 3        # @p parameter w has a help message whilst t does not
    h: int = 10  # @p h declared with typing is very concise and neat

    # Parameters can also be defined in the __init__ method
    def __init__(self, *args, **kwargs):
        super(BaseExperiment, self).__init__(*args, **kwargs)

        self.a: str = 'h'  # @p A parameter
        self.b: int = 17   # @p Another parameter

        # Normal attributes are still allowed
        self.c: int = 5    # This is not a parameter


class AnotherExperiment(BaseExperiment):
    m: str = 'Another value'  # @p Multiple inheritance example
    p: str = 'A parameter only in Another Experiment '  # @p


class AnExperiment(BaseExperiment):
                    #     |
                    # Experiments can inherit from other experiments
                    # parameters are inherited too
    """An experiment testing the xmen experiment API. The __docstring__ will
    appear in both the docstring of the class __and__ as the prolog in the
    command line interface."""

    # Feel free to define more parameters
    x: List[float] = [3., 2.]  # @p Parameters can be defined cleanly as class attributes
    y: float = 5  # @p This parameter will have this
    # Parameters can be overridden
    a: float = 0.5  # a's default and type will be changed. Its help will be overridden
    b: int = 17  # @p b's help will be changed

    m: str = 'Defined in AnExperiment'  # @p m is defined in AnExperiment

    def run(self):
        # Experiment execution is defined in the run method
        print(f'The experiment state inside run is {self.status}')

        # recording messaging is super easy
        self.message({'time': time.time()})

        # Each experiment has its own unique directory. You are encourage to
        # write out data accumulated through the execution (snapshots, logs etc.)
        # to this directory.
        with open(os.path.join(self.directory, 'logs.txt'), 'w') as f:
            f.write('This was written from a running experiment')

    def debug(self):
        self.a = 'In debug mode'
        return self

    @property
    def h(self):
        return 'h has been overloaded as property and will no longer' \
               'considered as a parameter'


# Experiments can inheret from multiple classes
class MultiParentsExperiment(AnotherExperiment, AnExperiment):
    pass


if __name__ == '__main__':
    # documentation is automatically added to the class
    help(AnExperiment)

    # to run an experiment first we initialise it
    # the experiment is initialised in 'default' status
    print('\nInitialising')
    print('---------------------------')
    exp = AnExperiment()
    print(exp)

    # whilst the status is default the parameters of the
    # experiment can be changed
    print('\nConfiguring')
    print('------------ ')
    exp = AnExperiment()
    exp.a = 'hello'
    exp.update({'t': 'dog', 'w': 100})

    # Note parameters are copied from the class during
    # instantiation. This way you don't need to worry
    # about accidentally changing the mutable class
    # types across the entire class.
    exp.x += [4.]
    print(exp)
    assert AnExperiment.x == [3., 2.]
    # If this is not desired (or neccessary) initialise
    # use exp = AnExperiment(copy=False)

    # Experiments can inheret from multiple classes:
    print('\nMultiple Inheritance')
    print('----------------------')
    print(MultiParentsExperiment())
    print('\n Parameters defaults, helps and values are '
          'inherited according to experiments method resolution order '
          '(i.e left to right). Note that m has the value '
          'defined in Another Experiment')

    print('\nRegistering')
    print('-------------')
    # Before being run an experiment needs to be registered
    # to a directory
    exp.register('/tmp/an_experiment', 'first_experiment',
                 purpose='A bit of a test of the xmen experiment api')
    print(exp, end='\n')
    print('\nGIT, and system information is automatically logged\n')

    # The parameters of the experiment can no longer be changed
    try:
        exp.a = 'cat'
    except AttributeError:
        print('Parameters can no longer be changed!', end='\n')
        pass

    # An experiment can be run either by...
    # (1) calling it
    print('\nRunning (1)')
    print('-------------')
    exp()

    # (2) using it as a context. Just define a main
    #     loop like you normally would
    print('\nRunning (2)')
    print('-------------')
    with exp as e:
        # Inside the experiment context the experiment status is 'running'
        print(f'Once again the experiment state is {e.status}')
        # Write the main loop just as you normally would"
        # using the parameters already defined
        results = dict(sum=sum(e.x), max=max(e.x), min=min(e.x))
        # Write results to the expeirment just as before
        e.message(results)

    # All the information about the current experiment is
    # automatically saved in the experiments root directory
    # for free
    print(f'\nEverything ypu might need to know is logged in {exp.directory}/params.yml')
    print('-----------------------------------------------------------------------------------------------')
    print('Note that GIT, and system information is automatically logged\n'
          'along with the messages')
    with open('/tmp/an_experiment/first_experiment/params.yml', 'r') as f:
        print(f.read())

    print('\nRegsitering and Configuring from the command line')
    print('---------------------------------------------------')
    # Alternatively configuring and registering an experiment
    # can be done from the command line as:
    exp = AnExperiment()
    args = exp.parse_args()

    print('\nRegsitering, Configuring and running in a single line!')
    print('--------------------------------------------------------')
    # Or alternatively
    # the experiment can be _configured_, _registered_
    # and _run_ by including a single line of code:
    AnExperiment().main()
    # See ``experiments xmen.tests.experiment --help`` for
    # more information

The Monitor Class

The Monitor class is designed to facilitate easy logging of experiements. All the examples in this section can be found in xmen.examples.monitor and can in experiments as:

experiments -m xmen.examples.monitor.logger
experiments -m xmen.examples.monitor.messenger.basic
experiments -m xmen.examples.monitor.messenger.leader
experiments -m xmen.examples.monitor.full

Logging

"""Automatic Logging"""
from xmen.monitor import Monitor

m = Monitor(
    log=['x@2s', 'y@1e'],
    log_fn=[lambda _: '|'.join(_), lambda _: _],
    log_format=['', '.5f'])

x = ['cat', 'dog']
y = 0.
for _ in m(range(3)):
    for i in m(range(5)):
        y += i


# output
[05:36PM 18/11/20 0/3 2/15]: x = cat|do
[05:36PM 18/11/20 0/3 4/15]: x = cat|do
[05:36PM 18/11/20 1/3]: y = 10.0000
[05:36PM 18/11/20 1/3 6/15]: x = cat|do
[05:36PM 18/11/20 1/3 8/15]: x = cat|do
[05:36PM 18/11/20 1/3 10/15]: x = cat|do
[05:36PM 18/11/20 2/3]: y = 20.0000
[05:36PM 18/11/20 2/3 12/15]: x = cat|do
[05:36PM 18/11/20 2/3 14/15]: x = cat|do
[05:36PM 18/11/20 3/3]: y = 30.0000

Messaging

Messgaging Multiple Experiments

"""automatic messaging"""
from xmen import Experiment
from xmen.monitor import Monitor

# link some experiments
ex1, ex2 = Experiment(), Experiment()
ex1.register('/tmp', 'ex1')
ex2.register('/tmp', 'ex2')

# open_socket monitor
m = Monitor(msg='y.*->ex.*@10s')

y1, y2 = 0, 0
for i in m(range(40)):  # monitor loop
    y1 += 1
    y2 += 2
    if i % 10 == 1:
        # messages added automatically to ex1 and ex2
        print(', '.join(
            [f"ex1: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')] +
            [f"ex2: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')]))

# timing information is also logged
print('\nAll Messages')
for k, v in ex1.messages.items():
    print(k, v)




# output
ex1: y1 = None, ex1: y2 = None, ex2: y1 = None, ex2: y2 = None
[06:00PM 18/11/20 10/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 10/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 10, ex1: y2 = 20, ex2: y1 = 10, ex2: y2 = 20
[06:00PM 18/11/20 20/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 20/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 20, ex1: y2 = 40, ex2: y1 = 20, ex2: y2 = 40
[06:00PM 18/11/20 30/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 30/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 30, ex1: y2 = 60, ex2: y1 = 30, ex2: y2 = 60
[06:00PM 18/11/20 40/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 40/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2

All Messages
last 06:00PM 18/11/20
s 40/40
wall 0:00:00.682972
end 0:00:00
next 0:00:00
step 0:00:00.000007
m_step 0:00:00.000013
load 0:00:00.000005
m_load 0:00:00.000005
y1 40
y2 80

Using a leader

"""using the leader argument"""
from xmen import Experiment
from xmen.monitor import TorchMonitor

ex = Experiment()
ex.register('/tmp', 'ex')
m = TorchMonitor(msg='^y$|^i$->^ex$@10s', msg_keep='min', msg_leader='^y$')

x = -50
for i in m(range(100)):
    x += 1
    y = x ** 2
    if i % 10 == 1:
        print([ex.messages.get(k, None) for k in ('i', 'y')])

# output
[None, None]
[06:04PM 18/11/20 10/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[9, 1600]
[06:04PM 18/11/20 20/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[19, 900]
[06:04PM 18/11/20 30/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[29, 400]
[06:04PM 18/11/20 40/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[39, 100]
[06:04PM 18/11/20 50/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 60/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 70/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 80/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 90/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 100/100]: Left messages ['i', 'y'] with ex at /tmp/ex

A Full example

from xmen import Experiment
from xmen.monitor import Monitor
import os

X = Experiment(os.path.join(os.environ['HOME'], 'tmp'), 'ex1')

a, b, c, d, e, f = 0, 0, 0, 0, 0, 0


def identity(x):
    return x


def mult(x):
    return 10 * x


m = Monitor(
    log=('a|b@2s', 'b@1e', 'c@1er', "d|e@1eo", "e@1se"),
    log_fn=(mult, identity, identity, identity, identity),
    log_format='.3f',
    msg='a->X@1s',
    time=('@2s', '@1e'),
    probe='X@10s',
    limit='@20s')
for _ in m(range(2)):
    for _ in m(range(2)):
        for _ in m(range(3)):
            for _ in m(range(4)):
                for _ in m(range(5)):
                    a += 1
                b += 1
            c += 1
        d += 1
    e += 1
=[03:07PM 05/02/21 0/2 0/4 0/12 0/48 1/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: a = 20.000 b = 0.00
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: wall=0:00:00.050627 end=0:00:06.024641 next=0:00:00.000054 step=0:00:00.000006 m_step=0:00:00.000007 load=0:00:00.000010 m_load=0:00:00.000011
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 3/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: a = 40.000 b = 0.00
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: wall=0:00:00.135165 end=0:00:07.974762 next=0:00:00.000017 step=0:00:00.000006 m_step=0:00:00.000006 load=0:00:00.000011 m_load=0:00:00.000011
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 5/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48]: b = 1.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48]: wall=0:00:00.214666 end=0:00:10.089297 next=0:00:00.637951 step=0:00:00.212640 m_step=0:00:00.212640 load=0:00:00.000010 m_load=0:00:00.000010
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: a = 60.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: wall=0:00:00.215068 end=0:00:08.387665 next=0:00:00.000059 step=0:00:00.000006 m_step=0:00:00.000006 load=0:00:00.000009 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 7/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: a = 80.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: wall=0:00:00.298510 end=0:00:08.656785 next=0:00:00.000045 step=0:00:00.000011 m_step=0:00:00.000009 load=0:00:00.000016 m_load=0:00:00.000014
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 9/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: a = 100.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: wall=0:00:00.387885 end=0:00:08.921352 next=0:00:00 step=0:00:00.000010 m_step=0:00:00.000009 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: cpu=3.5%   0 = TITAN X (Pascal) 88.0% 2116.0MB 85.0°C   1 = GeForce GTX TITAN Black 100.0% 3774.0MB 74.0°C
[03:07PM 05/02/21 0/2 0/4 0/12 2/48]: b = 2.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48]: wall=0:00:02.976140 end=0:01:08.451221 next=0:00:02.973728 step=0:00:02.761070 m_step=0:00:01.486855 load=0:00:00.000008 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 11/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: a = 120.000 b = 20.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: wall=0:00:03.043639 end=0:00:57.829140 next=0:00:00.000060 step=0:00:00.000007 m_step=0:00:00.000008 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 13/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: a = 140.000 b = 20.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: wall=0:00:03.198814 end=0:00:51.637992 next=0:00:00.000021 step=0:00:00.000010 m_step=0:00:00.000008 load=0:00:00.000016 m_load=0:00:00.000013
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 15/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48]: b = 3.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48]: wall=0:00:03.350787 end=0:00:50.261807 next=0:00:01.115966 step=0:00:00.374161 m_step=0:00:01.115957 load=0:00:00.000008 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: a = 160.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: wall=0:00:03.351449 end=0:00:46.920286 next=0:00:00.000102 step=0:00:00.000010 m_step=0:00:00.000010 load=0:00:00.000015 m_load=0:00:00.000015
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 17/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: a = 180.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: wall=0:00:03.502008 end=0:00:43.191432 next=0:00:00.000042 step=0:00:00.000010 m_step=0:00:00.000009 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 19/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: a = 200.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: wall=0:00:03.647854 end=0:00:40.126398 next=0:00:00 step=0:00:00.000012 m_step=0:00:00.000009 load=0:00:00.000018 m_load=0:00:00.000014
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: cpu=3.5%   0 = TITAN X (Pascal) 96.0% 2116.0MB 86.0°C   1 = GeForce GTX TITAN Black 100.0% 3774.0MB 75.0°C
[03:07PM 05/02/21 0/2 0/4 0/12 4/48]: b = 4.00
[03:07PM 05/02/21 0/2 0/4 0/12 4/48]: wall=0:00:04.948039 end=0:00:54.428427 next=0:00:00 step=0:00:01.596759 m_step=0:00:01.236157 load=0:00:00.000013 m_load=0:00:00.000010
[03:07PM 05/02/21 0/2 0/4 1/12]: c = 1.00
[03:07PM 05/02/21 0/2 0/4 1/12 4/48 20/240]: ---- Limit reached for step 20 ----
[03:07PM 05/02/21 0/2 0/4 1/12 5/48]: b = 5.00
[03:07PM 05/02/21 0/2 0/4 1/12 5/48]: wall=0:00:04.949144 end=0:00:42.562637 next=0:00:00.000554 step=0:00:00.000172 m_step=0:00:00.000172 load=0:00:00.000013 m_load=0:00:00.000013
[03:07PM 05/02/21 0/2 0/4 2/12]: c = 2.00
[03:07PM 05/02/21 0/2 1/4]: d = 1.000 e = 0.00
[03:07PM 05/02/21 1/2]: e = 1.00

The TorchMonitor Class

The TorchMonitor class adds to the functionality of Monitor also allowing torch modules to be automatically saved and variables to be logged to tensorboard.

experiments -m xmen.examples.monitor.torch_monitor

experiments -m xmen.examples.monitor.checkpoint

Automatic Checkpointing

"""Automatic check-pointing"""
from xmen.monitor import TorchMonitor
from torch.nn import Conv2d
from torch.optim import Adam

model = Conv2d(2, 3, 3)
model2 = Conv2d(2, 3, 3)
optimiser = Adam(model.parameters())

m = TorchMonitor(
    '/tmp/checkpoint',
    ckpt=['^model$@5s', 'opt@1e', '^model2$@20s'],
    ckpt_keep=[5, 1, None])
for _ in m(range(10)):
    for _ in m(range(20)):
        # Do something
        ...

# result
tree /tmp/checkpoint
└── checkpoints
    ├── model
    │   ├── 180.pt
    │   ├── 185.pt
    │   ├── 190.pt
    │   ├── 195.pt
    │   └── 200.pt
    ├── model2
    │   ├── 100.pt
    │   ├── 120.pt
    │   ├── 140.pt
    │   ├── 160.pt
    │   ├── 180.pt
    │   ├── 20.pt
    │   ├── 200.pt
    │   ├── 40.pt
    │   ├── 60.pt
    │   └── 80.pt
    └── optimiser
        └── 200.pt

Tensorboard Logging

"""automatic tensorboard logging"""
from xmen.monitor import TorchMonitor
import numpy as np
import torch
import os
import matplotlib.pyplot as plt
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T

plt.style.use('ggplot')
ds = DataLoader(MNIST(os.getenv("HOME") + '/data/mnist', download=True,
      transform=T.Compose(
          [T.Resize([64, 64]), T.CenterCrop([64, 64]),
           T.ToTensor(), T.Normalize([0.5], [0.5])])), 8)

m = TorchMonitor(
    directory='/tmp/tb',
    sca=['^z$|^X$@10s', '^a|^x$@5s'],
    img=['^mnist@10s', '^mnist@5s'], img_fn=[lambda x: x[:2], lambda x: x[:5]], img_pref=['2', '5'],
    hist='Z@1s',
    fig='fig@10s',
    txt='i@5s', txt_fn=lambda x: f'Hello at step {x}',
    vid='^mnist@10s', vid_fn=lambda x: (x.unsqueeze(0) - x.min()) / (x.max() - x.min())
)

# variables
x = 0.
a = [1, 2]
z = {'x': 5, 'y': 10}
for i, (mnist, _) in m(zip(range(31), ds)):
    # plot a figure
    fig = plt.figure(figsize=[10, 5])
    plt.plot(np.linspace(0, 1000), np.cos(np.linspace(0, 1000) * i))
    # random tensor
    Z = torch.randn([10, 3, 64, 64]) * i / 100
    # scalars
    x = (i - 15) ** 2
    z['i'] = i
    z['x'] += 1
    z['y'] = z['x'] ** 2
# to visualise results run
tensorboard --logdir /tmp/tb
Test Test

Pytorch experiments with Xmen

All examples in this section are defined in xmen.examples.torch and can be run from the commandline using the xmen command line interface; just add them using xmen --add (if they are not already). Pytorch will need to be installed in order to run these examples.

DCGAN using the functional api

"""A functional implementation of dcgan"""
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#  You should have received a copy of the GNU General Public License
#   along with this program. If not, see <http://www.gnu.org/licenses/>.
#  Copyright (C) 2019  Robert J Weston, Oxford Robotics Institute
#
#  xmen
#  email:   robw@robots.ox.ac.uk
#  github: https://github.com/robw4/xmen/
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
from typing import Tuple
import xmen
import os

try:
    import torch
except ImportError:
    print('In order to run this script first add pytorch to the experiments path')


def get_datasets(cy, cz, b, ngpus, ncpus, ns, data_root, hw, **kwargs):
    """Returns a dictionary of iterable get_datasets for modes 'train' (label image pairs)
    and 'inference' (just inputs spanning the prediction space)"""
    from torch.utils.data import DataLoader
    from torchvision.datasets.mnist import MNIST
    from torch.distributions import Normal
    import torchvision.transforms as T

    def to_target(y):
        Y = torch.zeros([cy])
        Y[y] = 1.
        return Y.reshape([cy, 1, 1])
    # Generate test samples
    y = torch.stack([to_target(i) for i in range(cy)])
    y = y.unsqueeze(0).expand(
        [ns, cy, cy, 1, 1]).reshape(  # Expand across last dim
        [-1, cy, 1, 1])  # Batch
    z = Normal(0., 1.).sample([y.shape[0], cz, 1, 1])
    return {
        'train': DataLoader(
            MNIST(data_root, download=True,
                  transform=T.Compose(
                    [T.Resize(hw), T.CenterCrop(hw),
                     T.ToTensor(), T.Normalize([0.5], [0.5])]),
                  target_transform=to_target),
            batch_size=b * ngpus,
            shuffle=True, num_workers=ncpus),
        'inference': list(
            zip(y.unsqueeze(0), z.unsqueeze(0)))}  # Turn into batches


@xmen.autodoc
def dcgan(
    root: xmen.Root,  #
        # first argument is always an experiment instance.
        # can be unused (specify with _) in experiments
        # syntax practice. can be named whatever depending
        # on use case. Eg. logger, root, experiment ...
    b: int = 128,  # the batch size per gpu
    hw0: Tuple[int, int] = (4, 4),  # the height and width of the image
    nl: int = 4,  # the number of levels in the discriminator.
    data_root: str = os.getenv("HOME") + '/data/mnist',  # @p the root data directory
    cx: int = 1,
    cy: int = 10,  # the dimensionality of the conditioning vector
    cf: int = 512,  # the number of features after the first conv in the discriminator
    cz: int = 100,  # the dimensionality of the noise vector
    ncpus: int = 8,  # the number of threads to use for data loading
    ngpus: int = 1,  # the number of gpus to run the model on
    epochs: int = 20,  # no. of epochs to train for
    lr: float = 0.0002,  # learning rate
    betas: Tuple[float, float] = (0.5, 0.999),  # the beta parameters for the
    # monitoring parameters
    checkpoint: str = 'nn_.*@1e',  # checkpoint at this modulo string
    log: str = 'loss_.*@20s',  # log scalars
    sca: str = 'loss_.*@20s',  # tensorboard scalars
    img: str = '_x_|x$@20s',  # tensorboard images
    nimg: int = 64,  # the maximum number of images to display to tensorboard
    ns: int = 5  # the number of samples to generate at inference)
):
    """Train a conditional GAN to predict MNIST digits.

    To viusalise the results run::

        tensorboard --logdir ...

    """
    from xmen.monitor import TorchMonitor, TensorboardLogger
    from xmen.examples.models import weights_init, set_requires_grad, GeneratorNet, DiscriminatorNet
    from torch.distributions import Normal
    from torch.distributions.one_hot_categorical import OneHotCategorical
    from torch.optim import Adam
    import logging

    hw = [d * 2 ** nl for d in hw0]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logger = logging.getLogger()
    logger.setLevel('INFO')
    # dataset
    datasets = get_datasets(
        cy, cz, b, ngpus, ncpus, ns, data_root, hw)
    # models
    nn_g = GeneratorNet(cy, cz, cx, cf, hw0, nl)
    nn_d = DiscriminatorNet(cx, cy, cf, hw0, nl)
    nn_g = nn_g.to(device).float().apply(weights_init)
    nn_d = nn_d.to(device).float().apply(weights_init)
    # distributions
    pz = Normal(torch.zeros([cz]), torch.ones([cz]))
    py = OneHotCategorical(probs=torch.ones([cy]) / cy)
    # optimisers
    op_d = Adam(nn_d.parameters(), lr=lr, betas=betas)
    op_g = Adam(nn_g.parameters(), lr=lr, betas=betas)
    # monitor
    m = TorchMonitor(
        root.directory, ckpt=checkpoint,
        log=log, sca=sca, img=img,
        time=('@20s', '@1e'),
        msg='root@100s',
        img_fn=lambda x: x[:min(nimg, x.shape[0])],
        hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)])

    for _ in m(range(epochs)):
        # (1) train
        for x, y in m(datasets['train']):
            # process input
            x, y = x.to(device), y.to(device).float()
            b = x.shape[0]
            # discriminator step
            set_requires_grad([nn_d], True)
            op_d.zero_grad()
            z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
            _x_ = nn_g(y, z)
            loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
            loss_d.backward()
            op_d.step()
            # generator step
            op_g.zero_grad()
            y = py.sample([b]).reshape([b, cy, 1, 1]).to(device)
            z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
            _x_ = nn_g(y, z)
            set_requires_grad([nn_d], False)
            loss_g = nn_d((_x_, y), True)
            loss_g.backward()
            op_g.step()
        # (2) inference
        if 'inference' in datasets:
            with torch.no_grad():
                for yi, zi in datasets['inference']:
                    yi, zi = yi.to(device), zi.to(device)
                    _xi_ = nn_g(yi, zi)

DCGAN using the class api

"""A class implementation of dcgan"""
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#  You should have received a copy of the GNU General Public License
#   along with this program. If not, see <http://www.gnu.org/licenses/>.
#  Copyright (C) 2019  Robert J Weston, Oxford Robotics Institute
#
#  xmen
#  email:   robw@robots.ox.ac.uk
#  github: https://github.com/robw4/xmen/
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.

from typing import Tuple
import torch
from torch.distributions import Normal
from xmen.experiment import Experiment


class Dcgan(Experiment):
    """Train a conditional GAN to predict MNIST digits.

    To viusalise the results run::

        tensorboard --logdir ...

    """
    import tempfile
    b: int = 128  # @p the batch size per gpu
    hw0: Tuple[int, int] = (4, 4)  # @p the height and width of the image
    nl: int = 4  # @p The number of levels in the discriminator.
    data_root: str = tempfile.gettempdir()  # @p the root data directory
    cx: int = 1  # @p the dimensionality of the image input
    cy: int = 10  # @p the dimensionality of the conditioning vector
    cf: int = 512  # @p the number of features after the first conv in the discriminator
    cz: int = 100  # @p the dimensionality of the noise vector
    ncpus: int = 8  # @p the number of threads to use for data loading
    ngpus: int = 1  # @p the number of gpus to run the model on
    epochs: int = 100  # @p no. of epochs to train for
    gan: str = 'lsgan'  # @p the gan type to use (one of ['vanilla', 'lsgan'])
    lr: float = 0.0002  # @p learning rate
    betas: Tuple[float, float] = (0.5, 0.999)  # @p The beta parameters for the
    # Monitoring parameters
    checkpoint: str = 'nn_.*@1e'  # @p
    log: str = 'loss_.*@20s'  # @p log scalars
    sca: str = 'loss_.*@20s'  # @p tensorboard scalars
    img: str = '_x_|x$@20s'  # @p tensorboard images
    time: str = ('@20s', '@1e')  # @p timing modulos
    nimg: int = 64  # @p The maximum number of images to display to tensorboard
    ns: int = 5  # @p The number of samples to generate at inference

    @property
    def hw(self): return [d * 2 ** self.nl for d in self.hw0]

    @property
    def device(self): return 'cuda' if torch.cuda.is_available() else 'cpu'

    def datasets(self):
        """Returns a dictionary of iterable get_datasets for modes 'train' (label image pairs)
        and 'inference' (just inputs spanning the prediction space)"""
        from torch.utils.data import DataLoader
        from torchvision.datasets.mnist import MNIST
        import torchvision.transforms as T

        def to_target(y):
            Y = torch.zeros([self.cy])
            Y[y] = 1.
            return Y.reshape([self.cy, 1, 1])
        # Generate test samples
        y = torch.stack([to_target(i) for i in range(self.cy)])
        y = y.unsqueeze(0).expand(
            [self.ns, self.cy, self.cy, 1, 1]).reshape(  # Expand across last dim
            [-1, self.cy, 1, 1])  # Batch
        z = Normal(0., 1.).sample([y.shape[0], self.cz, 1, 1])
        return {
            'train': DataLoader(
                MNIST(self.data_root, download=True,
                      transform=T.Compose(
                        [T.Resize(self.hw), T.CenterCrop(self.hw),
                         T.ToTensor(), T.Normalize([0.5], [0.5])]),
                      target_transform=to_target),
                batch_size=self.b * self.ngpus,
                shuffle=True, num_workers=self.ncpus),
            'inference': list(
                zip(y.unsqueeze(0), z.unsqueeze(0)))}  # Turn into batches

    def build(self):
        """Build generator, discriminator and optimisers."""
        from torch.optim import Adam
        from xmen.examples.models import GeneratorNet, DiscriminatorNet
        nn_g = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
        op_g = Adam(nn_g.parameters(), lr=self.lr, betas=self.betas)
        nn_d = DiscriminatorNet(self.cx, self.cy, self.cf, self.hw0, self.nl)
        op_d = Adam(nn_d.parameters(), lr=self.lr, betas=self.betas)
        return nn_g, nn_d, op_g, op_d

    def run(self):
        """Train the mnist dataset for a fixed number of epochs running an
        inference loop after every epoch."""
        from xmen.monitor import TorchMonitor, TensorboardLogger
        from xmen.examples.models import weights_init, set_requires_grad
        from torch.distributions import Normal
        from torch.distributions.one_hot_categorical import OneHotCategorical
        # open_socket
        datasets = self.datasets()
        pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz]))
        py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy)
        nn_g, nn_d, op_g, op_d = self.build()
        nn_g = nn_g.to(self.device).float().apply(weights_init)
        nn_d = nn_d.to(self.device).float().apply(weights_init)
        # training loop
        m = TorchMonitor(
            self.directory, ckpt=self.checkpoint,
            log=self.log, sca=self.sca, img=self.img,
            time=('@20s', '@1e'),
            img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
            hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)])
        for _ in m(range(self.epochs)):
            for x, y in m(datasets['train']):
                # process input
                x, y = x.to(self.device), y.to(self.device).float()
                b = x.shape[0]
                # discriminator step
                set_requires_grad([nn_d], True)
                op_d.zero_grad()
                z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
                _x_ = nn_g(y, z)
                loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
                loss_d.backward()
                op_d.step()
                # generator step
                op_g.zero_grad()
                y = py.sample([b]).reshape([b, self.cy, 1, 1]).to(self.device)
                z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
                _x_ = nn_g(y, z)
                set_requires_grad([nn_d], False)
                loss_g = nn_d((_x_, y), True)
                loss_g.backward()
                op_g.step()
            # inference
            if 'inference' in datasets:
                with torch.no_grad():
                    for yi, zi in datasets['inference']:
                        yi, zi = yi.to(self.device), zi.to(self.device)
                        _xi_ = nn_g(yi, zi)

Generative modelling with inheritance

"""Deep generative models with inheritance"""
#  Copyright (C) 2019  Robert J Weston, Oxford Robotics Institute
#
#  xmen
#  email:   robw@robots.ox.ac.uk
#  github: https://github.com/robw4/xmen/
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
from typing import Tuple
import torch
from torch.distributions import Normal
from xmen.experiment import Experiment
import os

# ------------------------------------------------------
# ---- BASE EXPERIMENTS --------------------------------
# ------------------------------------------------------
class BaseGenerative(Experiment):
    """An abstract class defining parameters and some useful properties common to
     both the VAE and cGAN implementations.

    Inherited classes must overload:
    - get_datasets()
    - build()
    - run()

    Inherited classes can optionally overload:
    - parameter defaults and documentation

    Note:
        The output size is given as hw0 * 2 ** nl (eg (4, 2) * 2 ** 4 = (64, 32)
    """
    b: int = 256  # @p the batch size per gpu
    hw0: Tuple[int, int] = (4, 4)  # @p the height and width of the image
    nl: int = 4  # @p The number of levels in the discriminator.
    data_root: str = os.getenv("HOME") + '/data/mnist'  # @p the root data directory
    cx: int = 1  # @p the dimensionality of the image input
    cy: int = 10  # @p the dimensionality of the conditioning vector
    cf: int = 512  # @p the number of features after the first conv in the discriminator
    cz: int = 100  # @p the dimensionality of the noise vector
    ncpus: int = 8  # @p the number of threads to use for data loading
    ngpus: int = 1  # @p the number of gpus to run the model on
    epochs: int = 100  # @p no. of epochs to train for
    gan: str = 'lsgan'  # @p the gan type to use (one of ['vanilla', 'lsgan'])
    lr: float = 0.0002  # @p learning rate
    betas: Tuple[float, float] = (0.5, 0.999)  # @p The beta parameters for the
    # Monitoring parameters
    checkpoint: str = 'nn_.*@1e'  # @p
    log: str = 'loss_.*@20s'  # @p log scalars
    sca: str = 'loss_.*@20s'  # @p tensorboard scalars
    img: str = '_x_|x$@20s'  # @p tensorboard images
    time: str = ('@20s', '@1e')  # @p timing modulos
    nimg: int = 64  # @p The maximum number of images to display to tensorboard

    test_samples = None

    # Useful properties
    @property
    def hw(self): return [d * 2 ** self.nl for d in self.hw0]

    @property
    def device(self): return 'cuda' if torch.cuda.is_available() else 'cpu'

    def datasets(self): return NotImplementedError

    def run(self): return NotImplementedError

    def build(self): return NotImplementedError


class BaseCVae(BaseGenerative):
    """BaseCVae is an intermediary class defining the model and training loop but not the dataset.

    Inherited classes should overload:
    - dataset()
    """
    nlprior: int = 1  # @p The number of hidden layers in the prior network
    w_kl: float = 1.0  # @p The weighting on the KL divergence term
    ns: int = 5  # @p The number of samples to generate at inference
    predictive: str = 'prior'  # @p Use either the 'prior' or 'posterior' predictive at inference
    b = 32
    log = 'loss.*|log_px|kl_qz_pz@20s'
    sca = 'loss.*|log_px|kl_qz_pz@20s'
    img = '_x_|x$@20s'
    time = ('@10s', '@1e')

    def build(self):
        from torch.optim import Adam
        from itertools import chain
        from xmen.examples.torch.models import GeneratorNet, PosteriorNet, PriorNet
        nn_gen = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
        nn_post = PosteriorNet(self.cx, self.cy, self.cz, self.cf, self.hw0, self.nl)
        nn_prior = PriorNet(self.cy, self.cz, self.cf, self.nlprior)
        opt = Adam(chain(nn_gen.parameters(), nn_post.parameters(), nn_prior.parameters()),
                   self.lr, self.betas)
        return nn_gen, nn_post, nn_prior, opt

    def run(self):
        from torch.distributions import Normal, kl_divergence
        from xmen.monitor import TorchMonitor, TensorboardLogger
        from xmen.examples.torch.models import weights_init
        # construct model
        datasets = self.datasets()
        nn_gen, nn_post, nn_prior, opt = self.build()
        nn_gen, nn_post, nn_prior = (
            v.to(self.device).float().apply(weights_init) for v in (
            nn_gen, nn_post, nn_prior))
        if self.ngpus > 1:
            nn_gen, nn_post, nn_prior = (torch.nn.DataParallel(n) for n in (
                nn_gen, nn_post, nn_prior))
        m = TorchMonitor(
            self.directory, ckpt=self.checkpoint,
            log=self.log, sca=self.sca, img=self.img,
            img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
            time=('@20s', '@1e'),
            hooks=[TensorboardLogger('image', '_xi_$@20s', nrow=self.ns)])
        for _ in m(range(self.epochs)):
            # Training
            for x, y in m(datasets['train']):
                x, y = x.to(self.device), y.to(self.device).float()
                qz = Normal(*nn_post(x, y))
                z = qz.sample()
                # Likelihood
                _x_ = nn_gen(y, z)
                px = Normal(_x_, 0.1)
                pz = Normal(*nn_prior(y.reshape([-1, self.cy])))
                log_px = px.log_prob(x).sum()
                kl_qz_pz = kl_divergence(qz, pz).sum()
                loss = self.w_kl * kl_qz_pz - log_px
                opt.zero_grad()
                loss.backward()
                opt.step()
            # Inference
            if 'inference' in datasets:
                with torch.no_grad():
                    for yi, zi in datasets['inference']:
                        if self.predictive == 'posterior':
                            qz = Normal(*nn_prior(yi.reshape([-1, self.cy])))
                            zi = qz.sample()
                        _xi_ = nn_gen(yi, zi)


class BaseGAN(BaseGenerative):
    """BaseGan is an intermediary class defining the model and training loop but not the dataset.

    Inherited classes should overload:
    - dataset()
    - distributions()
    """
    # overload previous parameter default (documentation is maintained)
    b = 128
    # define new parameters
    ns: int = 5  # @p The number of samples to generate at inference

    def build(self):
        """Build the model from the current configuration"""
        from torch.optim import Adam
        from xmen.examples.torch.models import GeneratorNet, DiscriminatorNet
        nn_g = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
        op_g = Adam(nn_g.parameters(), lr=self.lr, betas=self.betas)
        nn_d = DiscriminatorNet(self.cx, self.cy, self.cf, self.hw0, self.nl)
        op_d = Adam(nn_d.parameters(), lr=self.lr, betas=self.betas)
        return nn_g, nn_d, op_g, op_d

    def distributions(self): raise NotImplementedError

    def run(self):
        from xmen.monitor import TorchMonitor, TensorboardLogger
        from xmen.examples.torch.models import set_requires_grad, weights_init
        # Get get_datasets
        datasets = self.datasets()
        py, pz = self.distributions()
        nn_g, nn_d, op_g, op_d = self.build()
        nn_g = nn_g.to(self.device).float().apply(weights_init)
        nn_d = nn_d.to(self.device).float().apply(weights_init)
        m = TorchMonitor(
            self.directory, ckpt=self.checkpoint,
            log=self.log, sca=self.sca, img=self.img,
            time=('@20s', '@1e'),
            img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
            hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=self.ns)])
        for _ in m(range(self.epochs)):
            for x, y in m(datasets['train']):
                # process input
                x, y = x.to(self.device), y.to(self.device).float()
                b = x.shape[0]
                # discriminator step
                set_requires_grad([nn_d], True)
                op_d.zero_grad()
                z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
                _x_ = nn_g(y, z)
                loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
                loss_d.backward()
                op_d.step()
                # generator step
                op_g.zero_grad()
                y = py.sample([b]).reshape([b, self.cy, 1, 1]).to(self.device)
                z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
                _x_ = nn_g(y, z)
                set_requires_grad([nn_d], False)
                loss_g = nn_d((_x_, y), True)
                loss_g.backward()
                op_g.step()
            # Inference
            if 'inference' in datasets:
                with torch.no_grad():
                    for yi, zi in datasets['inference']:
                        yi, zi = yi.to(self.device), zi.to(self.device)
                        _xi_ = nn_g(yi, zi)


class BaseMnist(BaseGenerative):
    """cDCGAN Training on the MNIST get_datasets"""
    # Update defaults
    hw0, nl = (4, 4), 3  # Default size = 32 x 32
    data_root = os.getenv("HOME") + '/data/mnist'
    cx, cy, cz, cf = 1, 10, 100, 512
    ns = 10  # Number of samples to generate during inference

    def datasets(self):
        """Configure the get_datasets used for training"""
        from torch.utils.data import DataLoader
        from torchvision.datasets.mnist import MNIST
        import torchvision.transforms as T

        def to_target(y):
            Y = torch.zeros([self.cy])
            Y[y] = 1.
            return Y.reshape([self.cy, 1, 1])

        transform = T.Compose(
            [T.Resize(self.hw), T.CenterCrop(self.hw),
             T.ToTensor(), T.Normalize([0.5], [0.5])])
        y = torch.stack([to_target(i) for i in range(self.cy)])
        y = y.unsqueeze(0).expand(
            [self.ns, self.cy, self.cy, 1, 1]).reshape(  # Expand across last dim
            [-1, self.cy, 1, 1])  # Batch
        z = Normal(0., 1.).sample([y.shape[0], self.cz, 1, 1])
        return {
            'train': DataLoader(
                MNIST(self.data_root, download=True,
                      transform=transform,
                      target_transform=to_target),
                batch_size=self.b * self.ngpus,
                shuffle=True, num_workers=self.ncpus),
            'inference': list(
                zip(y.unsqueeze(0), z.unsqueeze(0)))}  # Turn into batches


# ------------------------------------------------------
# ---- RUNABLE EXPERIMENTS -----------------------------
# ------------------------------------------------------
class InheritedMnistGAN(BaseMnist, BaseGAN):
    """Train a cDCGAN on MNIST"""
    epochs = 20
    ncpus, ngpus = 0, 1
    b = 128

    def distributions(self):
        """Generate one hot and normal samples"""
        from torch.distributions import Normal
        from torch.distributions.one_hot_categorical import OneHotCategorical
        pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz]))
        py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy)
        return py, pz


class InheritedMnistVae(BaseMnist, BaseCVae):
    """Train a conditional VAE on MNIST"""
    ns = 10
    w_beta = 1.5
    lr = 0.00005
    ncpus, ngpus = 0, 2

Xmen Meets Pytorch Lightning

Note

this is currently experimental!

"""CURRENTLY EXPERIMENTAL"""
from pytorch_lightning import LightningModule
import torch
import torch.nn.functional as F
from typing import List, Any

from xmen.lightning import TensorBoardLogger, Trainer


class LitMNIST(LightningModule):
    """Example torch lighnning module taken from the docs"""

    def __init__(self):
        super().__init__()
        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x_in):
        batch_size, channels, width, height = x_in.size()
        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x_in.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log_dict(
            {'loss': loss, 'x': x})
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_epoch_end(
        self, outputs: List[Any]
    ) -> None:
        self.log('loss_val', torch.stack(outputs).mean())

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def test_epoch_end(
        self, outputs: List[Any]
    ) -> None:
        self.log('loss_val', torch.stack(outputs).mean())


def lit_experiment(
        root,
        batch_size=64,    # The batch size of the experiment
        epochs=5,  # Number of epochs to train for
):
    """Xmen meets pytorch_lightning"""
    import xmen
    import os

    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    from torchvision import transforms

    # prepare transforms standard to MNIST
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))])

    # data
    mnist_train = MNIST('/tmp/xmen', train=True, download=True, transform=transform)
    mnist_train = DataLoader(mnist_train, batch_size=batch_size)
    mnist_val = MNIST(os.getcwd(), train=False, download=True, transform=transform)
    mnist_val = DataLoader(mnist_val, batch_size=batch_size)

    model = LitMNIST()
    trainer = Trainer(
        default_root_dir=root.directory,
        max_epochs=epochs,
        logger=TensorBoardLogger(
            root=root,
            log=['loss@100s', 'loss_val'],
            sca=['loss@100s', 'loss_val'],
            img='x@100s',
            time='@500s',
            msg='loss@50s'
        )
    )
    trainer.fit(model, mnist_train, mnist_val)
    trainer.test(model, mnist_val)


if __name__ == '__main__':
    from xmen.functional import functional_experiment
    Exp, _ = functional_experiment(lit_experiment)
    Exp().main()