Welcome to ASHPY’s documentation!

Welcome To AshPy!

Warning

AshPy is still a work in progress and may change substantially before the first proper release. The API is not mature enough to be considered stable, but we’ll try to keep breaking changes to a minimum.

AshPy

Python - Version PyPy - Version PyPI - License Ashpy - Badge Build Status Documentation Status codecov CodeFactor Contributions Black - Badge Contributor Covenant

AshPy is a TensorFlow 2.1 library for (distributed) training, evaluation, model selection, and fast prototyping. It is designed to ease the burden of setting up all the nuances of the architectures built to train complex custom deep learning models.

Quick Example | Features | Set Up | Usage | Dataset Output Format | Test

Quick Example

# define a distribution strategy
strategy = tf.distribute.MirroredStrategy()

# work inside the scope of the created strategy
with strategy.scope():

    # get the MNIST dataset
    train, validation = tf.keras.datasets.mnist.load_data()

    # process data if needed
    def process(images, labels):
        data_images = tf.data.Dataset.from_tensor_slices((images)).map(
            lambda x: tf.reshape(x, (28 * 28,))
        )
        data_images = data_images.map(
            lambda x: tf.image.convert_image_dtype(x, tf.float32)
        )
        data_labels = tf.data.Dataset.from_tensor_slices((labels))
        dataset = tf.data.Dataset.zip((data_images, data_labels))
        dataset = dataset.batch(1024 * 1)
        return dataset

    # apply the process function to the data
    train, validation = (
        process(train[0], train[1]),
        process(validation[0], validation[1]),
    )

    # create the model
    model = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(10, activation=tf.nn.sigmoid),
            tf.keras.layers.Dense(10),
        ]
    )

    # define the optimizer
    optimizer = tf.optimizers.Adam(1e-3)

    # the loss is provided by the AshPy library
    loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
    logdir = "testlog"
    epochs = 10

    # the metrics are provided by the AshPy library
    # and every metric with model_selection_operator != None performs
    # model selection, saving the best model in a different folder per metric.
    metrics = [
        ClassifierMetric(
            tf.metrics.Accuracy(), model_selection_operator=operator.gt
        ),
        ClassifierMetric(
            tf.metrics.BinaryAccuracy(), model_selection_operator=operator.gt
        ),
    ]

    # define the AshPy trainer
    trainer = ClassifierTrainer(
        model, optimizer, loss, epochs, metrics, logdir=logdir
    )

    # run the training process
    trainer(train, validation)

Features

AshPy is a library designed to ease the burden of setting up all the nuances of the architectures built to train complex custom deep learning models. It provides both fully convolutional and fully connected models such as:

  • autoencoder
  • decoder
  • encoder

and a fully convolutional:

  • unet

Moreover, it provides already prepared trainers for a classifier model and GAN networks. In particular, in regards of the latter, it offers a basic GAN architecture with a Generator-Discriminator structure and an enhanced GAN architecture version made up of a Encoder-Generator-Discriminator structure.


AshPy it is developed around the concepts of Executor, Context, Metric, and Strategies that represents its foundations.

Executor An Executor is a class that helps to better generalize a training loop. With an Executor you can construct, for example, a custom loss function and put whatever computation you need inside it. You should define a call function inside your class and decorate it with @Executor.reduce header. Inside the call function you can take advantage of a context.

Context A Context is a useful class in which all the models, metrics, dataset and mode of your network are set. Passing the context around means that you can any time access to all what you need in order to performs any type of computation.

Metric A Metric is a class from which you can inherit to create your custom metric that can automatically keep track of the best performance of the model during training and, automatically save the best one doing what is called the model selection.

Strategies If you want to distribute your training across multiple GPUs, there is the tf.distribute.Strategy TensorFlow API with which you can distribute your models and training code with minimal code changes. AshPy implements this type of strategies internally and will check everything for you to apply the distribution strategy correctly. All you need to do is as simple as doing the following:

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():

    generator = ConvGenerator(
        layer_spec_input_res=(7, 7),
        layer_spec_target_res=(28, 28),
        kernel_size=(5, 5),
        initial_filters=256,
        filters_cap=16,
        channels=1,
    )
    # rest of the code
    # with trainer definition and so on

i.e., create the strategy and put the rest of the code inside its scope.

In general AshPy aims to:

  • Rapid model prototyping
  • Enforcement of best practices & API consistency
  • Remove duplicated and boilerplate code
  • General usability by new project

NOTE: We invite you to read the full documentation on the official website.

The following README aims to help you understand what you need to do to setup AshPy on your system and, with some examples, what you need to do to setup a complete training of your network. Moreover, it will explain some fundamental modules you need to understand to fully exploit the potential of the library.

Set up

Pip install
pip install ashpy
Source install

Clone this repo, go inside the downloaded folder and install with:

pip install -e .

Usage

Let’s quickly start with some examples.

Classifier

Let’s say we want to train a classifier.

import operator
import tensorflow as tf
from ashpy.metrics import ClassifierMetric
from ashpy.trainers.classifier import ClassifierTrainer
from ashpy.losses.classifier import ClassifierLoss

def toy_dataset():
    inputs = tf.expand_dims(tf.range(1, 1000.0), -1)
    labels = tf.expand_dims([1 if tf.equal(tf.math.mod(tf.squeeze(i), 2), 0) else 0 for i in inputs], -1)
    return tf.data.Dataset.from_tensor_slices((inputs,labels)).shuffle(10).batch(2)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation=tf.nn.sigmoid),
    tf.keras.layers.Dense(2)
])

optimizer = tf.optimizers.Adam(1e-3)
loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
logdir = "testlog"
epochs = 2

metrics = [
    ClassifierMetric(tf.metrics.Accuracy(), model_selection_operator=operator.gt),
    ClassifierMetric(tf.metrics.BinaryAccuracy(), model_selection_operator=operator.gt),
]

trainer = ClassifierTrainer(model, optimizer, loss, epochs, metrics, logdir=logdir)

