Examples¶
Introduction¶
All examples in this section are defined in xmen.examples
and can be run from the commandline using the xmen
command line interface; just add them using xmen --add
(if they are not already).
Hello World¶
"""Basic Examples.
"""
# Copyright (C) 2019 Robert J Weston, Oxford Robotics Institute
#
# xmen
# email: robw@robots.ox.ac.uk
# github: https://github.com/robw4/xmen/
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import xmen
@xmen.autodoc
def hello_world(
root: xmen.Root, # experiments are assigned a root before being executed
a: str = 'Hello', # the first
# argument
b: str = 'World', # the second argument
):
"""A hello world experiment designed to demonstrate
defining experiments through the functional experiment api"""
print(f'{a} {b}')
... # whatever other experiment code you want
with open(root.directory + '/out.txt', 'w') as f:
f.write(f'{a} {b}')
root.message({'a': a, 'b': b})
class HelloWorld(xmen.Experiment):
"""A hello world experiment designed to demonstrate
defining experiments through the class experiment api"""
# Parameters
a: str = 'Hello' # @p The first argument
b: str = 'World' # @p The second argument
def run(self):
print(f'{self.a} {self.b}!')
self.message({'a': self.a, 'b': self.b})
if __name__ == '__main__':
# optionally expose the command line interface if you
# would like to run the experiment as a experiments script
from xmen.functional import functional_experiment
# generate experiment from function definition if defined
# using the functional experiment (this step is not needed if
# the experiment is defined as a class)
Exp = functional_experiment(hello_world)
# every experiment inherits main() allowing the experiment
# to be configured and run from the command line.
Exp().main()
# try...
# >> experiments -m xmen.examples.hello_world --help
A little more detail¶
"""Using the class api to define experiments with inheritance."""
from xmen.experiment import Experiment
import os
import time
from typing import List
class BaseExperiment(Experiment):
"""A basic experiments experiment demonstrating the features of the xmen api."""
# Parameters are defined as attributes in the class body with the
# @p identifier
t = 'cat' # @p
w = 3 # @p parameter w has a help message whilst t does not
h: int = 10 # @p h declared with typing is very concise and neat
# Parameters can also be defined in the __init__ method
def __init__(self, *args, **kwargs):
super(BaseExperiment, self).__init__(*args, **kwargs)
self.a: str = 'h' # @p A parameter
self.b: int = 17 # @p Another parameter
# Normal attributes are still allowed
self.c: int = 5 # This is not a parameter
class AnotherExperiment(BaseExperiment):
m: str = 'Another value' # @p Multiple inheritance example
p: str = 'A parameter only in Another Experiment ' # @p
class AnExperiment(BaseExperiment):
# |
# Experiments can inherit from other experiments
# parameters are inherited too
"""An experiment testing the xmen experiment API. The __docstring__ will
appear in both the docstring of the class __and__ as the prolog in the
command line interface."""
# Feel free to define more parameters
x: List[float] = [3., 2.] # @p Parameters can be defined cleanly as class attributes
y: float = 5 # @p This parameter will have this
# Parameters can be overridden
a: float = 0.5 # a's default and type will be changed. Its help will be overridden
b: int = 17 # @p b's help will be changed
m: str = 'Defined in AnExperiment' # @p m is defined in AnExperiment
def run(self):
# Experiment execution is defined in the run method
print(f'The experiment state inside run is {self.status}')
# recording messaging is super easy
self.message({'time': time.time()})
# Each experiment has its own unique directory. You are encourage to
# write out data accumulated through the execution (snapshots, logs etc.)
# to this directory.
with open(os.path.join(self.directory, 'logs.txt'), 'w') as f:
f.write('This was written from a running experiment')
def debug(self):
self.a = 'In debug mode'
return self
@property
def h(self):
return 'h has been overloaded as property and will no longer' \
'considered as a parameter'
# Experiments can inheret from multiple classes
class MultiParentsExperiment(AnotherExperiment, AnExperiment):
pass
if __name__ == '__main__':
# documentation is automatically added to the class
help(AnExperiment)
# to run an experiment first we initialise it
# the experiment is initialised in 'default' status
print('\nInitialising')
print('---------------------------')
exp = AnExperiment()
print(exp)
# whilst the status is default the parameters of the
# experiment can be changed
print('\nConfiguring')
print('------------ ')
exp = AnExperiment()
exp.a = 'hello'
exp.update({'t': 'dog', 'w': 100})
# Note parameters are copied from the class during
# instantiation. This way you don't need to worry
# about accidentally changing the mutable class
# types across the entire class.
exp.x += [4.]
print(exp)
assert AnExperiment.x == [3., 2.]
# If this is not desired (or neccessary) initialise
# use exp = AnExperiment(copy=False)
# Experiments can inheret from multiple classes:
print('\nMultiple Inheritance')
print('----------------------')
print(MultiParentsExperiment())
print('\n Parameters defaults, helps and values are '
'inherited according to experiments method resolution order '
'(i.e left to right). Note that m has the value '
'defined in Another Experiment')
print('\nRegistering')
print('-------------')
# Before being run an experiment needs to be registered
# to a directory
exp.register('/tmp/an_experiment', 'first_experiment',
purpose='A bit of a test of the xmen experiment api')
print(exp, end='\n')
print('\nGIT, and system information is automatically logged\n')
# The parameters of the experiment can no longer be changed
try:
exp.a = 'cat'
except AttributeError:
print('Parameters can no longer be changed!', end='\n')
pass
# An experiment can be run either by...
# (1) calling it
print('\nRunning (1)')
print('-------------')
exp()
# (2) using it as a context. Just define a main
# loop like you normally would
print('\nRunning (2)')
print('-------------')
with exp as e:
# Inside the experiment context the experiment status is 'running'
print(f'Once again the experiment state is {e.status}')
# Write the main loop just as you normally would"
# using the parameters already defined
results = dict(sum=sum(e.x), max=max(e.x), min=min(e.x))
# Write results to the expeirment just as before
e.message(results)
# All the information about the current experiment is
# automatically saved in the experiments root directory
# for free
print(f'\nEverything ypu might need to know is logged in {exp.directory}/params.yml')
print('-----------------------------------------------------------------------------------------------')
print('Note that GIT, and system information is automatically logged\n'
'along with the messages')
with open('/tmp/an_experiment/first_experiment/params.yml', 'r') as f:
print(f.read())
print('\nRegsitering and Configuring from the command line')
print('---------------------------------------------------')
# Alternatively configuring and registering an experiment
# can be done from the command line as:
exp = AnExperiment()
args = exp.parse_args()
print('\nRegsitering, Configuring and running in a single line!')
print('--------------------------------------------------------')
# Or alternatively
# the experiment can be _configured_, _registered_
# and _run_ by including a single line of code:
AnExperiment().main()
# See ``experiments xmen.tests.experiment --help`` for
# more information
The Monitor Class¶
The Monitor
class is designed to facilitate easy logging of experiements. All the examples in this section can be found in xmen.examples.monitor
and can in experiments as:
experiments -m xmen.examples.monitor.logger
experiments -m xmen.examples.monitor.messenger.basic
experiments -m xmen.examples.monitor.messenger.leader
experiments -m xmen.examples.monitor.full
Logging¶
"""Automatic Logging"""
from xmen.monitor import Monitor
m = Monitor(
log=['x@2s', 'y@1e'],
log_fn=[lambda _: '|'.join(_), lambda _: _],
log_format=['', '.5f'])
x = ['cat', 'dog']
y = 0.
for _ in m(range(3)):
for i in m(range(5)):
y += i
# output
[05:36PM 18/11/20 0/3 2/15]: x = cat|do
[05:36PM 18/11/20 0/3 4/15]: x = cat|do
[05:36PM 18/11/20 1/3]: y = 10.0000
[05:36PM 18/11/20 1/3 6/15]: x = cat|do
[05:36PM 18/11/20 1/3 8/15]: x = cat|do
[05:36PM 18/11/20 1/3 10/15]: x = cat|do
[05:36PM 18/11/20 2/3]: y = 20.0000
[05:36PM 18/11/20 2/3 12/15]: x = cat|do
[05:36PM 18/11/20 2/3 14/15]: x = cat|do
[05:36PM 18/11/20 3/3]: y = 30.0000
Messaging¶
Messgaging Multiple Experiments¶
"""automatic messaging"""
from xmen import Experiment
from xmen.monitor import Monitor
# link some experiments
ex1, ex2 = Experiment(), Experiment()
ex1.register('/tmp', 'ex1')
ex2.register('/tmp', 'ex2')
# open_socket monitor
m = Monitor(msg='y.*->ex.*@10s')
y1, y2 = 0, 0
for i in m(range(40)): # monitor loop
y1 += 1
y2 += 2
if i % 10 == 1:
# messages added automatically to ex1 and ex2
print(', '.join(
[f"ex1: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')] +
[f"ex2: {k} = {ex1.messages.get(k, None)}" for k in ('y1', 'y2')]))
# timing information is also logged
print('\nAll Messages')
for k, v in ex1.messages.items():
print(k, v)
# output
ex1: y1 = None, ex1: y2 = None, ex2: y1 = None, ex2: y2 = None
[06:00PM 18/11/20 10/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 10/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 10, ex1: y2 = 20, ex2: y1 = 10, ex2: y2 = 20
[06:00PM 18/11/20 20/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 20/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 20, ex1: y2 = 40, ex2: y1 = 20, ex2: y2 = 40
[06:00PM 18/11/20 30/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 30/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
ex1: y1 = 30, ex1: y2 = 60, ex2: y1 = 30, ex2: y2 = 60
[06:00PM 18/11/20 40/40]: Left messages ['y1', 'y2'] with ex1 at /tmp/ex1
[06:00PM 18/11/20 40/40]: Left messages ['y1', 'y2'] with ex2 at /tmp/ex2
All Messages
last 06:00PM 18/11/20
s 40/40
wall 0:00:00.682972
end 0:00:00
next 0:00:00
step 0:00:00.000007
m_step 0:00:00.000013
load 0:00:00.000005
m_load 0:00:00.000005
y1 40
y2 80
Using a leader¶
"""using the leader argument"""
from xmen import Experiment
from xmen.monitor import TorchMonitor
ex = Experiment()
ex.register('/tmp', 'ex')
m = TorchMonitor(msg='^y$|^i$->^ex$@10s', msg_keep='min', msg_leader='^y$')
x = -50
for i in m(range(100)):
x += 1
y = x ** 2
if i % 10 == 1:
print([ex.messages.get(k, None) for k in ('i', 'y')])
# output
[None, None]
[06:04PM 18/11/20 10/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[9, 1600]
[06:04PM 18/11/20 20/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[19, 900]
[06:04PM 18/11/20 30/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[29, 400]
[06:04PM 18/11/20 40/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[39, 100]
[06:04PM 18/11/20 50/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 60/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 70/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 80/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 90/100]: Left messages ['i', 'y'] with ex at /tmp/ex
[49, 0]
[06:04PM 18/11/20 100/100]: Left messages ['i', 'y'] with ex at /tmp/ex
A Full example¶
from xmen import Experiment
from xmen.monitor import Monitor
import os
X = Experiment(os.path.join(os.environ['HOME'], 'tmp'), 'ex1')
a, b, c, d, e, f = 0, 0, 0, 0, 0, 0
def identity(x):
return x
def mult(x):
return 10 * x
m = Monitor(
log=('a|b@2s', 'b@1e', 'c@1er', "d|e@1eo", "e@1se"),
log_fn=(mult, identity, identity, identity, identity),
log_format='.3f',
msg='a->X@1s',
time=('@2s', '@1e'),
probe='X@10s',
limit='@20s')
for _ in m(range(2)):
for _ in m(range(2)):
for _ in m(range(3)):
for _ in m(range(4)):
for _ in m(range(5)):
a += 1
b += 1
c += 1
d += 1
e += 1
=[03:07PM 05/02/21 0/2 0/4 0/12 0/48 1/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: a = 20.000 b = 0.00
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: wall=0:00:00.050627 end=0:00:06.024641 next=0:00:00.000054 step=0:00:00.000006 m_step=0:00:00.000007 load=0:00:00.000010 m_load=0:00:00.000011
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 2/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 3/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: a = 40.000 b = 0.00
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: wall=0:00:00.135165 end=0:00:07.974762 next=0:00:00.000017 step=0:00:00.000006 m_step=0:00:00.000006 load=0:00:00.000011 m_load=0:00:00.000011
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 4/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 0/48 5/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48]: b = 1.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48]: wall=0:00:00.214666 end=0:00:10.089297 next=0:00:00.637951 step=0:00:00.212640 m_step=0:00:00.212640 load=0:00:00.000010 m_load=0:00:00.000010
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: a = 60.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: wall=0:00:00.215068 end=0:00:08.387665 next=0:00:00.000059 step=0:00:00.000006 m_step=0:00:00.000006 load=0:00:00.000009 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 6/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 7/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: a = 80.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: wall=0:00:00.298510 end=0:00:08.656785 next=0:00:00.000045 step=0:00:00.000011 m_step=0:00:00.000009 load=0:00:00.000016 m_load=0:00:00.000014
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 8/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 9/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: a = 100.000 b = 10.00
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: wall=0:00:00.387885 end=0:00:08.921352 next=0:00:00 step=0:00:00.000010 m_step=0:00:00.000009 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 1/48 10/240]: cpu=3.5% 0 = TITAN X (Pascal) 88.0% 2116.0MB 85.0°C 1 = GeForce GTX TITAN Black 100.0% 3774.0MB 74.0°C
[03:07PM 05/02/21 0/2 0/4 0/12 2/48]: b = 2.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48]: wall=0:00:02.976140 end=0:01:08.451221 next=0:00:02.973728 step=0:00:02.761070 m_step=0:00:01.486855 load=0:00:00.000008 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 11/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: a = 120.000 b = 20.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: wall=0:00:03.043639 end=0:00:57.829140 next=0:00:00.000060 step=0:00:00.000007 m_step=0:00:00.000008 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 12/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 13/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: a = 140.000 b = 20.00
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: wall=0:00:03.198814 end=0:00:51.637992 next=0:00:00.000021 step=0:00:00.000010 m_step=0:00:00.000008 load=0:00:00.000016 m_load=0:00:00.000013
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 14/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 2/48 15/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48]: b = 3.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48]: wall=0:00:03.350787 end=0:00:50.261807 next=0:00:01.115966 step=0:00:00.374161 m_step=0:00:01.115957 load=0:00:00.000008 m_load=0:00:00.000009
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: a = 160.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: wall=0:00:03.351449 end=0:00:46.920286 next=0:00:00.000102 step=0:00:00.000010 m_step=0:00:00.000010 load=0:00:00.000015 m_load=0:00:00.000015
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 16/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 17/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: a = 180.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: wall=0:00:03.502008 end=0:00:43.191432 next=0:00:00.000042 step=0:00:00.000010 m_step=0:00:00.000009 load=0:00:00.000010 m_load=0:00:00.000012
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 18/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 19/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: a = 200.000 b = 30.00
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: wall=0:00:03.647854 end=0:00:40.126398 next=0:00:00 step=0:00:00.000012 m_step=0:00:00.000009 load=0:00:00.000018 m_load=0:00:00.000014
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: Left messages ['a'] with X at /home/robw/tmp_23
[03:07PM 05/02/21 0/2 0/4 0/12 3/48 20/240]: cpu=3.5% 0 = TITAN X (Pascal) 96.0% 2116.0MB 86.0°C 1 = GeForce GTX TITAN Black 100.0% 3774.0MB 75.0°C
[03:07PM 05/02/21 0/2 0/4 0/12 4/48]: b = 4.00
[03:07PM 05/02/21 0/2 0/4 0/12 4/48]: wall=0:00:04.948039 end=0:00:54.428427 next=0:00:00 step=0:00:01.596759 m_step=0:00:01.236157 load=0:00:00.000013 m_load=0:00:00.000010
[03:07PM 05/02/21 0/2 0/4 1/12]: c = 1.00
[03:07PM 05/02/21 0/2 0/4 1/12 4/48 20/240]: ---- Limit reached for step 20 ----
[03:07PM 05/02/21 0/2 0/4 1/12 5/48]: b = 5.00
[03:07PM 05/02/21 0/2 0/4 1/12 5/48]: wall=0:00:04.949144 end=0:00:42.562637 next=0:00:00.000554 step=0:00:00.000172 m_step=0:00:00.000172 load=0:00:00.000013 m_load=0:00:00.000013
[03:07PM 05/02/21 0/2 0/4 2/12]: c = 2.00
[03:07PM 05/02/21 0/2 1/4]: d = 1.000 e = 0.00
[03:07PM 05/02/21 1/2]: e = 1.00
The TorchMonitor Class¶
The TorchMonitor
class adds to the functionality of Monitor
also allowing torch modules to be automatically saved and variables to be logged to tensorboard.
- experiments -m xmen.examples.monitor.torch_monitor
experiments -m xmen.examples.monitor.checkpoint
Automatic Checkpointing¶
"""Automatic check-pointing"""
from xmen.monitor import TorchMonitor
from torch.nn import Conv2d
from torch.optim import Adam
model = Conv2d(2, 3, 3)
model2 = Conv2d(2, 3, 3)
optimiser = Adam(model.parameters())
m = TorchMonitor(
'/tmp/checkpoint',
ckpt=['^model$@5s', 'opt@1e', '^model2$@20s'],
ckpt_keep=[5, 1, None])
for _ in m(range(10)):
for _ in m(range(20)):
# Do something
...
# result
tree /tmp/checkpoint
└── checkpoints
├── model
│ ├── 180.pt
│ ├── 185.pt
│ ├── 190.pt
│ ├── 195.pt
│ └── 200.pt
├── model2
│ ├── 100.pt
│ ├── 120.pt
│ ├── 140.pt
│ ├── 160.pt
│ ├── 180.pt
│ ├── 20.pt
│ ├── 200.pt
│ ├── 40.pt
│ ├── 60.pt
│ └── 80.pt
└── optimiser
└── 200.pt
Tensorboard Logging¶
"""automatic tensorboard logging"""
from xmen.monitor import TorchMonitor
import numpy as np
import torch
import os
import matplotlib.pyplot as plt
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
plt.style.use('ggplot')
ds = DataLoader(MNIST(os.getenv("HOME") + '/data/mnist', download=True,
transform=T.Compose(
[T.Resize([64, 64]), T.CenterCrop([64, 64]),
T.ToTensor(), T.Normalize([0.5], [0.5])])), 8)
m = TorchMonitor(
directory='/tmp/tb',
sca=['^z$|^X$@10s', '^a|^x$@5s'],
img=['^mnist@10s', '^mnist@5s'], img_fn=[lambda x: x[:2], lambda x: x[:5]], img_pref=['2', '5'],
hist='Z@1s',
fig='fig@10s',
txt='i@5s', txt_fn=lambda x: f'Hello at step {x}',
vid='^mnist@10s', vid_fn=lambda x: (x.unsqueeze(0) - x.min()) / (x.max() - x.min())
)
# variables
x = 0.
a = [1, 2]
z = {'x': 5, 'y': 10}
for i, (mnist, _) in m(zip(range(31), ds)):
# plot a figure
fig = plt.figure(figsize=[10, 5])
plt.plot(np.linspace(0, 1000), np.cos(np.linspace(0, 1000) * i))
# random tensor
Z = torch.randn([10, 3, 64, 64]) * i / 100
# scalars
x = (i - 15) ** 2
z['i'] = i
z['x'] += 1
z['y'] = z['x'] ** 2
# to visualise results run
tensorboard --logdir /tmp/tb
Pytorch experiments with Xmen¶
All examples in this section are defined in xmen.examples.torch
and can be run from the commandline using the xmen
command line interface; just add them using xmen --add
(if they are not already). Pytorch will need to be installed in order to run these examples.
DCGAN using the functional api¶
"""A functional implementation of dcgan"""
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2019 Robert J Weston, Oxford Robotics Institute
#
# xmen
# email: robw@robots.ox.ac.uk
# github: https://github.com/robw4/xmen/
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
from typing import Tuple
import xmen
import os
try:
import torch
except ImportError:
print('In order to run this script first add pytorch to the experiments path')
def get_datasets(cy, cz, b, ngpus, ncpus, ns, data_root, hw, **kwargs):
"""Returns a dictionary of iterable get_datasets for modes 'train' (label image pairs)
and 'inference' (just inputs spanning the prediction space)"""
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import MNIST
from torch.distributions import Normal
import torchvision.transforms as T
def to_target(y):
Y = torch.zeros([cy])
Y[y] = 1.
return Y.reshape([cy, 1, 1])
# Generate test samples
y = torch.stack([to_target(i) for i in range(cy)])
y = y.unsqueeze(0).expand(
[ns, cy, cy, 1, 1]).reshape( # Expand across last dim
[-1, cy, 1, 1]) # Batch
z = Normal(0., 1.).sample([y.shape[0], cz, 1, 1])
return {
'train': DataLoader(
MNIST(data_root, download=True,
transform=T.Compose(
[T.Resize(hw), T.CenterCrop(hw),
T.ToTensor(), T.Normalize([0.5], [0.5])]),
target_transform=to_target),
batch_size=b * ngpus,
shuffle=True, num_workers=ncpus),
'inference': list(
zip(y.unsqueeze(0), z.unsqueeze(0)))} # Turn into batches
@xmen.autodoc
def dcgan(
root: xmen.Root, #
# first argument is always an experiment instance.
# can be unused (specify with _) in experiments
# syntax practice. can be named whatever depending
# on use case. Eg. logger, root, experiment ...
b: int = 128, # the batch size per gpu
hw0: Tuple[int, int] = (4, 4), # the height and width of the image
nl: int = 4, # the number of levels in the discriminator.
data_root: str = os.getenv("HOME") + '/data/mnist', # @p the root data directory
cx: int = 1,
cy: int = 10, # the dimensionality of the conditioning vector
cf: int = 512, # the number of features after the first conv in the discriminator
cz: int = 100, # the dimensionality of the noise vector
ncpus: int = 8, # the number of threads to use for data loading
ngpus: int = 1, # the number of gpus to run the model on
epochs: int = 20, # no. of epochs to train for
lr: float = 0.0002, # learning rate
betas: Tuple[float, float] = (0.5, 0.999), # the beta parameters for the
# monitoring parameters
checkpoint: str = 'nn_.*@1e', # checkpoint at this modulo string
log: str = 'loss_.*@20s', # log scalars
sca: str = 'loss_.*@20s', # tensorboard scalars
img: str = '_x_|x$@20s', # tensorboard images
nimg: int = 64, # the maximum number of images to display to tensorboard
ns: int = 5 # the number of samples to generate at inference)
):
"""Train a conditional GAN to predict MNIST digits.
To viusalise the results run::
tensorboard --logdir ...
"""
from xmen.monitor import TorchMonitor, TensorboardLogger
from xmen.examples.models import weights_init, set_requires_grad, GeneratorNet, DiscriminatorNet
from torch.distributions import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.optim import Adam
import logging
hw = [d * 2 ** nl for d in hw0]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger = logging.getLogger()
logger.setLevel('INFO')
# dataset
datasets = get_datasets(
cy, cz, b, ngpus, ncpus, ns, data_root, hw)
# models
nn_g = GeneratorNet(cy, cz, cx, cf, hw0, nl)
nn_d = DiscriminatorNet(cx, cy, cf, hw0, nl)
nn_g = nn_g.to(device).float().apply(weights_init)
nn_d = nn_d.to(device).float().apply(weights_init)
# distributions
pz = Normal(torch.zeros([cz]), torch.ones([cz]))
py = OneHotCategorical(probs=torch.ones([cy]) / cy)
# optimisers
op_d = Adam(nn_d.parameters(), lr=lr, betas=betas)
op_g = Adam(nn_g.parameters(), lr=lr, betas=betas)
# monitor
m = TorchMonitor(
root.directory, ckpt=checkpoint,
log=log, sca=sca, img=img,
time=('@20s', '@1e'),
msg='root@100s',
img_fn=lambda x: x[:min(nimg, x.shape[0])],
hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)])
for _ in m(range(epochs)):
# (1) train
for x, y in m(datasets['train']):
# process input
x, y = x.to(device), y.to(device).float()
b = x.shape[0]
# discriminator step
set_requires_grad([nn_d], True)
op_d.zero_grad()
z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
_x_ = nn_g(y, z)
loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
loss_d.backward()
op_d.step()
# generator step
op_g.zero_grad()
y = py.sample([b]).reshape([b, cy, 1, 1]).to(device)
z = pz.sample([b]).reshape([b, cz, 1, 1]).to(device)
_x_ = nn_g(y, z)
set_requires_grad([nn_d], False)
loss_g = nn_d((_x_, y), True)
loss_g.backward()
op_g.step()
# (2) inference
if 'inference' in datasets:
with torch.no_grad():
for yi, zi in datasets['inference']:
yi, zi = yi.to(device), zi.to(device)
_xi_ = nn_g(yi, zi)
DCGAN using the class api¶
"""A class implementation of dcgan"""
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2019 Robert J Weston, Oxford Robotics Institute
#
# xmen
# email: robw@robots.ox.ac.uk
# github: https://github.com/robw4/xmen/
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
from typing import Tuple
import torch
from torch.distributions import Normal
from xmen.experiment import Experiment
class Dcgan(Experiment):
"""Train a conditional GAN to predict MNIST digits.
To viusalise the results run::
tensorboard --logdir ...
"""
import tempfile
b: int = 128 # @p the batch size per gpu
hw0: Tuple[int, int] = (4, 4) # @p the height and width of the image
nl: int = 4 # @p The number of levels in the discriminator.
data_root: str = tempfile.gettempdir() # @p the root data directory
cx: int = 1 # @p the dimensionality of the image input
cy: int = 10 # @p the dimensionality of the conditioning vector
cf: int = 512 # @p the number of features after the first conv in the discriminator
cz: int = 100 # @p the dimensionality of the noise vector
ncpus: int = 8 # @p the number of threads to use for data loading
ngpus: int = 1 # @p the number of gpus to run the model on
epochs: int = 100 # @p no. of epochs to train for
gan: str = 'lsgan' # @p the gan type to use (one of ['vanilla', 'lsgan'])
lr: float = 0.0002 # @p learning rate
betas: Tuple[float, float] = (0.5, 0.999) # @p The beta parameters for the
# Monitoring parameters
checkpoint: str = 'nn_.*@1e' # @p
log: str = 'loss_.*@20s' # @p log scalars
sca: str = 'loss_.*@20s' # @p tensorboard scalars
img: str = '_x_|x$@20s' # @p tensorboard images
time: str = ('@20s', '@1e') # @p timing modulos
nimg: int = 64 # @p The maximum number of images to display to tensorboard
ns: int = 5 # @p The number of samples to generate at inference
@property
def hw(self): return [d * 2 ** self.nl for d in self.hw0]
@property
def device(self): return 'cuda' if torch.cuda.is_available() else 'cpu'
def datasets(self):
"""Returns a dictionary of iterable get_datasets for modes 'train' (label image pairs)
and 'inference' (just inputs spanning the prediction space)"""
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as T
def to_target(y):
Y = torch.zeros([self.cy])
Y[y] = 1.
return Y.reshape([self.cy, 1, 1])
# Generate test samples
y = torch.stack([to_target(i) for i in range(self.cy)])
y = y.unsqueeze(0).expand(
[self.ns, self.cy, self.cy, 1, 1]).reshape( # Expand across last dim
[-1, self.cy, 1, 1]) # Batch
z = Normal(0., 1.).sample([y.shape[0], self.cz, 1, 1])
return {
'train': DataLoader(
MNIST(self.data_root, download=True,
transform=T.Compose(
[T.Resize(self.hw), T.CenterCrop(self.hw),
T.ToTensor(), T.Normalize([0.5], [0.5])]),
target_transform=to_target),
batch_size=self.b * self.ngpus,
shuffle=True, num_workers=self.ncpus),
'inference': list(
zip(y.unsqueeze(0), z.unsqueeze(0)))} # Turn into batches
def build(self):
"""Build generator, discriminator and optimisers."""
from torch.optim import Adam
from xmen.examples.models import GeneratorNet, DiscriminatorNet
nn_g = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
op_g = Adam(nn_g.parameters(), lr=self.lr, betas=self.betas)
nn_d = DiscriminatorNet(self.cx, self.cy, self.cf, self.hw0, self.nl)
op_d = Adam(nn_d.parameters(), lr=self.lr, betas=self.betas)
return nn_g, nn_d, op_g, op_d
def run(self):
"""Train the mnist dataset for a fixed number of epochs running an
inference loop after every epoch."""
from xmen.monitor import TorchMonitor, TensorboardLogger
from xmen.examples.models import weights_init, set_requires_grad
from torch.distributions import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical
# open_socket
datasets = self.datasets()
pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz]))
py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy)
nn_g, nn_d, op_g, op_d = self.build()
nn_g = nn_g.to(self.device).float().apply(weights_init)
nn_d = nn_d.to(self.device).float().apply(weights_init)
# training loop
m = TorchMonitor(
self.directory, ckpt=self.checkpoint,
log=self.log, sca=self.sca, img=self.img,
time=('@20s', '@1e'),
img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=10)])
for _ in m(range(self.epochs)):
for x, y in m(datasets['train']):
# process input
x, y = x.to(self.device), y.to(self.device).float()
b = x.shape[0]
# discriminator step
set_requires_grad([nn_d], True)
op_d.zero_grad()
z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
_x_ = nn_g(y, z)
loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
loss_d.backward()
op_d.step()
# generator step
op_g.zero_grad()
y = py.sample([b]).reshape([b, self.cy, 1, 1]).to(self.device)
z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
_x_ = nn_g(y, z)
set_requires_grad([nn_d], False)
loss_g = nn_d((_x_, y), True)
loss_g.backward()
op_g.step()
# inference
if 'inference' in datasets:
with torch.no_grad():
for yi, zi in datasets['inference']:
yi, zi = yi.to(self.device), zi.to(self.device)
_xi_ = nn_g(yi, zi)
Generative modelling with inheritance¶
"""Deep generative models with inheritance"""
# Copyright (C) 2019 Robert J Weston, Oxford Robotics Institute
#
# xmen
# email: robw@robots.ox.ac.uk
# github: https://github.com/robw4/xmen/
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
from typing import Tuple
import torch
from torch.distributions import Normal
from xmen.experiment import Experiment
import os
# ------------------------------------------------------
# ---- BASE EXPERIMENTS --------------------------------
# ------------------------------------------------------
class BaseGenerative(Experiment):
"""An abstract class defining parameters and some useful properties common to
both the VAE and cGAN implementations.
Inherited classes must overload:
- get_datasets()
- build()
- run()
Inherited classes can optionally overload:
- parameter defaults and documentation
Note:
The output size is given as hw0 * 2 ** nl (eg (4, 2) * 2 ** 4 = (64, 32)
"""
b: int = 256 # @p the batch size per gpu
hw0: Tuple[int, int] = (4, 4) # @p the height and width of the image
nl: int = 4 # @p The number of levels in the discriminator.
data_root: str = os.getenv("HOME") + '/data/mnist' # @p the root data directory
cx: int = 1 # @p the dimensionality of the image input
cy: int = 10 # @p the dimensionality of the conditioning vector
cf: int = 512 # @p the number of features after the first conv in the discriminator
cz: int = 100 # @p the dimensionality of the noise vector
ncpus: int = 8 # @p the number of threads to use for data loading
ngpus: int = 1 # @p the number of gpus to run the model on
epochs: int = 100 # @p no. of epochs to train for
gan: str = 'lsgan' # @p the gan type to use (one of ['vanilla', 'lsgan'])
lr: float = 0.0002 # @p learning rate
betas: Tuple[float, float] = (0.5, 0.999) # @p The beta parameters for the
# Monitoring parameters
checkpoint: str = 'nn_.*@1e' # @p
log: str = 'loss_.*@20s' # @p log scalars
sca: str = 'loss_.*@20s' # @p tensorboard scalars
img: str = '_x_|x$@20s' # @p tensorboard images
time: str = ('@20s', '@1e') # @p timing modulos
nimg: int = 64 # @p The maximum number of images to display to tensorboard
test_samples = None
# Useful properties
@property
def hw(self): return [d * 2 ** self.nl for d in self.hw0]
@property
def device(self): return 'cuda' if torch.cuda.is_available() else 'cpu'
def datasets(self): return NotImplementedError
def run(self): return NotImplementedError
def build(self): return NotImplementedError
class BaseCVae(BaseGenerative):
"""BaseCVae is an intermediary class defining the model and training loop but not the dataset.
Inherited classes should overload:
- dataset()
"""
nlprior: int = 1 # @p The number of hidden layers in the prior network
w_kl: float = 1.0 # @p The weighting on the KL divergence term
ns: int = 5 # @p The number of samples to generate at inference
predictive: str = 'prior' # @p Use either the 'prior' or 'posterior' predictive at inference
b = 32
log = 'loss.*|log_px|kl_qz_pz@20s'
sca = 'loss.*|log_px|kl_qz_pz@20s'
img = '_x_|x$@20s'
time = ('@10s', '@1e')
def build(self):
from torch.optim import Adam
from itertools import chain
from xmen.examples.torch.models import GeneratorNet, PosteriorNet, PriorNet
nn_gen = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
nn_post = PosteriorNet(self.cx, self.cy, self.cz, self.cf, self.hw0, self.nl)
nn_prior = PriorNet(self.cy, self.cz, self.cf, self.nlprior)
opt = Adam(chain(nn_gen.parameters(), nn_post.parameters(), nn_prior.parameters()),
self.lr, self.betas)
return nn_gen, nn_post, nn_prior, opt
def run(self):
from torch.distributions import Normal, kl_divergence
from xmen.monitor import TorchMonitor, TensorboardLogger
from xmen.examples.torch.models import weights_init
# construct model
datasets = self.datasets()
nn_gen, nn_post, nn_prior, opt = self.build()
nn_gen, nn_post, nn_prior = (
v.to(self.device).float().apply(weights_init) for v in (
nn_gen, nn_post, nn_prior))
if self.ngpus > 1:
nn_gen, nn_post, nn_prior = (torch.nn.DataParallel(n) for n in (
nn_gen, nn_post, nn_prior))
m = TorchMonitor(
self.directory, ckpt=self.checkpoint,
log=self.log, sca=self.sca, img=self.img,
img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
time=('@20s', '@1e'),
hooks=[TensorboardLogger('image', '_xi_$@20s', nrow=self.ns)])
for _ in m(range(self.epochs)):
# Training
for x, y in m(datasets['train']):
x, y = x.to(self.device), y.to(self.device).float()
qz = Normal(*nn_post(x, y))
z = qz.sample()
# Likelihood
_x_ = nn_gen(y, z)
px = Normal(_x_, 0.1)
pz = Normal(*nn_prior(y.reshape([-1, self.cy])))
log_px = px.log_prob(x).sum()
kl_qz_pz = kl_divergence(qz, pz).sum()
loss = self.w_kl * kl_qz_pz - log_px
opt.zero_grad()
loss.backward()
opt.step()
# Inference
if 'inference' in datasets:
with torch.no_grad():
for yi, zi in datasets['inference']:
if self.predictive == 'posterior':
qz = Normal(*nn_prior(yi.reshape([-1, self.cy])))
zi = qz.sample()
_xi_ = nn_gen(yi, zi)
class BaseGAN(BaseGenerative):
"""BaseGan is an intermediary class defining the model and training loop but not the dataset.
Inherited classes should overload:
- dataset()
- distributions()
"""
# overload previous parameter default (documentation is maintained)
b = 128
# define new parameters
ns: int = 5 # @p The number of samples to generate at inference
def build(self):
"""Build the model from the current configuration"""
from torch.optim import Adam
from xmen.examples.torch.models import GeneratorNet, DiscriminatorNet
nn_g = GeneratorNet(self.cy, self.cz, self.cx, self.cf, self.hw0, self.nl)
op_g = Adam(nn_g.parameters(), lr=self.lr, betas=self.betas)
nn_d = DiscriminatorNet(self.cx, self.cy, self.cf, self.hw0, self.nl)
op_d = Adam(nn_d.parameters(), lr=self.lr, betas=self.betas)
return nn_g, nn_d, op_g, op_d
def distributions(self): raise NotImplementedError
def run(self):
from xmen.monitor import TorchMonitor, TensorboardLogger
from xmen.examples.torch.models import set_requires_grad, weights_init
# Get get_datasets
datasets = self.datasets()
py, pz = self.distributions()
nn_g, nn_d, op_g, op_d = self.build()
nn_g = nn_g.to(self.device).float().apply(weights_init)
nn_d = nn_d.to(self.device).float().apply(weights_init)
m = TorchMonitor(
self.directory, ckpt=self.checkpoint,
log=self.log, sca=self.sca, img=self.img,
time=('@20s', '@1e'),
img_fn=lambda x: x[:min(self.nimg, x.shape[0])],
hooks=[TensorboardLogger('image', '_xi_$@1e', nrow=self.ns)])
for _ in m(range(self.epochs)):
for x, y in m(datasets['train']):
# process input
x, y = x.to(self.device), y.to(self.device).float()
b = x.shape[0]
# discriminator step
set_requires_grad([nn_d], True)
op_d.zero_grad()
z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
_x_ = nn_g(y, z)
loss_d = nn_d((x, y), True) + nn_d((_x_.detach(), y.detach()), False)
loss_d.backward()
op_d.step()
# generator step
op_g.zero_grad()
y = py.sample([b]).reshape([b, self.cy, 1, 1]).to(self.device)
z = pz.sample([b]).reshape([b, self.cz, 1, 1]).to(self.device)
_x_ = nn_g(y, z)
set_requires_grad([nn_d], False)
loss_g = nn_d((_x_, y), True)
loss_g.backward()
op_g.step()
# Inference
if 'inference' in datasets:
with torch.no_grad():
for yi, zi in datasets['inference']:
yi, zi = yi.to(self.device), zi.to(self.device)
_xi_ = nn_g(yi, zi)
class BaseMnist(BaseGenerative):
"""cDCGAN Training on the MNIST get_datasets"""
# Update defaults
hw0, nl = (4, 4), 3 # Default size = 32 x 32
data_root = os.getenv("HOME") + '/data/mnist'
cx, cy, cz, cf = 1, 10, 100, 512
ns = 10 # Number of samples to generate during inference
def datasets(self):
"""Configure the get_datasets used for training"""
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as T
def to_target(y):
Y = torch.zeros([self.cy])
Y[y] = 1.
return Y.reshape([self.cy, 1, 1])
transform = T.Compose(
[T.Resize(self.hw), T.CenterCrop(self.hw),
T.ToTensor(), T.Normalize([0.5], [0.5])])
y = torch.stack([to_target(i) for i in range(self.cy)])
y = y.unsqueeze(0).expand(
[self.ns, self.cy, self.cy, 1, 1]).reshape( # Expand across last dim
[-1, self.cy, 1, 1]) # Batch
z = Normal(0., 1.).sample([y.shape[0], self.cz, 1, 1])
return {
'train': DataLoader(
MNIST(self.data_root, download=True,
transform=transform,
target_transform=to_target),
batch_size=self.b * self.ngpus,
shuffle=True, num_workers=self.ncpus),
'inference': list(
zip(y.unsqueeze(0), z.unsqueeze(0)))} # Turn into batches
# ------------------------------------------------------
# ---- RUNABLE EXPERIMENTS -----------------------------
# ------------------------------------------------------
class InheritedMnistGAN(BaseMnist, BaseGAN):
"""Train a cDCGAN on MNIST"""
epochs = 20
ncpus, ngpus = 0, 1
b = 128
def distributions(self):
"""Generate one hot and normal samples"""
from torch.distributions import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical
pz = Normal(torch.zeros([self.cz]), torch.ones([self.cz]))
py = OneHotCategorical(probs=torch.ones([self.cy]) / self.cy)
return py, pz
class InheritedMnistVae(BaseMnist, BaseCVae):
"""Train a conditional VAE on MNIST"""
ns = 10
w_beta = 1.5
lr = 0.00005
ncpus, ngpus = 0, 2
Xmen Meets Pytorch Lightning¶
Note
this is currently experimental!
"""CURRENTLY EXPERIMENTAL"""
from pytorch_lightning import LightningModule
import torch
import torch.nn.functional as F
from typing import List, Any
from xmen.lightning import TensorBoardLogger, Trainer
class LitMNIST(LightningModule):
"""Example torch lighnning module taken from the docs"""
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x_in):
batch_size, channels, width, height = x_in.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x_in.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)
x = F.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log_dict(
{'loss': loss, 'x': x})
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_epoch_end(
self, outputs: List[Any]
) -> None:
self.log('loss_val', torch.stack(outputs).mean())
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def test_epoch_end(
self, outputs: List[Any]
) -> None:
self.log('loss_val', torch.stack(outputs).mean())
def lit_experiment(
root,
batch_size=64, # The batch size of the experiment
epochs=5, # Number of epochs to train for
):
"""Xmen meets pytorch_lightning"""
import xmen
import os
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
# prepare transforms standard to MNIST
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# data
mnist_train = MNIST('/tmp/xmen', train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=batch_size)
mnist_val = MNIST(os.getcwd(), train=False, download=True, transform=transform)
mnist_val = DataLoader(mnist_val, batch_size=batch_size)
model = LitMNIST()
trainer = Trainer(
default_root_dir=root.directory,
max_epochs=epochs,
logger=TensorBoardLogger(
root=root,
log=['loss@100s', 'loss_val'],
sca=['loss@100s', 'loss_val'],
img='x@100s',
time='@500s',
msg='loss@50s'
)
)
trainer.fit(model, mnist_train, mnist_val)
trainer.test(model, mnist_val)
if __name__ == '__main__':
from xmen.functional import functional_experiment
Exp, _ = functional_experiment(lit_experiment)
Exp().main()