train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)

Skipping the toy_dataset() function that creates a toy dataset, we’ll give a look to the code step by step.

So, first of all we define a model and its optimizer. Here, the model is a very simple sequential Keras model defined as:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation=tf.nn.sigmoid),
    tf.keras.layers.Dense(2)
])

optimizer = tf.optimizers.Adam(1e-3)

Then we define the loss:

loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))

The ClassifierLoss loss defined above it is defined using an internal class called “Executor”. The Executor is a class that let you define, alongside with a desired loss, the function that you want to use to “evaluate” that loss with all the needed parameters.

This works in conjunction with the following line (we will speak about the “metrics” and the other few definition lines in a minute):

trainer = ClassifierTrainer(model, optimizer, loss, epochs, metrics, logdir=logdir)

where a ClassifierTrainer is an object designed to run a specific training procedure adjusted, in this case, for a classifier.

The arguments of this function are the model, the optimizer, the loss, the number of epochs, the metrics and the logdir. We have already seen the definition of the model, the optimizer and of the loss. The definition of epochs, metrics and logdir happens here:

logdir = "testlog"
epochs = 2

metrics = [
    ClassifierMetric(tf.metrics.Accuracy(), model_selection_operator=operator.gt),
    ClassifierMetric(
    tf.metrics.BinaryAccuracy(),model_selection_operator=operator.gt),
]

What we need to underline here is the definition of the metrics because as you can see they are defined through the use of specific classes: ClassifierMetric. As for the ClassifierTrainer, the ClassifierMetric it is a specified designed class for the Classifier. If you want to create a different metric you should inheriting from the Metric class provided by the Ash library. This kind of Metrics are useful because you can indicate a processing function to apply on predictions (e.g., tf.argmax) and an operator (e.g., operator.gt is the “greater than” operator) if you desire to activate the model selection during the training process based on that particular metric.

Finally, once the datasets has been set, you can start the training procedure calling the trainer object:

train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)

GAN - Generative Adversarial Network

AshPy is equipped with two types of GAN network architectures:

  • A plain GAN network with the classic structure Generator - Discriminator.
  • A more elaborated GAN network architecture with the classic Generator - Discriminator structure plus an Encoder model (BiGAN like).

As for the previous classifier training example, let’s see for first a simple example of an entire “toy” code, regarding a simple plain GAN. At the end we will briefly touch upon the differences with the GAN network with the Encoder.

import operator
import tensorflow as tf
from ashpy.models.gans import ConvGenerator, ConvDiscriminator
from ashpy.metrics import InceptionScore
from ashpy.losses.gan import DiscriminatorMinMax, GeneratorBCE

generator = ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

# Losses
generator_bce = GeneratorBCE()
minmax = DiscriminatorMinMax()

# Real data
batch_size = 2
mnist_x, mnist_y = tf.zeros((100,28,28)), tf.zeros((100,))

# Trainer
epochs = 2
logdir = "testlog/adversarial"

metrics = [
    InceptionScore(
        # Fake inception model
        ConvDiscriminator(
            layer_spec_input_res=(299, 299),
            layer_spec_target_res=(7, 7),
            kernel_size=(5, 5),
            initial_filters=16,
            filters_cap=32,
            output_shape=10,
        ),
        model_selection_operator=operator.gt,
        logdir=logdir,
    )
]

trainer = AdversarialTrainer(
    generator,
    discriminator,
    tf.optimizers.Adam(1e-4),
    tf.optimizers.Adam(1e-4),
    generator_bce,
    minmax,
    epochs,
    metrics,
    logdir,
)

# Dataset
noise_dataset = tf.data.Dataset.from_tensors(0).repeat().map(
    lambda _: tf.random.normal(shape=(100,), dtype=tf.float32, mean=0.0, stddev=1)
).batch(batch_size).prefetch(1)

# take only 2 samples to speed up tests
real_data = tf.data.Dataset.from_tensor_slices(
        (tf.expand_dims(mnist_x, -1), tf.expand_dims(mnist_y, -1))
    ).take(batch_size).batch(batch_size).prefetch(1)

# Add noise in the same dataset, just by mapping.
# The return type of the dataset must be: tuple(tuple(a,b), noise)
dataset = real_data.map(lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100))))

trainer(dataset)

First we define the generator and discriminator of the GAN architecture:

generator = ConvGenerator(
    layer_spec_input_res=(7, 7),
    layer_spec_target_res=(28, 28),
    kernel_size=(5, 5),
    initial_filters=32,
    filters_cap=16,
    channels=1,
)

discriminator = ConvDiscriminator(
    layer_spec_input_res=(28, 28),
    layer_spec_target_res=(7, 7),
    kernel_size=(5, 5),
    initial_filters=16,
    filters_cap=32,
    output_shape=1,
)

and then we define the losses:

# Losses
generator_bce = GeneratorBCE()
minmax = DiscriminatorMinMax()

where GeneratorBCE() and DiscriminatorMinMax() are the losses defined inheriting Executor. Again, as we have seen in the previous classifier example, you can customize this type (the ones inheriting from the Executor) of losses.

The metrics are defined as follow:

metrics = [
    InceptionScore(
    # Fake inception model
        ConvDiscriminator(
        layer_spec_input_res=(299, 299),
        layer_spec_target_res=(7, 7),
        kernel_size=(5, 5),
        initial_filters=16,
        filters_cap=32,
        output_shape=10,
        ),
        model_selection_operator=operator.gt,
        logdir=logdir,
    )
]

and in particular here we have the InceptionScore metric constructed on the fly with the ConvDiscriminator class provided by AshPy.

Finally, the actual trainer is constructed and then called:

trainer = AdversarialTrainer(
    generator,
    discriminator,
    tf.optimizers.Adam(1e-4),
    tf.optimizers.Adam(1e-4),
    generator_bce,
    minmax,
    epochs,
    metrics,
    logdir,
)
trainer(dataset)

The main difference with a GAN architecture with an Encoder is that we would have the encoder loss:

encoder_bce = EncoderBCE()

an encoder accuracy metric:

metrics = [EncodingAccuracy(classifier, model_selection_operator=operator.gt, logdir=logdir)]

and an EncoderTrainer:

trainer = EncoderTrainer(
    generator,
    discriminator,
    encoder,
    tf.optimizers.Adam(1e-4),
    tf.optimizers.Adam(1e-5),
    tf.optimizers.Adam(1e-6),
    generator_bce,
    minmax,
    encoder_bce,
    epochs,
    metrics=metrics,
    logdir=logdir,
)

Note that the EncoderTrainer indicates a trainer of a GAN network with an Encoder and not a trainer of an Encoder itself.

Dataset Output Format

In order to standardize the GAN training, AshPy requires the input dataset to be in a common format. In particular, the dataset return type must always be in the format showed below, where the fist element of the tuple is the discriminator input, and the second is the generator input.

tuple(tuple(a,b), noise)

Where a is the input sample, b is the label/condition (if any, otherwise fill it with 0), and noise is the latent vector of input.

To train Pix2Pix-like architecture, that have no noise as ConvGenerator input, just return the values in thee format (tuple(a,b), b) since the condition is the generator input.

Test

In order to run the tests (with the doctests), linting and docs generation simply use tox.

tox

Write The Docs!

Ash being a project built with a Documentation Driven approach means that a solid, automated documentation procedure is a mandatory requirement.

The core components of our systems are:

  • Sphinx for the documentation generation
  • reStructuredText as the markup language
  • Google Style docstrings for in-code documentation
  • `vale`_ and vale-styles
  • Automatic internal deployment via GitLab Pages CI/CD integration

This document goal is threefold:

  1. Explaining the Documentation Architecture, the steps taken to automate it and defending such choices
  2. Serve as a future reference for other projects
  3. Act as an example for the Guide format and a demo of Sphinx + reST superpowers
  4. Convince you of the need to always be on the lookout for errors even in a perfect
    system.

The Whys

Why Sphinx?

Sphinx is the most used documentation framework for Python, developed for the Standard library itself it’s now adopted by all the most known third party libraries. What makes Sphinx so great is the combination of extensibility via themes, extensions and what not, coupled with a plethora of builtin functionalities that make writing decs a breeze.:

An example from Sphinx Site:

  • Output formats: HTML (including Windows HTML Help), LaTeX (for printable PDF versions), ePub, Texinfo, manual pages, plain text
  • Extensive cross-references: semantic markup and automatic links for functions, classes, citations, glossary terms and similar pieces of information
  • Hierarchical structure: easy definition of a document tree, with automatic links to siblings, parents and children
  • Automatic indices: general index as well as a language-specific module indices
  • Code handling: automatic highlighting using the Pygments highlighter
  • Extensions: automatic testing of code snippets, inclusion of docstrings from Python modules (API docs), and more
  • Contributed extensions: more than 50 extensions contributed by users in a second repository; most of them installable from PyPI

Why reST?

More than why reST, the real question is Why not Markdown?

While Markdown can be easier and slightly quicker to write, it does not offer the same level of fine grained control, necessary for an effort as complex as technical writing, without sacrificing portability.

Eric Holscher has an aptly named article: Why You Shouldn’t Use “Markdown” for Documentation, he is one of the greatest documentation advocate out there. Go and read his articles, they are beautiful.

Why Google Style for Docstrings?

Google Docstrings are to us the best way to organically combine code and documentation. Leveraging Napoleon, a Sphinx extension offering automatic documentation support for both Numpy and Google docstrings style, we can write easy to read docstrings and still be able to use autodoc and autosummary directives.

Documentation Architecture

Tutorials, Guides, Complex Examples

Any form of documentation which is not generated from the codebase should go here. Parent/Entry Point reStructuredText file, should be added in docs/src and then referenced in index.rst

API Reference

API reference contains the full API documentation automatically generated from our codebase. The only manual step required is adding the module you want to document to the api.rst located inside docs/source.

Automate all the docs!

Classes, Functions, Exceptions: Annotate them normally, they do not require anything else.

Autosummary & submodules with imports: A painful story

Exposing Python objects to their parent module by importing them in its __init__.py file, breaks the autosummary directives when combining it with the automatic generation of stub files. Currently there’s no way of making autosummary aware of the imported objects thus if you desire to document that API piece you need to find a workaround.

Example

Suppose we have the following structure:

keras/
   |---> __init__.py
   |
   |---> models.py

And that these two file contains respectively:

  • __init__.py
from .models import Model

__ALL__ = ["Model"]
  • models.py
class Model:
   pass

Calling the autosummary directive (with the toctree option) on keras will not generate stub files for keras.Model causing it to not show in the Table of Contents of our API reference.

To circumvent this limitation it is ideal to insert some manual labour into the keras docstring.

  • __init__.py
"""
Documentation example.

.. rubric:: Classes

.. autosummary:: Classes
   :toctree: _autosummary
   :nosignatures:

   keras.Model

.. rubric:: Submodules

.. autosummary:: keras.models
   :toctree: _autosummary
   :nosignatures:
   :template: autosummary/submodule.rst

   keras.models
"""
from .models import Model

__ALL__ = ["Model"]

This way autosummary will produce the proper API documentation. The same approach applies also when exposing functions,exceptions, and modules.

Note

used when annotating submodules.

Inheritance Diagrams

Inheritance Diagrams are drawn using sphinx.ext.inheritance_diagram and sphinx.ext.graphviz.

The autosummary template for classes has been modified in order to automatically generate an inheritance diagram just below the title.

An Inheritances Diagrams page is manually created in order to showcase all the diagrams in one single page. The page gives a quick overview of the relations between the classes of each module.

Getting Started

Datasets

AshPy supports tf.data.Dataset format.

We highly encourage you to use Tensorflow Datasets to manage and use your datasets in an handy way.

pip install tfds-nightly

Classification

In order to create a dataset for classification:

import tensorflow_datasets as tfds

from ashpy.trainers import ClassifierTrainer

def extract_fn(example):
    return example["image"], example["label"]

def main():
    ds_train, ds_validation = tfds.load(name="mnist", split=["train", "validation"])

    # build the input pipeline
    ds_train = ds_train.batch(BATCH_SIZE).prefetch(1)
    ds_train = ds_train.map(extract_fn)

    # same for validation
    ...

    # define model, loss, optimizer
    ...

    # define the classifier trainer
    trainer = ClassifierTrainer(model, optimizer, loss, epochs, metrics, logdir=logdir)

    # train
    trainer.train(ds_train, ds_validation)

GANs

In order to create a datasets for a (Conditional) GANs:

import tensorflow_datasets as tfds

from ashpy.trainers import AdversarialTrainer

def extract_fn(example):
    # the ashpy input must be (real, condition), condition
    return (example["image"], example["label"]), example["label"]

def main():
    ds_train = tfds.load(name="mnist", split="train")

    # build the input pipeline
    ds_train = ds_train.batch(BATCH_SIZE).prefetch(1)
    ds_train = ds_train.map(extract_fn)

    # define models, losses, optimizers
    ...

    # define the adversarial trainer
    trainer = AdversarialTrainer(generator,
        discriminator,
        generator_optimizer,
        discriminator_optimizer,
        generator_loss,
        discriminator_loss,
        epochs,
        metrics,
        logdir,
    )

    # train
    trainer.train(ds_train)

Models

AshPy supports Keras models as inputs. You can use an AshPy predefined model or you can implement your own model.

Using an AshPy model

import tensorflow_datasets as tfds

from ashpy.trainers import ClassifierTrainer
from ashpy.models import UNet

def main():

    # create the dataset and the input pipeline

    # define models, loss, optimizer
    model = UNet(
        input_res,
        min_res,
        kernel_size,
        initial_filters,
        filters_cap,
        channels,
        use_dropout_encoder,
        use_dropout_decoder,
        dropout_prob,
        use_attention,
    )

    # define the classifier trainer
    trainer = AdversarialTrainer(generator,
        discriminator,
        generator_optimizer,
        discriminator_optimizer,
        generator_loss,
        discriminator_loss,
        epochs,
        metrics,
        logdir,
    )

    # train
    trainer.train(ds_train)

Creating a Model

It’s very easy to create a simple model, since AshPy’s models are Keras’ models.

from ashpy.layers import Attention, InstanceNormalization

def downsample(
    filters,
    apply_normalization=True,
    attention=False,
    activation=tf.keras.layers.LeakyReLU(alpha=0.2),
    size=3,
):
initializer = tf.random_normal_initializer(0.0, 0.02)

result = tf.keras.Sequential()
result.add(
    tf.keras.layers.Conv2D(
        filters,
        size,
        strides=2,
        padding="same",
        kernel_initializer=initializer,
        use_bias=not apply_normalization,
    )
)

if apply_normalization:
    result.add(InstanceNormalization())

result.add(activation)

if attention:
    result.add(Attention(filters))

return result


def upsample(
    filters,
    apply_dropout=False,
    apply_normalization=True,
    attention=False,
    activation=tf.keras.layers.ReLU(),
    size=3,
):
    initializer = tf.random_normal_initializer(0.0, 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
    result.add(tf.keras.layers.ZeroPadding2D(padding=(1, 1)))

    result.add(
        tf.keras.layers.Conv2D(
            filters,
            size,
            strides=1,
            padding="valid",
            kernel_initializer=initializer,
            use_bias=False,
        )
    )

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    if apply_normalization:
        result.add(Normalizer())

    result.add(activation)

    if attention:
        result.add(Attention(filters))

    return result


def Generator(attention, output_channels=3):
    down_stack = [
        downsample(32, apply_normalization=False),  # 256
        downsample(32),  # 128
        downsample(64, attention=attention),  # 64
        downsample(64),  # 32
        downsample(64),  # 16
        downsample(128),  # 8
        downsample(128),  # 4
        downsample(256),  # 2
        downsample(512, apply_normalization=False),  # 1
    ]

    up_stack = [
        upsample(256, apply_dropout=True),  # 2
        upsample(128, apply_dropout=True),  # 4
        upsample(128, apply_dropout=True),  # 8
        upsample(64),  # 16
        upsample(64),  # 32
        upsample(64, attention=attention),  # 64
        upsample(32),  # 128
        upsample(32),  # 256
        upsample(32),  # 512
    ]

    inputs = tf.keras.layers.Input(shape=[None, None, 1])
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    last = upsample(
        output_channels,
        activation=tf.keras.layers.Activation(tf.nn.tanh),
        apply_normalization=False,
    )

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In this way we have created a new model to be used inside AshPy.

Inheriting from ashpy.models.Conv2DInterface

The third possibility you have to create a new model is to inherit from the ashpy.models.convolutional.interfaces.Conv2DInterface.

This class offers the basic methods to implement in a simple way a new model.

Creating a new Trainer

AshPy has different generics trainers. Trainers implement the basic training loop together with distribution strategy management and logging. By now the only distribution strategy handled is the tf.distribute.MirroredStrategy.


Complete Examples

Classifier

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example of Multi-GPU classifier trainer."""

import operator

import tensorflow as tf

from ashpy.losses import ClassifierLoss
from ashpy.metrics import ClassifierMetric
from ashpy.trainers import ClassifierTrainer


def main():
    """
    Train a multi-GPU classifier.

    How to use ash to training_set a classifier, measure the
    performance and perform model selection.
    """
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        training_set, validation_set = tf.keras.datasets.mnist.load_data()

        def process(images, labels):
            data_images = tf.data.Dataset.from_tensor_slices((images)).map(
                lambda x: tf.reshape(x, (28 * 28,))
            )
            data_images = data_images.map(
                lambda x: tf.image.convert_image_dtype(x, tf.float32)
            )
            data_labels = tf.data.Dataset.from_tensor_slices((labels))
            dataset = tf.data.Dataset.zip((data_images, data_labels))
            dataset = dataset.batch(1024 * 1)
            return dataset

        training_set, validation_set = (
            process(training_set[0], training_set[1]),
            process(validation_set[0], validation_set[1]),
        )

        model = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(10, activation=tf.nn.sigmoid),
                tf.keras.layers.Dense(10),
            ]
        )
        optimizer = tf.optimizers.Adam(1e-3)
        loss = ClassifierLoss(tf.losses.SparseCategoricalCrossentropy(from_logits=True))
        logdir = "testlog"
        epochs = 10

        metrics = [
            ClassifierMetric(
                tf.metrics.Accuracy(), model_selection_operator=operator.gt
            ),
            ClassifierMetric(
                tf.metrics.BinaryAccuracy(), model_selection_operator=operator.gt
            ),
        ]

        trainer = ClassifierTrainer(
            model=model,
            optimizer=optimizer,
            loss=loss,
            epochs=epochs,
            metrics=metrics,
            logdir=logdir,
        )
        trainer(training_set, validation_set)


if __name__ == "__main__":
    main()

GANs

BiGAN
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bigan dummy implementation."""

import operator

import tensorflow as tf
from tensorflow import keras

from ashpy.losses import DiscriminatorMinMax, EncoderBCE, GeneratorBCE
from ashpy.metrics import EncodingAccuracy
from ashpy.trainers import EncoderTrainer


def main():
    """Define the trainer and the models."""

    def real_gen():
        """Define generator of real values."""
        for _ in tf.range(100):
            yield ((10.0,), (0,))

    num_classes = 1
    latent_dim = 100

    generator = keras.Sequential([keras.layers.Dense(1)])

    left_input = tf.keras.layers.Input(shape=(1,))
    left = tf.keras.layers.Dense(10, activation=tf.nn.elu)(left_input)

    right_input = tf.keras.layers.Input(shape=(latent_dim,))
    right = tf.keras.layers.Dense(10, activation=tf.nn.elu)(right_input)

    net = tf.keras.layers.Concatenate()([left, right])
    out = tf.keras.layers.Dense(1)(net)

    discriminator = tf.keras.Model(inputs=[left_input, right_input], outputs=[out])

    encoder = keras.Sequential([keras.layers.Dense(latent_dim)])
    generator_bce = GeneratorBCE()
    encoder_bce = EncoderBCE()
    minmax = DiscriminatorMinMax()

    epochs = 100
    logdir = "log/adversarial/encoder"

    # Fake pre-trained classifier
    classifier = tf.keras.Sequential(
        [tf.keras.layers.Dense(10), tf.keras.layers.Dense(num_classes)]
    )

    metrics = [
        EncodingAccuracy(
            classifier, model_selection_operator=operator.gt, logdir=logdir
        )
    ]

    trainer = EncoderTrainer(
        generator=generator,
        discriminator=discriminator,
        encoder=encoder,
        generator_optimizer=tf.optimizers.Adam(1e-4),
        discriminator_optimizer=tf.optimizers.Adam(1e-5),
        encoder_optimizer=tf.optimizers.Adam(1e-6),
        generator_loss=generator_bce,
        discriminator_loss=minmax,
        encoder_loss=encoder_bce,
        epochs=epochs,
        metrics=metrics,
        logdir=logdir,
    )

    batch_size = 10
    discriminator_input = tf.data.Dataset.from_generator(
        real_gen, (tf.float32, tf.int64), ((1), (1))
    ).batch(batch_size)

    dataset = discriminator_input.map(
        lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
    )

    trainer(dataset)


if __name__ == "__main__":
    main()
MNIST
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Adversarial trainer example."""

import tensorflow as tf
from tensorflow import keras  # pylint: disable=no-name-in-module

from ashpy.losses import DiscriminatorMinMax, GeneratorBCE
from ashpy.metrics import InceptionScore
from ashpy.models.gans import ConvDiscriminator, ConvGenerator
from ashpy.trainers import AdversarialTrainer


def main():
    """Adversarial trainer example."""
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():

        generator = ConvGenerator(
            layer_spec_input_res=(7, 7),
            layer_spec_target_res=(28, 28),
            kernel_size=(5, 5),
            initial_filters=256,
            filters_cap=16,
            channels=1,
        )

        discriminator = ConvDiscriminator(
            layer_spec_input_res=(28, 28),
            layer_spec_target_res=(7, 7),
            kernel_size=(5, 5),
            initial_filters=32,
            filters_cap=128,
            output_shape=1,
        )

        # Losses
        generator_bce = GeneratorBCE()
        minmax = DiscriminatorMinMax()

        # Trainer
        logdir = "log/adversarial"

        # InceptionScore: keep commented until the issues
        # https://github.com/tensorflow/tensorflow/issues/28599
        # https://github.com/tensorflow/hub/issues/295
        # Haven't been solved and merged into tf2

        metrics = [
            # InceptionScore(
            #    InceptionScore.get_or_train_inception(
            #        mnist_dataset,
            #        "mnist",
            #        num_classes=10,
            #        epochs=1,
            #        fine_tuning=False,
            #        logdir=logdir,
            #    ),
            #    model_selection_operator=operator.gt,
            #    logdir=logdir,
            # )
        ]

        epochs = 50
        trainer = AdversarialTrainer(
            generator=generator,
            discriminator=discriminator,
            generator_optimizer=tf.optimizers.Adam(1e-4),
            discriminator_optimizer=tf.optimizers.Adam(1e-4),
            generator_loss=generator_bce,
            discriminator_loss=minmax,
            epochs=epochs,
            metrics=metrics,
            logdir=logdir,
        )

        batch_size = 512

        # Real data
        mnist_x, mnist_y = keras.datasets.mnist.load_data()[0]

        def iterator():
            """Define an iterator in order to do not load in memory all the dataset."""
            for image, label in zip(mnist_x, mnist_y):
                yield tf.image.convert_image_dtype(
                    tf.expand_dims(image, -1), tf.float32
                ), tf.expand_dims(label, -1)

        real_data = (
            tf.data.Dataset.from_generator(
                iterator, (tf.float32, tf.int64), ((28, 28, 1), (1,))
            )
            .batch(batch_size)
            .prefetch(1)
        )

        # Add noise in the same dataset, just by mapping.
        # The return type of the dataset must be: tuple(tuple(a,b), noise)
        dataset = real_data.map(
            lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, 100)))
        )

        trainer(dataset)


if __name__ == "__main__":
    main()
Facades (Pix2Pix)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Pix2Pix on Facades Datasets dummy implementation.

Input Pipeline taken from: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
"""
from pathlib import Path

import tensorflow as tf

from ashpy import LogEvalMode
from ashpy.losses.gan import (
    AdversarialLossType,
    Pix2PixLoss,
    get_adversarial_loss_discriminator,
)
from ashpy.models.convolutional.discriminators import PatchDiscriminator
from ashpy.models.convolutional.unet import FUNet
from ashpy.trainers.gan import AdversarialTrainer

_URL = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"

PATH_TO_ZIP = tf.keras.utils.get_file("facades.tar.gz", origin=_URL, extract=True)
PATH = Path(PATH_TO_ZIP).parent / "facades"

BUFFER_SIZE = 100
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256


def load(image_file):
    """Load the image from file path."""
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    width = tf.shape(image)[1]

    width = width // 2
    real_image = image[:, :width, :]
    input_image = image[:, width:, :]

    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image


def resize(input_image, real_image, height, width):
    """Resize input_image and real_image to height x width."""
    input_image = tf.image.resize(
        input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    real_image = tf.image.resize(
        real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )

    return input_image, real_image


def random_crop(input_image, real_image):
    """Random crop both input_image and real_image."""
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]
    )

    return cropped_image[0], cropped_image[1]


def normalize(input_image, real_image):
    """Normalize images in [-1, 1]."""
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image


def load_image_train(image_file):
    """Load and process the image_file to be ready for the training."""
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image


@tf.function
def random_jitter(input_image, real_image):
    """Apply random jitter to both input_image and real_image."""
    # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)

    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image


def main(
    kernel_size=5,
    learning_rate_d=2e-4,
    learning_rate_g=2e-4,
    g_input_res=IMG_WIDTH,
    g_min_res=1,
    g_initial_filters=64,
    g_filters_cap=512,
    use_dropout_encoder=False,
    use_dropout_decoder=True,
    d_target_res=32,
    d_initial_filters=64,
    d_filters_cap=512,
    use_dropout_discriminator=False,
    dataset_name="facades",
    resolution=256,
    epochs=100_000,
    dropout_prob=0.3,
    l1_loss_weight=100,
    gan_loss_weight=1,
    use_attention_d=False,
    use_attention_g=False,
    channels=3,
    gan_loss_type=AdversarialLossType.LSGAN,
):
    """Define Trainer and models."""
    generator = FUNet(
        input_res=g_input_res,
        min_res=g_min_res,
        kernel_size=kernel_size,
        initial_filters=g_initial_filters,
        filters_cap=g_filters_cap,
        channels=channels,  # color_to_label_tensor.shape[0],
        use_dropout_encoder=use_dropout_encoder,
        use_dropout_decoder=use_dropout_decoder,
        dropout_prob=dropout_prob,
        use_attention=use_attention_g,
    )
    discriminator = PatchDiscriminator(
        input_res=resolution,
        min_res=d_target_res,
        initial_filters=d_initial_filters,
        kernel_size=kernel_size,
        filters_cap=d_filters_cap,
        use_dropout=use_dropout_discriminator,
        dropout_prob=dropout_prob,
        use_attention=use_attention_d,
    )

    discriminator_loss = get_adversarial_loss_discriminator(gan_loss_type)()
    generator_loss = Pix2PixLoss(
        l1_loss_weight=l1_loss_weight,
        adversarial_loss_weight=gan_loss_weight,
        adversarial_loss_type=gan_loss_type,
    )

    metrics = []
    logdir = Path("log") / dataset_name / "run2"

    if not logdir.exists():
        logdir.mkdir(parents=True)

    trainer = AdversarialTrainer(
        generator=generator,
        discriminator=discriminator,
        generator_optimizer=tf.optimizers.Adam(learning_rate_g, beta_1=0.5),
        discriminator_optimizer=tf.optimizers.Adam(learning_rate_d, beta_1=0.5),
        generator_loss=generator_loss,
        discriminator_loss=discriminator_loss,
        epochs=epochs,
        metrics=metrics,
        logdir=logdir,
        log_eval_mode=LogEvalMode.TEST,
    )

    train_dataset = tf.data.Dataset.list_files(PATH + "train/*.jpg")
    train_dataset = train_dataset.shuffle(BUFFER_SIZE)
    train_dataset = train_dataset.map(load_image_train)
    train_dataset = train_dataset.batch(BATCH_SIZE)

    train_dataset = train_dataset.map(lambda x, y: ((y, x), x))

    trainer(
        # generator_input,
        train_dataset
    )


if __name__ == "__main__":
    main()

Advanced AshPy

Custom Metrics

AshPy Trainers can accept metrics that they will use for both logging and automatic model selection.

Implementing a custom Metric in AshPy can be done via two approach:

  1. Your metric is already available as a tf.keras.metrics.Metric and you want to use it as is.
  2. You need to write the implementation of the Metric from scratch or you need to alter the default behavior we provide for AshPy Metrics.

Wrapping Keras Metrics

In case number (1) what you want to do is to search for one of the Metrics provided by AshPy and use it as a wrapper around the one you wish to use.

Note

Passing an operator funciton to the AshPy Metric will enable model selection using the metric value.

The example below shows how to implement the Precision metric for an ClassifierTrainer.

import operator

from ashpy.metrics import ClassifierMetric
from ashpy.trainers import ClassifierTrainer
from tensorflow.keras.metrics import Precision

precision = ClassifierMetric(
    metric=tf.keras.metrics.Precision(),
    model_selection_operator=operator.gt,
    logdir=Path().cwd() / "log",
)

trainer = ClassifierTrainer(
    ...
    metrics = [precision]
    ...
)

You can apply this technique to any object derived and behaving as a tf.keras.metrics.Metric (i.e. the Metrics present in TensorFlow Addons)

Creating your own Metric

As an example of a custom Metric we present the analysis of the ashpy.metrics.classifier.ClassifierLoss.

class ClassifierLoss(Metric):
    """A handy way to measure the classification loss."""

    def __init__(
        self,
        name: str = "loss",
        model_selection_operator: Callable = None,
        logdir: Union[Path, str] = Path().cwd() / "log",
    ) -> None:
        """
        Initialize the Metric.

        Args:
            name (str): Name of the metric.
            model_selection_operator (:py:obj:`typing.Callable`): The operation that will
                be used when `model_selection` is triggered to compare the metrics,
                used by the `update_state`.
                Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted.
                .. note::
                    Model selection is done ONLY if an operator is specified here.
            logdir (str): Path to the log dir, defaults to a `log` folder in the current
                directory.

        """
        super().__init__(
            name=name,
            metric=tf.metrics.Mean(name=name, dtype=tf.float32),
            model_selection_operator=model_selection_operator,
            logdir=logdir,
        )

    def update_state(self, context: ClassifierContext) -> None:
        """
        Update the internal state of the metric, using the information from the context object.
        Args:
            context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
                holding all the information the Metric needs.

        """
        updater = lambda value: lambda: self._metric.update_state(value)
        for features, labels in context.dataset:
            loss = context.loss(
                context,
                features=features,
                labels=labels,
                training=context.log_eval_mode == LogEvalMode.TRAIN,
            )
            self._distribute_strategy.experimental_run_v2(updater(loss))

Warning

The name argument of the ashpy.metrics.metric.Metric.__init__() is a str identifier which should be unique across all the metrics used by your Trainer.

Custom Computation inside Metric.update_state()
  • This method is invoked during the training and receives a Context.
  • In this example, since we are working under the ClassifierTrainer we are using an ClassifierContext. For more information on the Context family of objects see AshPy Internals.
  • Inside this update_state state we won’t be doing any fancy computation, we just retrieve the loss value from the ClassifierContext and then we call the updater lambda from the fetched distribution strategy.
  • The active distribution strategy is automatically retrieved during the super(), this guarantees that every object derived from an ashpy.metrics.Metric will work flawlessly even in a distributed environment.
  • ashpy.metrics.metric.Metric.metric (here referenced as self._metric is the primitive tf.keras.metrics.Metric whose upadate_state() method we will be using to simplify our operations.
  • Custom computation will almost always be done via iteration over the data offered by the Context.

For a much more complex (but probably exhaustive) example have a look at the source code of ashpy.metrics.SlicedWassersteinDistance.

Custom Callbacks

Our Callback is built on the same base structure as a tf.keras.callbacks.Callback exposing methods acting as hooks for the same events.

  • on_train_start
  • on_epoch_start
  • on_batch_start
  • on_batch_end
  • on_epoch_end
  • on_train_end

Inside the ashpy.callbacks module we offer two primitive Callbacks classes to inherit from.

  1. ashpy.callbacks.Callback: is the most basic form of callback and the basic block for all the other.
  2. CounterCallback: is derived from ashpy.callbacks.Callback and contains built-in logic for triggering an event given a desired frequency.

Let’s take a look at the following example which is the callback used to log GANs output to TensorBoard - ashpy.callbacks.gan.LogImageGANCallback

class LogImageGANCallback(CounterCallback):
    def __init__(
        self,
        event: Event = Event.ON_EPOCH_END,
        name: str = "log_image_gan_callback",
        event_freq: int = 1,
    ) -> None:
        """
        Initialize the LogImageCallbackGAN.

        Args:
            event (:py:class:`ashpy.callbacks.events.Event`): event to consider.
            event_freq (int): frequency of logging.
            name (str): name of the callback.

        """
        super(LogImageGANCallback, self).__init__(
            event=event, fn=self._log_fn, name=name, event_freq=event_freq
        )

    def _log_fn(self, context: GANContext) -> None:
        """
        Log output of the generator to Tensorboard.

        Args:
            context (:py:class:`ashpy.contexts.gan.GANContext`): current context.

        """
        if context.log_eval_mode == LogEvalMode.TEST:
            out = context.generator_model(context.generator_inputs, training=False)
        elif context.log_eval_mode == LogEvalMode.TRAIN:
            out = context.fake_samples
        else:
            raise ValueError("Invalid LogEvalMode")

        log("generator", out, context.global_step)

Let’s start with the __init__() function, as for the Custom ashpy.metrics.Metric when inheriting from either Callback or CounterCallback respect the common part of the signature:

  • event: In AshPy we use an Enum - ashpy.callbacks.Event - to choose the event type you want the Callback to be triggered on.
  • name: Unique str identifier for the Callback
  • event_freq: Simple int specifying the frequency.
  • fn: A callable() this is the function that gets triggered. Inside AshPy we converged on using a private method called _log_fn() in each of our derived Callbacks. Whatever approach you choose, the function fed to fn should have a Context as input. For more information on the Context family of objects see AshPy Internals.

Warning

The name argument of the ashpy.callbacks.callback.Callback.__init__() is a str identifier which should be unique across all the callbacks used by your Trainer.

AshPy Internals

The two main concepts of AshPy internals are Context and Executor.

Context

A Context is an object that contains all the needed information. Here needed depends on the application. In AshPy the Context concept links a generic training loop with the loss function calculation and the model evaluation. A Context is a useful class in which all the models, metrics, dataset and mode of your network are set. Passing the context around means that you can any time access to all what you need in order to perform any type of computation.

In AshPy we have (until now) three types of contexts:

Classifier Context

The ClassifierContext is rather straightforward containing only:

  • classifier_model
  • loss
  • dataset
  • metrics
  • log_eval_mode
  • global_step
  • ckpt

In this way the loss function (Executor) can use the context in order to get the model and the needed information in order to correctly feed the model.

GAN Context

The basic GANContext is composed by:

  • dataset
  • generator_model
  • discriminator_model
  • generator_loss
  • discriminator_loss
  • metrics
  • log_eval_mode
  • global_step
  • ckpt

As we can see we have all information needed to define our training and evaluation loop.

GANEncoder Context

The GANEncoderContext extends the GANContext, contains all the information of the base class plus:

  • Encoder Model
  • Encoder Loss

Executor

The Executor is the main concept behind the loss function implementation in AshPy. An Executor is a class that helps in order to better generalize a training loop. With an Executor you can construct, for example, a custom loss function and put every computation you need inside it. You should define a call function inside your class and decorate it with @Executor.reduce header, if needed.

Inside the call function you can take advantage of a context.

Executors can be summed up, subtracted and multiplied by scalars.

An executor takes also care of the distribution strategy by reducing appropriately the loss (see Tensorflow Guide).

An Executor Example

In this example we will see the implementation of the Generator Binary CrossEntropy loss.

The __init__ method is straightforward, we need only to instantiate tf.losses.BinaryCrossentropy object and then we pass it to our parent:

class GeneratorBCE(GANExecutor):

    def __init__(self, from_logits=True):
        self.name = "GeneratorBCE"
        # call super passing the BinaryCrossentropy as function
        super().__init__(tf.losses.BinaryCrossentropy(from_logits=from_logits))

Then we need to implement the call function respecting the signature:

def call(self, context, *, fake, condition, training, **kwargs):

    # we need a function that gives us the correct inputs given the discriminator model
    fake_inputs = self.get_discriminator_inputs(
        context=context, fake_or_real=fake, condition=condition, training=training
    )

    # get the discriminator predictions from the discriminator model
    d_fake = context.discriminator_model(fake_inputs, training=training)

    # get the target prediction for the generator
    value = self._fn(tf.ones_like(d_fake), d_fake)

    # mean everything
    return tf.reduce_mean(value)

The function get_discriminator_inputs() returns the correct discriminator inputs using the context. The discriminator input can be the output of the generator (unconditioned case) or the output of the generator together with the condition (conditioned case).

The the call() uses the discriminator model inside the context in order to obtain the output of the discriminator when evaluated in the fake_inputs.

After that the self._fn() (BinaryCrossentropy) is used to get the value of the loss. This loss is then averaged.

In this way the executor computes correctly the loss function.

This is ok if we do not want use our code in a distribution strategy.

If we want to use our executor in a distribution strategy the only modifications are:

@Executor.reduce_loss
def call(self, context, *, fake, condition, training, **kwargs):

    # we need a function that gives us the correct inputs given the discriminator model
    fake_inputs = self.get_discriminator_inputs(
        context=context, fake_or_real=fake, condition=condition, training=training
    )

    # get the discriminator predictions from the discriminator model
    d_fake = context.discriminator_model(fake_inputs, training=training)

    # get the target prediction for the generator
    value = self._fn(tf.ones_like(d_fake), d_fake)

    # mean only over the axis 1
    return tf.reduce_mean(value, axis=1)

The important things are:

  • Executor.reduce_loss decoration: uses the Executor decorator in order to correctly reduce the loss
  • tf.reduce_mean(value, axis=1) (last line), we perform only the mean over the axis 1. The output of the call function

should be a tf.Tensor with shape (N, 1) or (N,). This is because the decorator performs the mean over the axis 0.

API Reference

ashpy.ashtypes Custom Type-Aliases.
ashpy.callbacks Callbacks in order to gain control over the training loop.
ashpy.contexts Contexts help gaining an easier control over the model selection and testing process of the models.
ashpy.keras Custom extensions of standard Keras components.
ashpy.layers Collection of layers.
ashpy.losses Collection of Losses.
ashpy.metrics Collection of Metrics.
ashpy.models Collection of Models.
ashpy.modes Various modalities used to configure certain ash behaviours.
ashpy.restorers Restorers allow for easy restoration of tracked objects from tf.train.Checkpoint.
ashpy.trainers Trainers help reducing boilerplate code by bootstrapping models training.

Dependencies Graph

ashpy.callbacks

Inheritance diagram of ashpy.callbacks.callback, ashpy.callbacks.counter_callback

Events

Inheritance diagram of ashpy.callbacks.events

ashpy.models

Convolutional

Inheritance diagram of ashpy.models.convolutional.interfaces, ashpy.models.convolutional.encoders, ashpy.models.convolutional.decoders

GANs

GANs models are just aliases.


ashpy.trainers

Adversarial

Inheritance diagram of ashpy.trainers.gan

Classifier

Inheritance diagram of ashpy.trainers.classifier

ashpy.restorers

Inheritance diagram of ashpy.restorers.restorer, ashpy.restorers.gan, ashpy.restorers.classifier

ashpy.layers

Layers

Inheritance diagram of ashpy.layers.attention, ashpy.layers.instance_normalization

ashpy.losses

Classifier Losses

Inheritance diagram of ashpy.losses.executor, ashpy.losses.classifier

GAN Losses

Inheritance diagram of ashpy.losses.executor, ashpy.losses.gan

ashpy.metrics

Classifier Metrics

Inheritance diagram of ashpy.metrics.classifier, ashpy.metrics.metric

GAN Metrics

Inheritance diagram of ashpy.metrics.gan, ashpy.metrics.metric

About

AshPy is an open-source project available on Github under APACHE licence.

The Framework is created and maintained primarily by the ML & CV Lab @ Zuru Tech.

Please contact ml@zuru.tech for doubt, information or suggestions.

Indices and tables