sonnet - module reference

This python module contains Neural Network Modules for TensorFlow.

Each module is a Python object which conceptually "owns" any variables required in that part of the Neural Network. The __call__ function on the object is used to connect that Module into the Graph, and this may be called repeatedly with sharing automatically taking place.

Everything public should be imported by this top level __init__.py so that the library can be used as follows:

import sonnet as snt

linear = snt.Linear(...)

Other Functions and Classes

class ACTCore

Adaptive computation time core.

Implementation of the model described in "Adaptive Computation Time for Recurrent Neural Networks" paper, https://arxiv.org/abs/1603.08983.

The ACTCore incorporates the pondering RNN of ACT, with different computation times for each element in the mini batch. Each pondering step is performed by the core passed to the constructor of ACTCore.

The output of the ACTCore is made of (act_out, (iteration, remainder), where

  • iteration counts the number of pondering step in each batch element;
  • remainder is the remainder as defined in the ACT paper;
  • act_out is the weighted average output of all pondering steps (see ACT paper for more info).

ACTCore.__init__(core, output_size, threshold, get_state_for_halting, max_steps=0, name='act_core')

Constructor.

Args:
  • core: A sonnet.RNNCore object. This should only take a single Tensor in input, and output only a single flat Tensor.
  • output_size: An integer. The size of each output in the sequence.
  • threshold: A float between 0 and 1. Probability to reach for ACT to stop pondering.
  • get_state_for_halting: A callable that can take the core state and return the input to the halting function.
  • max_steps: Integer >= 0, that controls the maximum number of ponder steps. If equal to 0, then this disables control.
  • name: A string. The name of this module.
Raises:
  • ValueError: if threshold is not between 0 and 1.
  • ValueError: if core has either nested outputs or outputs that are not one dimensional.

ACTCore.__call__(x, prev_state)

Connects the core to the graph.

Args:
  • x: Input Tensor of shape (batch_size, input_size).
  • prev_state: Previous state. This could be a Tensor, or a tuple of Tensors.
Returns:

The tuple (output, state) for this core.

Raises:
  • ValueError: if the Tensor x does not have rank 2.

ACTCore.batch_size

ACTCore.connected_subgraphs

Returns the subgraphs created by this module so far.

ACTCore.defun()

Wraps this modules call method in a callable graph function.

ACTCore.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

ACTCore.dtype

ACTCore.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

ACTCore.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.graph

Returns the Graph instance which the module is connected to, or None.

ACTCore.initial_state(*args, **kwargs)

ACTCore.is_connected

Returns true iff the Module been connected to the Graph at least once.

ACTCore.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.module_name

Returns the name of the Module.

ACTCore.name_scopes

Returns a tuple of all name_scopes generated by this module.

ACTCore.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.output_size

ACTCore.scope_name

Returns the full name of the Module's variable scope.

ACTCore.state_size

ACTCore.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ACTCore.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class AbstractModule

Superclass for Sonnet Modules.

This class defines the functionality that every module should implement, principally the build method which is wrapped using tf.make_template and called from __call__. Every time the module is called it will be connected into the graph but using the same shared set of variables, thanks to the template.

For this to work correctly, the build implementation in the derived class must access all variables using tf.get_variable, not tf.Variable. The same set of variables must be created each time, if this is not the case an Error will be raised.

Every subclass must call this class' __init__ at the start of their __init__, passing the relevant name. If this step is omitted variable sharing will not work.

AbstractModule.__init__(_sentinel=None, custom_getter=None, name=None)

Performs the initialisation necessary for all AbstractModule instances.

Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e.

super(MySubModule, self).__init__(custom_getter=custom_getter, name=name).

If you instantiate sub-modules in init you must create them within the _enter_variable_scope context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in _build.

Args:

_sentinel: Variable that only carries a non-None value if __init__ was called without named parameters. If this is the case, a deprecation warning is issued in form of a ValueError.

  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of this module. Used to construct the Templated build function. If None the module's class name is used (converted to snake case).
Raises:
  • TypeError: If name is not a string.
  • TypeError: If a given custom_getter is not callable.
  • ValueError: If __init__ was called without named arguments.

AbstractModule.__call__(*args, **kwargs)

Add elements to the Graph, computing output Tensors from input Tensors.

Subclasses must implement this method, which will be wrapped in a Template.

Args:
  • *args: Input Tensors.
  • **kwargs: Additional Python flags controlling connection.
Returns:

output Tensor(s).

AbstractModule.connected_subgraphs

Returns the subgraphs created by this module so far.

AbstractModule.defun()

Wraps this modules call method in a callable graph function.

AbstractModule.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

AbstractModule.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

AbstractModule.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.graph

Returns the Graph instance which the module is connected to, or None.

AbstractModule.is_connected

Returns true iff the Module been connected to the Graph at least once.

AbstractModule.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.module_name

Returns the name of the Module.

AbstractModule.name_scopes

Returns a tuple of all name_scopes generated by this module.

AbstractModule.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.scope_name

Returns the full name of the Module's variable scope.

AbstractModule.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AbstractModule.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class AddBias

AddBias module.

AddBias.__init__(output_shape=None, bias_dims=None, initializers=None, partitioners=None, regularizers=None, name='add')

Constructs an AddBias module that supports broadcasting.

Args:
  • output_shape: Output dimensionality. output_shape can be either None, a tuple, or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_shape can be called, returning a tuple, when build is called. If output_shape is left as None, the size will be directly inferred by the input.
  • bias_dims: List of which dimensions to retain from the input shape when constructing the bias. The remaining dimensions will get broadcasted over (given size of 1), and leading dimensions will be removed completely. For example, for an input of [batch_size, dim1_size, dim2_size, dim3_size] and bias_dims=[1, 3], the resulting bias will have shape [dim1_size, 1, dim3_size]. The default is to retain all dimensions apart from the minibatch dimension. Trying to retain the bias shape over the minibatch dimension, e.g. bias_dims=[0], will result in an error at build time. See the 'Example Usage' section below for more information.
  • initializers: Optional dict containing ops to initialize the biases (with key 'b'). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing a partitioner to partition the bias (with key 'b'). As a default, no partitioner is used.
  • regularizers: Optional dict containing regularizers of the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Name of the module.

Example Usage:

# Create a 4D input Tensor.
input = tf.random_normal(
    shape=(batch_size, dim1_size, dim2_size, dim3_size)))

# Create a scalar bias:
scalar_bias = snt.AddBias(bias_dims=[])
scalar_bias_output = scalar_bias(input)
scalar_bias.b.get_shape()  # ()

# Create a bias over all non-minibatch dimensions:
all_bias = snt.AddBias()  # or snt.AddBias(bias_dims=None)
all_bias_output = all_bias(input)
all_bias.b.get_shape()  # (dim1_size, dim2_size, dim3_size)

# Create a bias over the last non-minibatch dimension:
last_bias = snt.AddBias(bias_dims=[-1])
last_bias_output = last_bias(input)
last_bias.b.get_shape()  # (dim3_size)

# Create a bias over the first non-minibatch dimension:
first_bias = snt.AddBias(bias_dims=[1])
first_bias_output = first_bias(input)
first_bias.b.get_shape()  # (dim1_size, 1, 1)

# Subtract and later add the same learned bias:
bias = snt.AddBias()
hidden1 = bias(input, multiplier=-1)
# ...
reconstructed_input = bias(hidden4)

Raises:
  • KeyError: If initializers contains any keys other than 'b'.
  • KeyError: If partitioners contains any keys other than 'b'.
  • KeyError: If regularizers contains any keys other than 'b'.
  • TypeError: If any of the given initializers are not callable.
  • TypeError: If any of the given partitioners are not callable.
  • TypeError: If any of the given regularizers are not callable.

AddBias.__call__(inputs, multiplier=1)

Connects the Add module into the graph, with input Tensor inputs.

Args:
  • inputs: A Tensor of size [batch_size, input_size1, ...].
  • multiplier: A scalar or Tensor which the bias term is multiplied by before adding it to inputs. Anything which works in the expression bias * multiplier is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via multiplier=-1.
Returns:

A Tensor of size [batch_size, input_size1, ...].

Raises:

base.IncompatibleShapeError: If the input is not a >= 2D Tensor. base.IncompatibleShapeError: If connecting the module into the graph any time after the first time, and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the output_shape has been specified but it does not match the input_shape`. base.ParentNotBuiltError: If the module is a transposed and the original untransposed module has not been built.

AddBias.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

AddBias.connected_subgraphs

Returns the subgraphs created by this module so far.

AddBias.defun()

Wraps this modules call method in a callable graph function.

AddBias.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

AddBias.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

AddBias.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.graph

Returns the Graph instance which the module is connected to, or None.

AddBias.input_shape

Returns shape of input Tensor passed at last call to build.

AddBias.is_connected

Returns true iff the Module been connected to the Graph at least once.

AddBias.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.module_name

Returns the name of the Module.

AddBias.name_scopes

Returns a tuple of all name_scopes generated by this module.

AddBias.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.scope_name

Returns the full name of the Module's variable scope.

AddBias.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.transpose(name=None)

Returns transposed AddBias module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.module_name.
Returns:

Transposed AddBias module.

AddBias.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AddBias.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class AffineGridWarper

Affine Grid Warper class.

The affine grid warper generates a reference grid of n-dimensional points and warps it via an affine transormation model determined by an input parameter Tensor. Some of the transformation parameters can be fixed at construction time via an AffineWarpConstraints object.

AffineGridWarper.__init__(source_shape, output_shape, constraints=None, name='affine_grid_warper')

Constructs an AffineGridWarper.

source_shape and output_shape are used to define the size of the source and output signal domains, as opposed to the shape of the respective Tensors. For example, for an image of size width=W and height=H, {source,output}_shape=[H, W]; for a volume of size width=W, height=H and depth=D, {source,output}_shape=[H, W, D].

Args:
  • source_shape: Iterable of integers determining the size of the source signal domain.
  • output_shape: Iterable of integers determining the size of the destination resampled signal domain.
  • constraints: Either a double list of shape [N, N+1] defining constraints on the entries of a matrix defining an affine transformation in N dimensions, or an AffineWarpConstraints object. If the double list is passed, a numeric value bakes in a constraint on the corresponding entry in the tranformation matrix, whereas None implies that the corresponding entry will be specified at run time.
  • name: Name of module.
Raises:
  • Error: If constraints fully define the affine transformation; or if input grid shape and contraints have different dimensionality.
  • TypeError: If output_shape and source_shape are not both iterable.

AffineGridWarper.__call__(inputs)

Assembles the module network and adds it to the graph.

The internal computation graph is assembled according to the set of constraints provided at construction time.

Args:
  • inputs: Tensor containing a batch of transformation parameters.
Returns:

A batch of warped grids.

Raises:
  • Error: If the input tensor size is not consistent with the constraints passed at construction time.

AffineGridWarper.connected_subgraphs

Returns the subgraphs created by this module so far.

AffineGridWarper.constraints

AffineGridWarper.defun()

Wraps this modules call method in a callable graph function.

AffineGridWarper.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

AffineGridWarper.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

AffineGridWarper.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.graph

Returns the Graph instance which the module is connected to, or None.

AffineGridWarper.inverse(name=None)

Returns a sonnet module to compute inverse affine transforms.

The function first assembles a network that given the constraints of the current AffineGridWarper and a set of input parameters, retrieves the coefficients of the corresponding inverse affine transform, then feeds its output into a new AffineGridWarper setup to correctly warp the output space into the source space.

Args:
  • name: Name of module implementing the inverse grid transformation.
Returns:

A sonnet module performing the inverse affine transform of a reference grid of points via an AffineGridWarper module.

Raises:

tf.errors.UnimplementedError: If the function is called on a non 2D instance of AffineGridWarper.

AffineGridWarper.is_connected

Returns true iff the Module been connected to the Graph at least once.

AffineGridWarper.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.module_name

Returns the name of the Module.

AffineGridWarper.n_coeff

Returns number of coefficients of warping function.

AffineGridWarper.name_scopes

Returns a tuple of all name_scopes generated by this module.

AffineGridWarper.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.output_shape

Returns a tuple containing the shape of the output grid.

AffineGridWarper.psi

Returns a list of features used to compute the grid warp.

AffineGridWarper.scope_name

Returns the full name of the Module's variable scope.

AffineGridWarper.source_shape

Returns a tuple containing the shape of the source signal.

AffineGridWarper.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AffineGridWarper.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class AffineWarpConstraints

Affine warp contraints class.

AffineWarpConstraints allow for very succinct definitions of constraints on the values of entries in affine transform matrices.

AffineWarpConstraints.__init__(constraints=((None, None, None), (None, None, None)))

Creates a constraint definition for an affine transformation.

Args:
  • constraints: A doubly-nested iterable of shape [N, N+1] defining constraints on the entries of a matrix that represents an affine transformation in N dimensions. A numeric value bakes in a constraint on the corresponding entry in the tranformation matrix, whereas None implies that the corresponding entry will be specified at run time.
Raises:
  • TypeError: If constraints is not a nested iterable.
  • ValueError: If the double iterable constraints has inconsistent dimensions.

AffineWarpConstraints.combine_with(additional_constraints)

Combines two sets of constraints into a coherent single set.

AffineWarpConstraints.constraints

AffineWarpConstraints.mask

AffineWarpConstraints.no_constraints(cls, num_dim=2)

Empty set of constraints for a num_dim-ensional affine transform.

AffineWarpConstraints.no_shear_2d(cls)

AffineWarpConstraints.no_shear_3d(cls)

Assigns contraints on shear components of affine transform in 3d.

AffineWarpConstraints.num_dim

AffineWarpConstraints.num_free_params

AffineWarpConstraints.scale_2d(cls, x=None, y=None)

Assigns contraints on scaling components of affine transform in 2d.

AffineWarpConstraints.scale_3d(cls, x=None, y=None, z=None)

Assigns contraints on scaling components of affine transform in 3d.

AffineWarpConstraints.shear_2d(cls, x=None, y=None)

Assigns contraints on shear components of affine transform in 2d.

AffineWarpConstraints.translation_2d(cls, x=None, y=None)

Assign contraints on translation components of affine transform in 2d.

AffineWarpConstraints.translation_3d(cls, x=None, y=None, z=None)

Assign contraints on translation components of affine transform in 3d.

class AttentiveRead

A module for reading with attention.

This module reads a weighted sum of embeddings from memory, where each memory slot's weight is based on the logit returned by an attention embedding module. A mask may be given to ignore some memory slots (e.g. when attending over variable-length sequences).

AttentiveRead.__init__(attention_logit_mod, name='attention')

Initialize AttentiveRead module.

Args:
  • attention_logit_mod: Module that produces logit corresponding to a memory slot's compatibility. Must map a [batch_size * memory_size, memory_word_size + query_word_size]-shaped Tensor to a [batch_size * memory_size, 1] shape Tensor.
  • name: string. Name for module.

AttentiveRead.__call__(memory, query, memory_mask=None)

Perform a differentiable read.

Args:
  • memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, a single embedding to attend over.
  • query: [batch_size, query_word_size]-shaped Tensor of dtype float32. Represents, for each example, a single embedding representing a query.
  • memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype bool. An entry of False indicates that a memory slot should not enter the resulting weighted sum. If None, all memory is used.
Returns:

An AttentionOutput instance containing:

  • read: [batch_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example, a weighted sum of the contents of the memory.
  • weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the attention weights used to compute the read.
  • weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the logits of the attention weights, that is, weights is calculated by taking the softmax of the weight logits.
Raises:
  • UnderspecifiedError: if memory_word_size or query_word_size can not be inferred.
  • IncompatibleShapeError: if memory, query, memory_mask, or output of attention_logit_mod do not match expected shapes.

AttentiveRead.connected_subgraphs

Returns the subgraphs created by this module so far.

AttentiveRead.defun()

Wraps this modules call method in a callable graph function.

AttentiveRead.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

AttentiveRead.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

AttentiveRead.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.graph

Returns the Graph instance which the module is connected to, or None.

AttentiveRead.is_connected

Returns true iff the Module been connected to the Graph at least once.

AttentiveRead.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.module_name

Returns the name of the Module.

AttentiveRead.name_scopes

Returns a tuple of all name_scopes generated by this module.

AttentiveRead.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.scope_name

Returns the full name of the Module's variable scope.

AttentiveRead.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

AttentiveRead.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BatchApply

Merges a number of leading dimensions of an input tensor to manipulate it.

Merges a number of leading dimensions of a tensor into a single dimension, connects the provided module, then splits the leading dimension of the result to match the input.

Input tensors whose rank is smaller than the number of dimensions to collapse (e.g. all scalar values, which are tensors of rank 0), are passed unaltered to the provided module.

This is useful for applying some module to each timestep of a Time x Batch x N tensor. If a module is hard coded to only support 2D (Batch x N) then the full 3D Tensor cannot be provided. BatchApply will 'merge' the first two dimensions of the sequence tensor by reshaping to a (Time * Batch) x N Tensor, and then the internal module can be applied. The result of that operation is reshaped such that its first dimensions are split to match the leading dimensions of the input.

BatchApply.__init__(module_or_op, n_dims=2, input_example_index=0, name='batch_apply')

Constructor of the module.

Args:
  • module_or_op: Module or tensorflow op to apply to an input tensor.
  • n_dims: Number of dimensions to merge before using module on the input of BatchApply.
  • input_example_index: Index of input that has same shape for the first n_dims dimensions as module_or_op output(s). This is used for unflattening the output(s) if static shape inference is not possible.
  • name: Name of the module.
Raises:
  • TypeError: If n_dims is not an integer.
  • ValueError: If n_dims is not greater than zero.

BatchApply.__call__(*args, **kwargs)

Connects the BatchApply module into the graph.

Args:
  • *args: a Tensor or a nested list or dictionary of Tensors. The input tensors will have their first dimensions merged, then an op or a module will be called on the input. The first dimension of the output tensor(s) will be split again based on the leading dimensions of the first input tensor.
  • **kwargs: Dictionary of named arguments; used in the same way as *args.
Returns:

A Tensor or nested list or dictionary of Tensors as a result of applying the process above. ("None" return values are also supported.)

BatchApply.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchApply.defun()

Wraps this modules call method in a callable graph function.

BatchApply.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchApply.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchApply.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.graph

Returns the Graph instance which the module is connected to, or None.

BatchApply.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchApply.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.module_name

Returns the name of the Module.

BatchApply.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchApply.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.scope_name

Returns the full name of the Module's variable scope.

BatchApply.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchApply.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BatchFlatten

Flattens the input Tensor, preserving the batch dimension(s).

BatchFlatten.__init__(preserve_dims=1, name='batch_flatten')

Constructs a BatchFlatten module.

Args:
  • preserve_dims: Number of leading dimensions that will not be reshaped. For example, given an input Tensor with shape [B, H, W, C]: * preserve_dims=1 will return a Tensor with shape [B, H*W*C]. * preserve_dims=2 will return a Tensor with shape [B, H, W*C]. * preserve_dims=3 will return the input itself, shape [B, H, W, C]. * preserve_dims=4 will return a Tensor with shape [B, H, W, C, 1]. * preserve_dims>=5 will throw an error on build. The preserved dimensions can be unknown at building time.
  • name: Name of the module.

BatchFlatten.__call__(inputs)

Connects the module into the graph, with input Tensor inputs.

Args:
  • inputs: A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_preserve_dims+1, ...].
Returns:

A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_reshape_1, b_reshape_2, ...], with reshaping defined by the constructor shape parameter.

Raises:
  • ValueError: If output shape is incompatible with input shape; or if shape array contains non numeric entries; or if shape array contains more than 1 wildcard -1; or if the input array contains unknown, non-preserved dimensions (except when the unknown dimension is the only non-preserved dimension and doesn't actually need reshaping).

BatchFlatten.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchFlatten.defun()

Wraps this modules call method in a callable graph function.

BatchFlatten.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchFlatten.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchFlatten.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.graph

Returns the Graph instance which the module is connected to, or None.

BatchFlatten.input_shape

BatchFlatten.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchFlatten.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.module_name

Returns the name of the Module.

BatchFlatten.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchFlatten.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.scope_name

Returns the full name of the Module's variable scope.

BatchFlatten.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.transpose(name=None)

Returns transpose batch reshape.

BatchFlatten.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchFlatten.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BatchNorm

Batch normalization module, including optional affine transformation.

This module maintains exponential moving averages of the mean and variance, which can be optionally used to normalize at test time.

At training time, batch statistics (mean, variance) are not shared between separate connections. The moving averages are shared between separate connections. At both training and test time, the optional affine transformation (* gamma + beta) is shared between separate connections.

This is also the case for distributed replica training, where the batch statistics are not aggregated across replicas, but the moving averages are shared globally.

When connecting the module to the graph, is_training=True means that

  • Update ops are created to update the moving averages with the current batch's statistics.
  • Features are normalized using the current batch's statistics. The test_local_stats setting is ignored. The moving averages are not used.

whereas is_training=False means that

  • Update ops are not created.
  • Features are normalized using either:
    • The test batch statistics if test_local_stats=True (default).
    • The moving averages if test_local_stats=False.

Local batch statistics are used by default at test time, but the moving averages can be used by specifying a flag when connecting. One often wants to use local batch statistics at test time to track the progress while the model is trained as it would ensure that moving average updates do not affect the training curves. Once the training is finished, it's often advantageous to use moving average statistics, since it would make evaluation agnostic to the batch size, and might even lead to small improvements over the local batch statistics.

You can either update the moving averages automatically by setting update_ops_collection=None or by running the ops in the given collection, by default tf.GraphKeys.UPDATE_OPS.

For example, to run the updates automatically:

bn = BatchNorm(update_ops_collection=None)
train_net = bn(train_inputs, is_training=True)

this does, however, have the effect of blocking the forwards pass of the network until the update ops have been run and may have a small performance penalty.

For example, to run the updates manually:

bn = BatchNorm()
train_net = bn(train_inputs, is_training=True)

...

update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
train_op = tf.group(train_op, update_ops)

Then, whenever train_op is run so also are the moving average update ops.

Some batch normalization caveats:

  • Batch normalization will remove the effect of adding a bias, so e.g. use_bias=False should be used for an immediately preceding snt.Linear module.
  • If your data batches aren't i.i.d. then batch normalization can allow your network to 'cheat' by using the batch statistics to peek at the rest of the batch. This can exhibit itself as a higher test score with test_local_stats=True than test_local_stats=False.

BatchNorm.__init__(axis=None, offset=True, scale=False, decay_rate=0.999, eps=0.001, initializers=None, partitioners=None, regularizers=None, update_ops_collection='update_ops', fused=False, name='batch_norm')

Constructs a BatchNorm module.

By default reduces over all input tensor dimensions apart from the final dimension. This has the effect of treating pixels in 1D/2D/3D images as additional elements of the minibatch.

If this is not the desired behaviour, the user can specify the tensor indices to reduce over with axis.

Args:
  • axis: Optional iterable of indices of dimensions to reduce over. By default None and all dimensions except the last are reduced over.
  • offset: Optional boolean to specify whether or not to apply a trained component-wise bias after the batch normalization and scaling.
  • scale: Optional boolean to specify whether or not to apply a trained component-wise scale after the batch normalization.
  • decay_rate: Decay rate of the exponential moving averages of the mean and variance.
  • eps: Small number to avoid dividing by zero when diving by the standard deviation.
  • initializers: Optional dict containing ops to initialize the weights of the affine transform (gamma and beta).
  • partitioners: Optional dict containing partitioners to partition the weights of the affine transform (gamma and beta).
  • regularizers: Optional dict containing regularizers for the weights of the affine transform ('gamma' and 'beta'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • update_ops_collection: Name of TensorFlow variable collection to add the moving average update ops to. If None, we instead add the update ops as control dependencies of the output of the module. This may result in some slowdown, as the feed-forward of the network is now blocked. By default, tf.GraphKeys.UPDATE_OPS.
  • fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
  • name: Name of the module.
Raises:
  • KeyError: If initializers contains any keys other than gamma, beta, moving_mean or moving_variance.
  • KeyError: If partitioners or regularizers contains any keys other than gamma or beta.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.

BatchNorm.__call__(input_batch, is_training, test_local_stats=True)

Connects the BatchNorm module into the graph.

Args:
  • input_batch: A Tensor of arbitrary dimension. By default, the final dimension is not reduced over when computing the minibatch statistics.
  • is_training: A boolean to indicate if the module should be connected in training mode, meaning the moving averages are updated. Can be a Tensor.
  • test_local_stats: A boolean to indicate if local batch statistics should be used when is_training=False. If not, moving averages are used. By default True. Can be a Tensor.
Returns:

A tensor with the same shape as input_batch.

Raises:

base.IncompatibleShapeError: If axis is not valid for the input shape or has negative entries. base.NotSupportedError: If input_batch has data type of tf.bfloat16.

BatchNorm.beta

BatchNorm.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchNorm.defun()

Wraps this modules call method in a callable graph function.

BatchNorm.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchNorm.gamma

BatchNorm.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchNorm.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.graph

Returns the Graph instance which the module is connected to, or None.

BatchNorm.initializers

BatchNorm.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchNorm.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.module_name

Returns the name of the Module.

BatchNorm.moving_mean

BatchNorm.moving_variance

BatchNorm.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchNorm.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.partitioners

BatchNorm.regularizers

BatchNorm.scope_name

Returns the full name of the Module's variable scope.

BatchNorm.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNorm.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BatchNormLSTM

LSTM recurrent network cell with optional peepholes, batch normalization.

The base implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

Peep-hole connections

Peep-hole connections may optionally be used by specifying a flag in the constructor. These connections can aid increasing the precision of output timing, for more details see:

https://research.google.com/pubs/archive/43905.pdf

Batch normalization

The batch norm transformation (in training mode) is batchnorm(x) = gamma * (x - mean(x)) / stddev(x) + beta, where gamma is a learnt scaling factor and beta is a learnt offset.

Batch normalization may optionally be used at different places in the LSTM by specifying flag(s) in the constructor. These are applied when calculating the gate activations and cell-to-hidden transformation. The set-up is based on

https://arxiv.org/pdf/1603.09025.pdf

Batch normalization: where to apply?

Batch norm can be applied in three different places in the LSTM:

(h) To the W_h h_{t-1} contribution to the gates from the previous hiddens. (x) To the W_x x_t contribution to the gates from the current input. (c) To the cell value c_t when calculating the output h_t from the cell.

(The notation here is consistent with the Recurrent Batch Normalization paper). Each of these can be controlled individually, because batch norm is expensive, and not all are necessary. The paper doesn't mention the relative effects of these different batch norms; however, experimentation with a shallow LSTM for the permuted_mnist sequence task suggests that (h) is the most important and the other two can be left off. For other tasks or deeper (stacked) LSTMs, other batch norm combinations may be more effective.

Batch normalization: collecting stats (training vs test)

When switching to testing (see LSTM.with_batch_norm_control), we can use a mean and stddev learnt from the training data instead of using the statistics from the test data. (This both increases test accuracy because the statistics have less variance, and if the test data does not have the same distribution as the training data then we must use the training statistics to ensure the effective network does not change when switching to testing anyhow.)

This does however introduces a slight subtlety. The first few time steps of the RNN tend to have varying statistics (mean and variance) before settling down to a steady value. Therefore in general, better performance is obtained by using separate statistics for the first few time steps, and then using the final set of statistics for all subsequent time steps. This is controlled by the parameter max_unique_stats. (We can't have an unbounded number of distinct statistics for both technical reasons and also for the case where test sequences are longer than anything seen in training.)

You may be fine leaving it at its default value of 1. Small values (like 10) may achieve better performance on some tasks when testing with cached statistics.

Attributes: state_size: Tuple of tf.TensorShapes indicating the size of state tensors. output_size: tf.TensorShape indicating the size of the core output. use_peepholes: Boolean indicating whether peephole connections are used. use_batch_norm_h: Boolean indicating whether batch norm (h) is enabled. use_batch_norm_x: Boolean indicating whether batch norm (x) is enabled. use_batch_norm_c: Boolean indicating whether batch norm (c) is enabled.

BatchNormLSTM.__init__(hidden_size, forget_bias=1.0, initializers=None, partitioners=None, regularizers=None, use_peepholes=False, use_batch_norm_h=True, use_batch_norm_x=False, use_batch_norm_c=False, max_unique_stats=1, hidden_clip_value=None, cell_clip_value=None, custom_getter=None, name='batch_norm_lstm')

Construct BatchNormLSTM.

Args:
  • hidden_size: (int) Hidden size dimensionality.
  • forget_bias: (float) Bias for the forget activation.
  • initializers: Dict containing ops to initialize the weights. This dictionary may contain any of the keys returned by BatchNormLSTM.get_possible_initializer_keys. The gamma and beta variables control batch normalization values for different batch norm transformations inside the cell; see the paper for details.
  • partitioners: Optional dict containing partitioners to partition the weights and biases. As a default, no partitioners are used. This dict may contain any of the keys returned by BatchNormLSTM.get_possible_initializer_keys.
  • regularizers: Optional dict containing regularizers for the weights and biases. As a default, no regularizers are used. This dict may contain any of the keys returned by BatchNormLSTM.get_possible_initializer_keys.
  • use_peepholes: Boolean that indicates whether peephole connections are used.
  • use_batch_norm_h: Boolean that indicates whether to apply batch normalization at the previous_hidden -> gates contribution. If you are experimenting with batch norm then this may be the most effective to use, and is enabled by default.
  • use_batch_norm_x: Boolean that indicates whether to apply batch normalization at the input -> gates contribution.
  • use_batch_norm_c: Boolean that indicates whether to apply batch normalization at the cell -> output contribution.
  • max_unique_stats: The maximum number of steps to use unique batch norm statistics for. (See module description above for more details.)
  • hidden_clip_value: Optional number; if set, then the LSTM hidden state vector is clipped by this value.
  • cell_clip_value: Optional number; if set, then the LSTM cell vector is clipped by this value.
  • custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. See the tf.get_variable documentation for more details.
  • name: Name of the module.
Raises:
  • KeyError: if initializers contains any keys not returned by BatchNormLSTM.get_possible_initializer_keys.
  • KeyError: if partitioners contains any keys not returned by BatchNormLSTM.get_possible_initializer_keys.
  • KeyError: if regularizers contains any keys not returned by BatchNormLSTM.get_possible_initializer_keys.
  • ValueError: if a peephole initializer is passed in the initializer list, but use_peepholes is False.
  • ValueError: if a batch norm initializer is passed in the initializer list, but batch norm is disabled.
  • ValueError: if none of the use_batch_norm_* options are True.
  • ValueError: if max_unique_stats is < 1.

BatchNormLSTM.__call__(inputs, prev_state, is_training=None, test_local_stats=True)

Connects the LSTM module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as inputs and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • inputs: Tensor of size [batch_size, input_size].
  • prev_state: Tuple (prev_hidden, prev_cell), or if batch norm is enabled and max_unique_stats > 1, then (prev_hidden, prev_cell, time_step). Here, prev_hidden and prev_cell are tensors of size [batch_size, hidden_size], and time_step is used to indicate the current RNN step.
  • is_training: Boolean indicating whether we are in training mode (as opposed to testing mode), passed to the batch norm modules. Note to use this you must wrap the cell via the with_batch_norm_control function.
  • test_local_stats: Boolean indicating whether to use local batch statistics in test mode. See the BatchNorm documentation for more on this.
Returns:

A tuple (output, next_state) where 'output' is a Tensor of size [batch_size, hidden_size] and 'next_state' is a tuple (next_hidden, next_cell) or (next_hidden, next_cell, time_step + 1), where next_hidden and next_cell have size [batch_size, hidden_size].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations.

BatchNormLSTM.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchNormLSTM.defun()

Wraps this modules call method in a callable graph function.

BatchNormLSTM.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchNormLSTM.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.get_possible_initializer_keys(cls, use_peepholes=False, use_batch_norm_h=True, use_batch_norm_x=False, use_batch_norm_c=False)

Returns the keys the dictionary of variable initializers may contain.

The set of all possible initializer keys are:
  • w_gates: weight for gates
  • b_gates: bias of gates
  • w_f_diag: weight for prev_cell -> forget gate peephole
  • w_i_diag: weight for prev_cell -> input gate peephole
  • w_o_diag: weight for prev_cell -> output gate peephole
  • gamma_h: batch norm scaling for previous_hidden -> gates
  • gamma_x: batch norm scaling for input -> gates
  • gamma_c: batch norm scaling for cell -> output
  • beta_c: batch norm bias for cell -> output
Args:

cls:The class.

  • use_peepholes: Boolean that indicates whether peephole connections are used.
  • use_batch_norm_h: Boolean that indicates whether to apply batch normalization at the previous_hidden -> gates contribution. If you are experimenting with batch norm then this may be the most effective to turn on.
  • use_batch_norm_x: Boolean that indicates whether to apply batch normalization at the input -> gates contribution.
  • use_batch_norm_c: Boolean that indicates whether to apply batch normalization at the cell -> output contribution.
Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchNormLSTM.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.graph

Returns the Graph instance which the module is connected to, or None.

BatchNormLSTM.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None)

Builds the default start state tensor of zeros.

Args:
  • batch_size: An int, float or scalar Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state.
  • trainable_initializers: An optional pair of initializers for the initial hidden state and cell state.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor tuple ([batch_size, state_size], [batch_size, state_size], ?) filled with zeros, with the third entry present when batch norm is enabled with max_unique_stats > 1', with value0` (representing the time step).

BatchNormLSTM.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchNormLSTM.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.module_name

Returns the name of the Module.

BatchNormLSTM.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchNormLSTM.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.output_size

tf.TensorShape indicating the size of the core output.

BatchNormLSTM.scope_name

Returns the full name of the Module's variable scope.

BatchNormLSTM.state_size

Tuple of tf.TensorShapes indicating the size of state tensors.

BatchNormLSTM.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.use_batch_norm_c

Boolean indicating whether batch norm for cell -> output is enabled.

BatchNormLSTM.use_batch_norm_h

Boolean indicating whether batch norm for hidden -> gates is enabled.

BatchNormLSTM.use_batch_norm_x

Boolean indicating whether batch norm for input -> gates is enabled.

BatchNormLSTM.use_peepholes

Boolean indicating whether peephole connections are used.

BatchNormLSTM.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormLSTM.with_batch_norm_control(is_training, test_local_stats=True)

Wraps this RNNCore with the additional control input to the BatchNorms.

Example usage:

lstm = snt.BatchNormLSTM(4) is_training = tf.placeholder(tf.bool) rnn_input = ... my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)

Args:
  • is_training: Boolean that indicates whether we are in training mode or testing mode. When in training mode, the batch norm statistics are taken from the given batch, and moving statistics are updated. When in testing mode, the moving statistics are not updated, and in addition if test_local_stats is False then the moving statistics are used for the batch statistics. See the BatchNorm module for more details.
  • test_local_stats: Boolean scalar indicated whether to use local batch statistics in test mode.
Returns:

snt.RNNCore wrapping this class with the extra input(s) added.

BatchNormLSTM.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class BatchNormV2

Batch normalization module, including optional affine transformation.

This module maintains exponential moving averages of the mean and variance, which can be optionally used to normalize at test time.

At training time, batch statistics (mean, variance) are not shared between separate connections. The moving averages are shared between separate connections. At both training and test time, the optional affine transformation (* gamma + beta) is shared between separate connections.

This is also the case for distributed replica training, where the batch statistics are not aggregated across replicas, but the moving averages are shared globally.

When connecting the module to the graph, is_training=True means that

  • Update ops are created to update the moving averages with the current batch's statistics.
  • Features are normalized using the current batch's statistics. The test_local_stats setting is ignored. The moving averages are not used.

whereas is_training=False means that

  • Update ops are not created.
  • Features are normalized using either:
    • The moving averages if test_local_stats=False (default).
    • The test batch statistics if test_local_stats=True.

The moving averages are used by default at test time, but local batch statistics can be used by specifying a flag when connecting. One often wants to use local batch statistics at test time to track the progress while the model is trained as it would ensure that moving average updates do not affect the training curves. Once the training is finished, it's often advantageous to use moving average statistics, since it would make evaluation agnostic to the batch size, and might even lead to small improvements over the local batch statistics.

The moving averages will be updated automatically by default, but not if update_ops_collection is provided: in that case they will only be updated when the ops in that collection are run.

For example, to run the updates automatically:

bn = BatchNormV2()
train_net = bn(train_inputs, is_training=True)

this does, however, have the effect of blocking the forwards pass of the network until the update ops have been run and may have a small performance penalty.

For example, to run the updates manually:

bn = BatchNormV2(update_ops_collection=tf.GraphKeys.UPDATE_OPS)
train_net = bn(train_inputs, is_training=True)

...

update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
train_op = tf.group(train_op, update_ops)

Then, whenever train_op is run so also are the moving average update ops.

Some batch normalization caveats:

  • Batch normalization will remove the effect of adding a bias, so e.g. use_bias=False should be used for an immediately preceding snt.Linear module.
  • If your data batches aren't i.i.d. then batch normalization can allow your network to 'cheat' by using the batch statistics to peek at the rest of the batch. This can exhibit itself as a higher test score with test_local_stats=True than test_local_stats=False.

BatchNormV2.__init__(data_format=None, offset=True, scale=False, decay_rate=0.999, eps=0.001, initializers=None, partitioners=None, regularizers=None, update_ops_collection=None, fused=True, name='batch_norm')

Constructs a BatchNormV2 module.

Reduces over all input tensor dimensions apart from the channel dimension. This has the effect of treating pixels in 1D/2D/3D images as additional elements of the minibatch.

Args:
  • data_format: The data format. Can be "NC", "NWC", "NCW", "NHWC", "NCHW", "NDHWC", or "NCDHW". If not provided we assume the channel dimension is last.
  • offset: Optional boolean to specify whether or not to apply a trained component-wise bias after the batch normalization and scaling.
  • scale: Optional boolean to specify whether or not to apply a trained component-wise scale after the batch normalization.
  • decay_rate: Decay rate of the exponential moving averages of the mean and variance.
  • eps: Small number to avoid dividing by zero when diving by the standard deviation.
  • initializers: Optional dict containing ops to initialize the weights of the affine transform (gamma and beta).
  • partitioners: Optional dict containing partitioners to partition the weights of the affine transform (gamma and beta).
  • regularizers: Optional dict containing regularizers for the weights of the affine transform ("gamma" and "beta"). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • update_ops_collection: Optional name of TensorFlow variable collection to add the moving average update ops to. If not provided, we instead add the update ops as control dependencies of the output of the module. This may result in some slowdown, as the feed-forward of the network is now blocked.
  • fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
  • name: Name of the module.
Raises:
  • KeyError: If initializers contains any keys other than gamma, beta, moving_mean or moving_variance.
  • KeyError: If partitioners or regularizers contains any keys other than gamma or beta.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If data_format is invalid.

BatchNormV2.__call__(input_batch, is_training, test_local_stats=False)

Connects the BatchNormV2 module into the graph.

Args:
  • input_batch: A Tensor of the same dimension as len(data_format).
  • is_training: A boolean to indicate if the module should be connected in training mode, meaning the moving averages are updated. Can be a Tensor.
  • test_local_stats: A boolean to indicate if local batch statistics should be used when is_training=False. If not, moving averages are used. By default False. Can be a Tensor.
Returns:

A tensor with the same shape as input_batch.

Raises:

base.IncompatibleShapeError: If data_format is not valid for the input shape. base.NotSupportedError: If input_batch has data type of tf.bfloat16.

BatchNormV2.beta

BatchNormV2.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchNormV2.defun()

Wraps this modules call method in a callable graph function.

BatchNormV2.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchNormV2.gamma

BatchNormV2.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchNormV2.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.graph

Returns the Graph instance which the module is connected to, or None.

BatchNormV2.initializers

BatchNormV2.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchNormV2.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.module_name

Returns the name of the Module.

BatchNormV2.moving_mean

BatchNormV2.moving_variance

BatchNormV2.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchNormV2.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.partitioners

BatchNormV2.regularizers

BatchNormV2.scope_name

Returns the full name of the Module's variable scope.

BatchNormV2.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchNormV2.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BatchReshape

Reshapes input Tensor, preserving the batch dimension.

BatchReshape.__init__(shape, preserve_dims=1, name='batch_reshape')

Constructs a BatchReshape module.

Args:
  • shape: Shape to reshape the input Tensor to while preserving its first preserve_dims dimensions; shape can be either a tuple/list, or a callable that returns the actual shape. The callable does not need to be ready to return something meaningful at construction time, but it will be required to be able to do so when the module is connected to the graph. When the special value -1 appears in shape the corresponding size is automatically inferred. Note that -1 can only appear once in shape. To flatten all non-batch dimensions, the snt.BatchFlatten module can also be used.
  • preserve_dims: Number of leading dimensions that will not be reshaped. For example, given an input Tensor with shape [B, H, W, C, D], and argument shape equal to (-1, D): * preserve_dims=1 will return a Tensor with shape [B, H*W*C, D]. * preserve_dims=2 will return a Tensor with shape [B, H, W*C, D]. * preserve_dims=3 will return a Tensor with shape [B, H, W, C, D]. * preserve_dims=4 will return a Tensor with shape [B, H, W, C, 1, D]. * preserve_dims>=5 will throw an error on build unless D=1. The preserved dimensions can be unknown at building time.
  • name: Name of the module.
Raises:
  • ValueError: If preserve_dims <= 0.

BatchReshape.__call__(inputs)

Connects the module into the graph, with input Tensor inputs.

Args:
  • inputs: A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_preserve_dims+1, ...].
Returns:

A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_reshape_1, b_reshape_2, ...], with reshaping defined by the constructor shape parameter.

Raises:
  • ValueError: If output shape is incompatible with input shape; or if shape array contains non numeric entries; or if shape array contains more than 1 wildcard -1; or if the input array contains unknown, non-preserved dimensions (except when the unknown dimension is the only non-preserved dimension and doesn't actually need reshaping).

BatchReshape.connected_subgraphs

Returns the subgraphs created by this module so far.

BatchReshape.defun()

Wraps this modules call method in a callable graph function.

BatchReshape.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BatchReshape.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BatchReshape.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.graph

Returns the Graph instance which the module is connected to, or None.

BatchReshape.input_shape

BatchReshape.is_connected

Returns true iff the Module been connected to the Graph at least once.

BatchReshape.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.module_name

Returns the name of the Module.

BatchReshape.name_scopes

Returns a tuple of all name_scopes generated by this module.

BatchReshape.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.scope_name

Returns the full name of the Module's variable scope.

BatchReshape.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.transpose(name=None)

Returns transpose batch reshape.

BatchReshape.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BatchReshape.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class BidirectionalRNN

Bidirectional RNNCore that processes the sequence forwards and backwards.

Based upon the encoder implementation in: https://arxiv.org/abs/1409.0473

This interface of this module is different than the typical ones found in the RNNCore family. The primary difference is that it is pre-conditioned on the full input sequence in order to produce a full sequence of outputs and states concatenated along the feature dimension among the forward and backward cores.

BidirectionalRNN.__init__(forward_core, backward_core, name='bidir_rnn')

Construct a Bidirectional RNN core.

Args:
  • forward_core: callable RNNCore module that computes forward states.
  • backward_core: callable RNNCore module that computes backward states.
  • name: name of the module.
Raises:
  • ValueError: if not all the modules are recurrent.

BidirectionalRNN.__call__(input_sequence, state)

Connects the BidirectionalRNN module into the graph.

Args:
  • input_sequence: tensor (time, batch, [feature_1, ..]). It must be time_major.
  • state: tuple of states for the forward and backward cores.
Returns:

A dict with forward/backard states and output sequences:

"outputs":{
    "forward": ...,
    "backward": ...},
"state": {
    "forward": ...,
    "backward": ...}
Raises:
  • ValueError: in case time dimension is not statically known.

BidirectionalRNN.connected_subgraphs

Returns the subgraphs created by this module so far.

BidirectionalRNN.defun()

Wraps this modules call method in a callable graph function.

BidirectionalRNN.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

BidirectionalRNN.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

BidirectionalRNN.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.graph

Returns the Graph instance which the module is connected to, or None.

BidirectionalRNN.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None)

Builds the default start state for a BidirectionalRNN.

The Bidirectional RNN flattens the states of its forward and backward cores and concatentates them.

Args:
  • batch_size: An int, float or scalar Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

Tuple of initial states from forward and backward RNNs.

BidirectionalRNN.is_connected

Returns true iff the Module been connected to the Graph at least once.

BidirectionalRNN.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.module_name

Returns the name of the Module.

BidirectionalRNN.name_scopes

Returns a tuple of all name_scopes generated by this module.

BidirectionalRNN.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.output_size

Flattened output size of cores.

BidirectionalRNN.scope_name

Returns the full name of the Module's variable scope.

BidirectionalRNN.state_size

Flattened state size of cores.

BidirectionalRNN.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

BidirectionalRNN.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class CausalConv1D

1D convolution module, including optional bias.

This is deprecated, please use the padding=CAUSAL argument to Conv1D.

This acts as a light wrapper around _ConvND ensuring that the outputs at index i only depend on indices smaller than i (also known as a causal convolution). For further details on the theoretical background, refer to:

https://arxiv.org/abs/1610.10099

CausalConv1D.__init__(output_channels, kernel_shape, stride=1, rate=1, use_bias=True, initializers=None, partitioners=None, regularizers=None, mask=None, padding='CAUSAL', data_format='NWC', custom_getter=None, name='causal_conv_1d')

Constructs a CausalConv1D module.

This is deprecated, please use the padding=CAUSAL argument to Conv1D.

Args:
  • output_channels: Number of output channels. output_channels can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_channels can be called, returning an integer, when build is called.
  • kernel_shape: Sequence of kernel sizes (of size 1), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 1), or integer that is used to define stride in all dimensions.
  • rate: Sequence of dilation rates (of size 1), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • mask: A convertible to a 3D tensor which is multiplied component-wise with the weights (Optional).
  • padding: Padding algorithm. Should be snt.CAUSAL.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NWC), or the second dimension (NCW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.IncompatibleShapeError: If a mask is a TensorFlow Tensor with a not fully defined shape. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • TypeError: If mask is given and it is not convertible to a Tensor.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_1D_DATA_FORMATS).

CausalConv1D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

CausalConv1D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

CausalConv1D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

CausalConv1D.connected_subgraphs

Returns the subgraphs created by this module so far.

CausalConv1D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

CausalConv1D.data_format

Returns the data format.

CausalConv1D.defun()

Wraps this modules call method in a callable graph function.

CausalConv1D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

CausalConv1D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.get_possible_initializer_keys(cls, use_bias=True)

CausalConv1D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.graph

Returns the Graph instance which the module is connected to, or None.

CausalConv1D.has_bias

Returns True if bias Variable is present in the module.

CausalConv1D.initializers

Returns the initializers dictionary.

CausalConv1D.input_channels

Returns the number of input channels.

CausalConv1D.input_shape

Returns the input shape.

CausalConv1D.is_connected

Returns true iff the Module been connected to the Graph at least once.

CausalConv1D.kernel_shape

Returns the kernel shape.

CausalConv1D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.mask

Returns the mask.

CausalConv1D.module_name

Returns the name of the Module.

CausalConv1D.name_scopes

Returns a tuple of all name_scopes generated by this module.

CausalConv1D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.output_channels

Returns the number of output channels.

CausalConv1D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

CausalConv1D.paddings

Returns a tuple with the padding algorithm used for each dimension.

CausalConv1D.partitioners

Returns the partitioners dictionary.

CausalConv1D.rate

Returns the dilation rate.

CausalConv1D.regularizers

Returns the regularizers dictionary.

CausalConv1D.scope_name

Returns the full name of the Module's variable scope.

CausalConv1D.stride

Returns the stride.

CausalConv1D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

CausalConv1D.w

Returns the Variable containing the weight matrix.

class ConcatLinear

Linear transformation of a number of concatenated inputs.

This class ensures that at initialisation, the relative importance of all inputs are similar even if they have very different sizes. This assumes that all inputs have roughly the same range of values.

For example, the following code also concatenates a list of inputs and applies a linear transform:

inp = tf.concat(input_list, axis=-1)
return snt.Linear(output_size)(inp)

The issue with the above code is that if input_list is made of two Tensors of very different shapes such as [batch_size, 1] and [batch_size, 128], then almost no signal will be received from the first Tensor. This class works around this problem by using a weight matrix with relatively larger coefficients for the first Tensor than for the second one.

ConcatLinear.__init__(output_size, use_bias=True, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='concat_linear')

Constructs a ConcatLinear module.

Args:
  • output_size: Output dimensionality. output_size can be either an integer or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_size can be called, returning an integer, when build is called.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing initializers to initialize the weights (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the weights (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.

ConcatLinear.__call__(inputs_list)

Connects the module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided here must have the same final dimensions as when called the first time, in order for the existing variables to be the correct size for the multiplication. The batch size may differ for each connection.

Args:
  • inputs_list: A list of 2D Tensors of rank 2, with leading batch dimension.
Returns:

A 2D Tensor of size [batch_size, output_size].

ConcatLinear.connected_subgraphs

Returns the subgraphs created by this module so far.

ConcatLinear.defun()

Wraps this modules call method in a callable graph function.

ConcatLinear.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

ConcatLinear.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

ConcatLinear.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.graph

Returns the Graph instance which the module is connected to, or None.

ConcatLinear.is_connected

Returns true iff the Module been connected to the Graph at least once.

ConcatLinear.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.module_name

Returns the name of the Module.

ConcatLinear.name_scopes

Returns a tuple of all name_scopes generated by this module.

ConcatLinear.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.scope_name

Returns the full name of the Module's variable scope.

ConcatLinear.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ConcatLinear.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class Conv1D

1D convolution module, including optional bias.

This acts as a light wrapper around the class _ConvND.

Conv1D.__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, mask=None, data_format='NWC', custom_getter=None, name='conv_1d')

Constructs a Conv1D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. output_channels can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_channels can be called, returning an integer, when build is called.
  • kernel_shape: Sequence of kernel sizes (of size 1), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 1), or integer that is used to define stride in all dimensions.
  • rate: Sequence of dilation rates (of size 1), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 1.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • mask: A convertible to a 3D tensor which is multiplied component-wise with the weights (Optional).
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NWC), or the second dimension (NCW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.IncompatibleShapeError: If a mask is a TensorFlow Tensor with a not fully defined shape. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • ValueError: If the given padding is not snt.VALID or snt.SAME.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • TypeError: If mask is given and it is not convertible to a Tensor.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_1D_DATA_FORMATS).

Conv1D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv1D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv1D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

Conv1D.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv1D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv1D.data_format

Returns the data format.

Conv1D.defun()

Wraps this modules call method in a callable graph function.

Conv1D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv1D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.get_possible_initializer_keys(cls, use_bias=True)

Conv1D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.graph

Returns the Graph instance which the module is connected to, or None.

Conv1D.has_bias

Returns True if bias Variable is present in the module.

Conv1D.initializers

Returns the initializers dictionary.

Conv1D.input_channels

Returns the number of input channels.

Conv1D.input_shape

Returns the input shape.

Conv1D.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv1D.kernel_shape

Returns the kernel shape.

Conv1D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.mask

Returns the mask.

Conv1D.module_name

Returns the name of the Module.

Conv1D.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv1D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.output_channels

Returns the number of output channels.

Conv1D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

Conv1D.paddings

Returns a tuple with the padding algorithm used for each dimension.

Conv1D.partitioners

Returns the partitioners dictionary.

Conv1D.rate

Returns the dilation rate.

Conv1D.regularizers

Returns the regularizers dictionary.

Conv1D.scope_name

Returns the full name of the Module's variable scope.

Conv1D.stride

Returns the stride.

Conv1D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.transpose(name=None)

Returns matching Conv1DTranspose module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.name.
Returns:

Conv1DTranspose module.

Raises:

base.NotSupportedError: If rate in any dimension > 1.

Conv1D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1D.w

Returns the Variable containing the weight matrix.

class Conv1DLSTM

1D convolutional LSTM.

Conv1DLSTM.__init__(name='conv_1d_lstm', **kwargs)

Construct Conv1DLSTM. See snt.ConvLSTM for more details.

Conv1DLSTM.__call__(inputs, state)

Conv1DLSTM.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv1DLSTM.convolutions

Conv1DLSTM.defun()

Wraps this modules call method in a callable graph function.

Conv1DLSTM.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv1DLSTM.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.get_possible_initializer_keys(cls, use_bias=True)

Conv1DLSTM.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.graph

Returns the Graph instance which the module is connected to, or None.

Conv1DLSTM.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

Conv1DLSTM.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv1DLSTM.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.module_name

Returns the name of the Module.

Conv1DLSTM.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv1DLSTM.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.output_size

tf.TensorShape indicating the size of the core output.

Conv1DLSTM.scope_name

Returns the full name of the Module's variable scope.

Conv1DLSTM.state_size

Tuple of tf.TensorShapes indicating the size of state tensors.

Conv1DLSTM.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.use_layer_norm

Boolean indicating whether layer norm is enabled.

Conv1DLSTM.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DLSTM.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class Conv1DTranspose

1D transposed / reverse / up 1D convolution module, including bias.

This performs a 1D transpose convolution by lightly wrapping the TensorFlow op tf.nn.conv2d_transpose, setting the size of the height dimension of the image to 1.

Conv1DTranspose.__init__(output_channels, output_shape=None, kernel_shape=None, stride=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NWC', custom_getter=None, name='conv_1d_transpose')

Constructs a Conv1DTranspose module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. Can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure output_channels can be called, returning an integer, when build is called.
  • output_shape: Output shape of transpose convolution. Can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_shape can be called, returning an iterable of format (out_length) when build is called. If a None value is given, a default shape is automatically calculated (see docstring of _default_transpose_size function for more details).
  • kernel_shape: Sequence of kernel sizes (of size 1), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 1), or integer that is used to define stride in all dimensions.
  • padding: Padding algorithm, either snt.SAME or snt.VALID.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NWC), or the second dimension (NCW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two or four integers.

  • ValueError: If the given padding is not snt.VALID or snt.SAME.
  • ValueError: If the given kernel_shape is None.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_1D_DATA_FORMATS).

Conv1DTranspose.__call__(inputs)

Connects the _ConvNDTranspose module into the graph.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same final N dimensions, in order for the existing variables to be the correct size for the multiplication. The batch size may differ for each connection.

Args:
  • inputs: A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.
Returns:

A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If output_shape is an iterable and is not in the format (out_height, out_width).

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv1DTranspose.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv1DTranspose.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv1DTranspose.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv1DTranspose.defun()

Wraps this modules call method in a callable graph function.

Conv1DTranspose.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv1DTranspose.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.get_possible_initializer_keys(cls, use_bias=True)

Conv1DTranspose.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.graph

Returns the Graph instance which the module is connected to, or None.

Conv1DTranspose.has_bias

Returns True if bias Variable is present in the module.

Conv1DTranspose.initializers

Returns the initializers dictionary.

Conv1DTranspose.input_channels

Returns the number of input channels.

Conv1DTranspose.input_shape

Returns the input shape.

Conv1DTranspose.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv1DTranspose.kernel_shape

Returns the kernel shape.

Conv1DTranspose.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.module_name

Returns the name of the Module.

Conv1DTranspose.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv1DTranspose.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.output_channels

Returns the number of output channels.

Conv1DTranspose.output_shape

Returns the output shape.

Conv1DTranspose.padding

Returns the padding algorithm.

Conv1DTranspose.partitioners

Returns the partitioners dictionary.

Conv1DTranspose.regularizers

Returns the regularizers dictionary.

Conv1DTranspose.scope_name

Returns the full name of the Module's variable scope.

Conv1DTranspose.stride

Returns the stride.

Conv1DTranspose.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.transpose(name=None)

Returns matching Conv1D module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.name.
Returns:

Conv1D module.

Conv1DTranspose.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv1DTranspose.w

Returns the Variable containing the weight matrix.

class Conv2D

Spatial convolution and dilated convolution module, including bias.

This acts as a light wrapper around the class _ConvND.

Conv2D.__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, mask=None, data_format='NHWC', custom_getter=None, name='conv_2d')

Constructs a Conv2D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. output_channels can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_channels can be called, returning an integer, when build is called.
  • kernel_shape: Sequence of kernel sizes (of size 2), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 2), or integer that is used to define stride in all dimensions.
  • rate: Sequence of dilation rates (of size 2), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard 2D convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 2.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • mask: A convertible to a 4D tensor which is multiplied component-wise with the weights (Optional).
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NHWC), or the second dimension (NCHW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.IncompatibleShapeError: If a mask is given and its rank is neither 2 nor 4, or if it is a TensorFlow Tensor with a not fully defined shape. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • TypeError: If mask is given and it is not convertible to a Tensor.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_2D_DATA_FORMATS).

Conv2D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv2D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv2D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

Conv2D.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv2D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv2D.data_format

Returns the data format.

Conv2D.defun()

Wraps this modules call method in a callable graph function.

Conv2D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv2D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.get_possible_initializer_keys(cls, use_bias=True)

Conv2D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.graph

Returns the Graph instance which the module is connected to, or None.

Conv2D.has_bias

Returns True if bias Variable is present in the module.

Conv2D.initializers

Returns the initializers dictionary.

Conv2D.input_channels

Returns the number of input channels.

Conv2D.input_shape

Returns the input shape.

Conv2D.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv2D.kernel_shape

Returns the kernel shape.

Conv2D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.mask

Returns the mask.

Conv2D.module_name

Returns the name of the Module.

Conv2D.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv2D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.output_channels

Returns the number of output channels.

Conv2D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

Conv2D.paddings

Returns a tuple with the padding algorithm used for each dimension.

Conv2D.partitioners

Returns the partitioners dictionary.

Conv2D.rate

Returns the dilation rate.

Conv2D.regularizers

Returns the regularizers dictionary.

Conv2D.scope_name

Returns the full name of the Module's variable scope.

Conv2D.stride

Returns the stride.

Conv2D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.transpose(name=None)

Returns matching Conv2DTranspose module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.name.
Returns:

Conv2DTranspose module.

Raises:

base.NotSupportedError: If rate in any dimension > 1.

Conv2D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2D.w

Returns the Variable containing the weight matrix.

class Conv2DLSTM

2D convolutional LSTM.

Conv2DLSTM.__init__(name='conv_2d_lstm', **kwargs)

Construct Conv2DLSTM. See snt.ConvLSTM for more details.

Conv2DLSTM.__call__(inputs, state)

Conv2DLSTM.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv2DLSTM.convolutions

Conv2DLSTM.defun()

Wraps this modules call method in a callable graph function.

Conv2DLSTM.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv2DLSTM.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.get_possible_initializer_keys(cls, use_bias=True)

Conv2DLSTM.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.graph

Returns the Graph instance which the module is connected to, or None.

Conv2DLSTM.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

Conv2DLSTM.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv2DLSTM.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.module_name

Returns the name of the Module.

Conv2DLSTM.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv2DLSTM.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.output_size

tf.TensorShape indicating the size of the core output.

Conv2DLSTM.scope_name

Returns the full name of the Module's variable scope.

Conv2DLSTM.state_size

Tuple of tf.TensorShapes indicating the size of state tensors.

Conv2DLSTM.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.use_layer_norm

Boolean indicating whether layer norm is enabled.

Conv2DLSTM.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DLSTM.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class Conv2DTranspose

Spatial transposed / reverse / up 2D convolution module, including bias.

This acts as a light wrapper around the TensorFlow op tf.nn.conv2d_transpose abstracting away variable creation and sharing.

Conv2DTranspose.__init__(output_channels, output_shape=None, kernel_shape=None, stride=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NHWC', custom_getter=None, name='conv_2d_transpose')

Constructs a Conv2DTranspose module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. Can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure output_channels can be called, returning an integer, when build is called.
  • output_shape: Output shape of transpose convolution. Can be either an iterable of integers or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_shape can be called, returning an iterable of format (out_height, out_width) when build is called. Note that output_shape defines the size of output signal domain, as opposed to the shape of the output Tensor. If a None value is given, a default shape is automatically calculated (see docstring of _default_transpose_size function for more details).
  • kernel_shape: Sequence of kernel sizes (of size 2), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 2), or integer that is used to define stride in all dimensions.
  • padding: Padding algorithm, either snt.SAME or snt.VALID.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NHWC), or the second dimension ("NCHW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See thetf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is neither an integer nor a sequence of two integers. base.IncompatibleShapeError: If the given stride is neither an integer nor a sequence of two or four integers.

  • ValueError: If the given padding is not snt.VALID or snt.SAME.
  • ValueError: If the given kernel_shape is None.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_2D_DATA_FORMATS).

Conv2DTranspose.__call__(inputs)

Connects the _ConvNDTranspose module into the graph.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same final N dimensions, in order for the existing variables to be the correct size for the multiplication. The batch size may differ for each connection.

Args:
  • inputs: A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.
Returns:

A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If output_shape is an iterable and is not in the format (out_height, out_width).

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv2DTranspose.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv2DTranspose.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv2DTranspose.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv2DTranspose.defun()

Wraps this modules call method in a callable graph function.

Conv2DTranspose.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv2DTranspose.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.get_possible_initializer_keys(cls, use_bias=True)

Conv2DTranspose.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.graph

Returns the Graph instance which the module is connected to, or None.

Conv2DTranspose.has_bias

Returns True if bias Variable is present in the module.

Conv2DTranspose.initializers

Returns the initializers dictionary.

Conv2DTranspose.input_channels

Returns the number of input channels.

Conv2DTranspose.input_shape

Returns the input shape.

Conv2DTranspose.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv2DTranspose.kernel_shape

Returns the kernel shape.

Conv2DTranspose.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.module_name

Returns the name of the Module.

Conv2DTranspose.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv2DTranspose.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.output_channels

Returns the number of output channels.

Conv2DTranspose.output_shape

Returns the output shape.

Conv2DTranspose.padding

Returns the padding algorithm.

Conv2DTranspose.partitioners

Returns the partitioners dictionary.

Conv2DTranspose.regularizers

Returns the regularizers dictionary.

Conv2DTranspose.scope_name

Returns the full name of the Module's variable scope.

Conv2DTranspose.stride

Returns the stride.

Conv2DTranspose.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.transpose(name=None)

Returns matching Conv2D module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.name.
Returns:

Conv2D module.

Conv2DTranspose.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv2DTranspose.w

Returns the Variable containing the weight matrix.

class Conv3D

Volumetric convolution module, including optional bias.

This acts as a light wrapper around the class _ConvND.

Conv3D.__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, mask=None, data_format='NDHWC', custom_getter=None, name='conv_3d')

Constructs a Conv3D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. output_channels can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_channels can be called, returning an integer, when build is called.
  • kernel_shape: Sequence of kernel sizes (of size 3), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 3), or integer that is used to define stride in all dimensions.
  • rate: Sequence of dilation rates (of size 3), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard 3D convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 3.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • mask: An object convertible to a 5D tensor which is multiplied component-wise with the weights (Optional).
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NDHWC), or the second dimension (NCDHW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two or four integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_3D_DATA_FORMATS).

Conv3D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv3D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv3D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

Conv3D.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv3D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv3D.data_format

Returns the data format.

Conv3D.defun()

Wraps this modules call method in a callable graph function.

Conv3D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv3D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.get_possible_initializer_keys(cls, use_bias=True)

Conv3D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.graph

Returns the Graph instance which the module is connected to, or None.

Conv3D.has_bias

Returns True if bias Variable is present in the module.

Conv3D.initializers

Returns the initializers dictionary.

Conv3D.input_channels

Returns the number of input channels.

Conv3D.input_shape

Returns the input shape.

Conv3D.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv3D.kernel_shape

Returns the kernel shape.

Conv3D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.mask

Returns the mask.

Conv3D.module_name

Returns the name of the Module.

Conv3D.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv3D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.output_channels

Returns the number of output channels.

Conv3D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

Conv3D.paddings

Returns a tuple with the padding algorithm used for each dimension.

Conv3D.partitioners

Returns the partitioners dictionary.

Conv3D.rate

Returns the dilation rate.

Conv3D.regularizers

Returns the regularizers dictionary.

Conv3D.scope_name

Returns the full name of the Module's variable scope.

Conv3D.stride

Returns the stride.

Conv3D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.transpose(name=None)

Returns matching Conv3DTranspose module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.name.
Returns:

Conv3DTranspose module.

Raises:

base.NotSupportedError: If rate in any dimension > 1.

Conv3D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3D.w

Returns the Variable containing the weight matrix.

class Conv3DTranspose

Volumetric transposed / reverse / up 3D convolution module, including bias.

This acts as a light wrapper around the TensorFlow op tf.nn.conv3d_transpose abstracting away variable creation and sharing.

Conv3DTranspose.__init__(output_channels, output_shape=None, kernel_shape=None, stride=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NDHWC', custom_getter=None, name='conv_3d_transpose')

Constructs a Conv3DTranspose module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. output_channels can be either a number or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure output_channels can be called, returning an integer, when build is called.
  • output_shape: Output shape of transpose convolution. Can be either an iterable of integers or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_shape can be called, returning an iterable of format (out_depth, out_height, out_width) when build is called. Note that output_shape defines the size of output signal domain, as opposed to the shape of the output Tensor. If a None value is given, a default shape is automatically calculated (see docstring of _default_transpose_size function for more details).
  • kernel_shape: Sequence of kernel sizes (of size 3), or integer that is used to define kernel size in all dimensions.
  • stride: Sequence of kernel strides (of size 3), or integer that is used to define stride in all dimensions.
  • padding: Padding algorithm, either snt.SAME or snt.VALID.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NDHWC), or the second dimension (NCDHW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

module.IncompatibleShapeError: If the given kernel shape is neither an integer nor a sequence of three integers. module.IncompatibleShapeError: If the given stride is neither an integer nor a sequence of three or five integers.

  • ValueError: If the given padding is not snt.VALID or snt.SAME.
  • ValueError: If the given kernel_shape is None.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_3D_DATA_FORMATS).

Conv3DTranspose.__call__(inputs)

Connects the _ConvNDTranspose module into the graph.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same final N dimensions, in order for the existing variables to be the correct size for the multiplication. The batch size may differ for each connection.

Args:
  • inputs: A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.
Returns:

A Tensor of shape data_format and of type tf.float16, tf.bfloat16 or tf.float32.

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If output_shape is an iterable and is not in the format (out_height, out_width).

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

Conv3DTranspose.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Conv3DTranspose.connected_subgraphs

Returns the subgraphs created by this module so far.

Conv3DTranspose.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

Conv3DTranspose.defun()

Wraps this modules call method in a callable graph function.

Conv3DTranspose.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Conv3DTranspose.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.get_possible_initializer_keys(cls, use_bias=True)

Conv3DTranspose.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.graph

Returns the Graph instance which the module is connected to, or None.

Conv3DTranspose.has_bias

Returns True if bias Variable is present in the module.

Conv3DTranspose.initializers

Returns the initializers dictionary.

Conv3DTranspose.input_channels

Returns the number of input channels.

Conv3DTranspose.input_shape

Returns the input shape.

Conv3DTranspose.is_connected

Returns true iff the Module been connected to the Graph at least once.

Conv3DTranspose.kernel_shape

Returns the kernel shape.

Conv3DTranspose.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.module_name

Returns the name of the Module.

Conv3DTranspose.name_scopes

Returns a tuple of all name_scopes generated by this module.

Conv3DTranspose.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.output_channels

Returns the number of output channels.

Conv3DTranspose.output_shape

Returns the output shape.

Conv3DTranspose.padding

Returns the padding algorithm.

Conv3DTranspose.partitioners

Returns the partitioners dictionary.

Conv3DTranspose.regularizers

Returns the regularizers dictionary.

Conv3DTranspose.scope_name

Returns the full name of the Module's variable scope.

Conv3DTranspose.stride

Returns the stride.

Conv3DTranspose.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.transpose(name=None)

Returns transposed Conv3DTranspose module, i.e. a Conv3D module.

Conv3DTranspose.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Conv3DTranspose.w

Returns the Variable containing the weight matrix.

class DeepRNN

RNN core that passes data through a number of internal modules or ops.

This module is constructed by passing an iterable of externally constructed modules or ops. The DeepRNN takes (input, prev_state) as input and passes the input through each internal module in the order they were presented, using elements from prev_state as necessary for internal recurrent cores. The output is (output, next_state) in common with other RNN cores. By default, skip connections from the input to all internal modules and from each intermediate output to the final output are used.

E.g.:

lstm1 = snt.LSTM(hidden_size=256)
lstm2 = snt.LSTM(hidden_size=256)
deep_rnn = snt.DeepRNN([lstm1, lstm2])
output, next_state = deep_rnn(input, prev_state)

The computation set up inside the DeepRNN has the same effect as:

prev_state1, prev_state2 = prev_state
lstm1_output, next_state1 = lstm1(input, prev_state1)
lstm2_output, next_state2 = lstm2(
    tf.concat([input, lstm1_output], 1), prev_state2)

next_state = (next_state1, next_state2)
output = tf.concat([lstm1_output, lstm2_output], 1)

Every internal module receives the preceding module's output and the entire core's input. The output is created by concatenating each internal module's output. In the case of internal recurrent elements, corresponding elements of the state are used such that state[i] is passed to the i'th internal recurrent element. Note that the state of a DeepRNN is always a tuple, which will contain the same number of elements as there are internal recurrent cores. If no internal modules are recurrent, the state of the DeepRNN as a whole is the empty tuple. Wrapping non-recurrent modules into a DeepRNN can be useful to produce something API compatible with a "real" recurrent module, simplifying code that handles the cores.

Without skip connections the previous example would become the following (note the only difference is the addition of skip_connections=False):

# ... declare other modules as above
deep_rnn = snt.DeepRNN([lin, tanh, lstm], skip_connections=False)
output, next_state = deep_rnn(input, prev_state)

which is equivalent to:

lin_output = lin(input)
tanh_output = tanh(lin_output)
lstm_output, lstm_next_state = lstm(tanh_output, prev_state[0])

next_state = (lstm_next_state,)
output = lstm_output

Note: when using skip connections, all the cores should be recurrent.

DeepRNN.__init__(cores, skip_connections=True, concat_final_output_if_skip=True, name='deep_rnn')

Construct a Deep RNN core.

Args:
  • cores: iterable of modules or ops.
  • skip_connections: a boolean that indicates whether to use skip connections. This means that the input is fed to all the layers, after being concatenated on the last dimension with the output of the previous layer. The output of the module will be the concatenation of all the outputs of the internal modules.
  • concat_final_output_if_skip: A boolean that indicates whether the outputs of intermediate layers should be concatenated into the timestep-wise output of the core. By default this is True. If this is set to False, then the core output is that of the final layer, i.e. that of cores[-1].
  • name: name of the module.
Raises:
  • ValueError: if cores is not an iterable, or if skip_connections is True and not all the modules are recurrent.

DeepRNN.__call__(inputs, prev_state, **kwargs)

Connects the DeepRNN module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as input_ and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • inputs: a nested tuple of Tensors of arbitrary dimensionality, with at least an initial batch dimension.
  • prev_state: a tuple of prev_states that corresponds to the state of each one of the cores of the DeepCore.
  • **kwargs: optional kwargs to be passed to the _build of all sub-modules. E.g. is_training=True. Note all sub-modules must accept the given kwarg.
Returns:
  • output: a nested tuple of Tensors of arbitrary dimensionality, with at least an initial batch dimension.
  • next_state: a tuple of next_states that corresponds to the updated state of each one of the cores of the DeepCore.
Raises:
  • ValueError: if connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations. This may happen if one connects a module any time after the first time that does not have the configuration of skip connections as the first time.

DeepRNN.connected_subgraphs

Returns the subgraphs created by this module so far.

DeepRNN.defun()

Wraps this modules call method in a callable graph function.

DeepRNN.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

DeepRNN.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

DeepRNN.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.graph

Returns the Graph instance which the module is connected to, or None.

DeepRNN.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None)

Builds the default start state for a DeepRNN.

Args:
  • batch_size: An int, float or scalar Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the number of passed initializers is not the same as the number of recurrent cores.

DeepRNN.is_connected

Returns true iff the Module been connected to the Graph at least once.

DeepRNN.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.module_name

Returns the name of the Module.

DeepRNN.name_scopes

Returns a tuple of all name_scopes generated by this module.

DeepRNN.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.output_size

DeepRNN.scope_name

Returns the full name of the Module's variable scope.

DeepRNN.state_size

DeepRNN.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DeepRNN.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class DepthwiseConv2D

Spatial depthwise 2D convolution module, including bias.

This acts as a light wrapper around the TensorFlow ops tf.nn.depthwise_conv2d, abstracting away variable creation and sharing.

DepthwiseConv2D.__init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NHWC', custom_getter=None, name='conv_2d_depthwise')

Constructs a DepthwiseConv2D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • channel_multiplier: Number of channels to expand convolution to. Must be an integer. Must be > 0. When channel_multiplier is set to 1, apply a different filter to each input channel producing one output channel per input channel. Numbers larger than 1 cause multiple different filters to be applied to each input channel, with their outputs being concatenated together, producing channel_multiplier * input_channels output channels.
  • kernel_shape: Iterable with 2 elements in the following layout: [filter_height, filter_width] or integer that is used to define the list in all dimensions.
  • stride: Iterable with 2 or 4 elements of kernel strides, or integer that is used to define stride in all dimensions. Layout of list: In case of 4 elements: [1, stride_height, stride_widith, 1] In case of 2 elements: [stride_height, stride_width].
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 2.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners for the filters (with key 'w') and the biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NHWC), or the second dimension ("NCHW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • ValueError: If channel_multiplier isn't of type (numbers.Integral or tf.Dimension).
  • ValueError: If channel_multiplier is less than 1.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_2D_DATA_FORMATS). base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.

  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.

DepthwiseConv2D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

DepthwiseConv2D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

DepthwiseConv2D.channel_multiplier

Returns the channel multiplier argument.

DepthwiseConv2D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

DepthwiseConv2D.connected_subgraphs

Returns the subgraphs created by this module so far.

DepthwiseConv2D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

DepthwiseConv2D.data_format

Returns the data format.

DepthwiseConv2D.defun()

Wraps this modules call method in a callable graph function.

DepthwiseConv2D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

DepthwiseConv2D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.get_possible_initializer_keys(cls, use_bias=True)

DepthwiseConv2D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.graph

Returns the Graph instance which the module is connected to, or None.

DepthwiseConv2D.has_bias

Returns True if bias Variable is present in the module.

DepthwiseConv2D.initializers

Returns the initializers dictionary.

DepthwiseConv2D.input_channels

Returns the number of input channels.

DepthwiseConv2D.input_shape

Returns the input shape.

DepthwiseConv2D.is_connected

Returns true iff the Module been connected to the Graph at least once.

DepthwiseConv2D.kernel_shape

Returns the kernel shape.

DepthwiseConv2D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.mask

Returns the mask.

DepthwiseConv2D.module_name

Returns the name of the Module.

DepthwiseConv2D.name_scopes

Returns a tuple of all name_scopes generated by this module.

DepthwiseConv2D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.output_channels

Returns the number of output channels.

DepthwiseConv2D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

DepthwiseConv2D.paddings

Returns a tuple with the padding algorithm used for each dimension.

DepthwiseConv2D.partitioners

Returns the partitioners dictionary.

DepthwiseConv2D.rate

Returns the dilation rate.

DepthwiseConv2D.regularizers

Returns the regularizers dictionary.

DepthwiseConv2D.scope_name

Returns the full name of the Module's variable scope.

DepthwiseConv2D.stride

Returns the stride.

DepthwiseConv2D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

DepthwiseConv2D.w

Returns the Variable containing the weight matrix.

class DifferentGraphError

Error raised when trying to connect a Sonnet module to multiple Graphs.

class Embed

Module for embedding tokens in a low-dimensional space.

Embed.__init__(vocab_size=None, embed_dim=None, existing_vocab=None, densify_gradients=False, initializers=None, partitioners=None, regularizers=None, trainable=True, custom_getter=None, name='embed')

Constructs an Embed module.

Args:
  • vocab_size: int. Number of unique tokens to embed. If not provided, an existing vocabulary matrix from which vocab_size can be inferred must be provided as existing_vocab.
  • embed_dim: int or None. Number of dimensions to assign to each embedding. If not specified, a sensible default is chosen based on vocab_size. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred.
  • existing_vocab: a [vocab_size, embed_dim] vocabulary matrix. Will be converted to a tf.float32 tensor. If provided, neither or vocab_size or embed_dim should be provided as they are inferred.
  • densify_gradients: if True, we convert the embedding gradient from an indexed-slices to a regular tensor before sending it back to the parameter server. This avoids excess computation on the parameter server. Use this option for moderately sized embeddings, e.g., a vocabulary size on the order of up to thousands. For embeddings larger than these, e.g. a vocabulary size on the order of tens or hundreds of thousands, set this to False.
  • initializers: Optional dict containing initializers for embeddings (with key 'embeddings'). As a default, embeddings are initialized via a truncated normal distribution.
  • partitioners: Optional dict containing partitioners for embeddings (with key 'embeddings'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for embeddings (with key 'embeddings'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • trainable: if True, the embeddings will be updated during training. If False, they are fixed to their initial values. If trainable=False and a regularizer is given, the resulting loss stays constant.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: string. Name for this module.
Raises:
  • ValueError: if neither one of vocab_size or existing_vocab is provided, or if existing_vocab is provided along with vocab_size, embedding_dim, initializers, partitioners or regularizers (as these should be inferred).

Embed.__call__(ids)

Lookup embeddings.

Looks up an embedding vector for each value in ids. All ids must be within [0, vocab_size), else an InvalidArgumentError is raised at runtime.

Args:
  • ids: Tensor of dtype int64.
Returns:

Tensor of tf.shape(ids) + [embedding_dim] and dtype float32.

Embed.connected_subgraphs

Returns the subgraphs created by this module so far.

Embed.defun()

Wraps this modules call method in a callable graph function.

Embed.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Embed.embed_dim

Size of embedding vectors.

Embed.embeddings

Returns the Variable containing embeddings.

Returns:

A 2D Variable containing one embedding vector per row, constructed in the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

Embed.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

Embed.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.graph

Returns the Graph instance which the module is connected to, or None.

Embed.is_connected

Returns true iff the Module been connected to the Graph at least once.

Embed.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.module_name

Returns the name of the Module.

Embed.name_scopes

Returns a tuple of all name_scopes generated by this module.

Embed.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.scope_name

Returns the full name of the Module's variable scope.

Embed.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Embed.vocab_size

Size of input vocabulary.

class Error

Base class for all errors from snt.

This is thrown to indicate a Neural Network specific problem, e.g. wrong module arity, module is not connected to the graph when it should be, tried to wire together incompatible modules, etc.

class FlattenTrailingDimensions

Flattens trailing dimensions of a Tensor.

FlattenTrailingDimensions.__init__(dim_from, name='batch_dim_from')

Constructs a FlattenTrailingDimensions module.

For example, given an input Tensor with shape [B, H, W, C]:

  • dim_from=1 will return a Tensor with shape [B, H*W*C].
  • dim_from=2 will return a Tensor with shape [B, H, W*C].
  • dim_from=3 will return the input itself.
  • dim_from=4 will return a Tensor with shape [B, H, W, C, 1].
  • dim_from>=5 will generate a ValueError when building the module. The preserved dimensions can be unknown at building time.

Equivalent to BatchFlatten(preserve_dims=dim_from, name=name).

Args:
  • dim_from: All dimensions after and including dim_from will be flattened into a single dimension.
  • name: Name of the module.
Raises:
  • ValueError: If dim_from <= 0.

FlattenTrailingDimensions.__call__(inputs)

Connects the module into the graph, with input Tensor inputs.

Args:
  • inputs: A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_preserve_dims+1, ...].
Returns:

A Tensor of shape [b_1, b_2, ..., b_preserve_dims, b_reshape_1, b_reshape_2, ...], with reshaping defined by the constructor shape parameter.

Raises:
  • ValueError: If output shape is incompatible with input shape; or if shape array contains non numeric entries; or if shape array contains more than 1 wildcard -1; or if the input array contains unknown, non-preserved dimensions (except when the unknown dimension is the only non-preserved dimension and doesn't actually need reshaping).

FlattenTrailingDimensions.connected_subgraphs

Returns the subgraphs created by this module so far.

FlattenTrailingDimensions.defun()

Wraps this modules call method in a callable graph function.

FlattenTrailingDimensions.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

FlattenTrailingDimensions.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

FlattenTrailingDimensions.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.graph

Returns the Graph instance which the module is connected to, or None.

FlattenTrailingDimensions.input_shape

FlattenTrailingDimensions.is_connected

Returns true iff the Module been connected to the Graph at least once.

FlattenTrailingDimensions.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.module_name

Returns the name of the Module.

FlattenTrailingDimensions.name_scopes

Returns a tuple of all name_scopes generated by this module.

FlattenTrailingDimensions.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.scope_name

Returns the full name of the Module's variable scope.

FlattenTrailingDimensions.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.transpose(name=None)

Returns transpose batch reshape.

FlattenTrailingDimensions.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

FlattenTrailingDimensions.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class GRU

GRU recurrent network cell.

The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf.

Attributes: state_size: Integer indicating the size of state tensor. output_size: Integer indicating the size of the core output.

GRU.__init__(hidden_size, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='gru')

Construct GRU.

Args:
  • hidden_size: (int) Hidden size dimensionality.
  • initializers: Dict containing ops to initialize the weights. This dict may contain any of the keys returned by GRU.get_possible_initializer_keys.
  • partitioners: Optional dict containing partitioners to partition the weights and biases. As a default, no partitioners are used. This dict may contain any of the keys returned by GRU.get_possible_initializer_keys
  • regularizers: Optional dict containing regularizers for the weights and biases. As a default, no regularizers are used. This dict may contain any of the keys returned by GRU.get_possible_initializer_keys
  • custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. See the tf.get_variable documentation for more details.
  • name: Name of the module.
Raises:
  • KeyError: if initializers contains any keys not returned by GRU.get_possible_initializer_keys.
  • KeyError: if partitioners contains any keys not returned by GRU.get_possible_initializer_keys.
  • KeyError: if regularizers contains any keys not returned by GRU.get_possible_initializer_keys.

GRU.__call__(inputs, prev_state)

Connects the GRU module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as inputs and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • inputs: Tensor of size [batch_size, input_size].
  • prev_state: Tensor of size [batch_size, hidden_size].
Returns:

A tuple (output, next_state) where output is a Tensor of size [batch_size, hidden_size] and next_state is a Tensor of size [batch_size, hidden_size].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations.

GRU.connected_subgraphs

Returns the subgraphs created by this module so far.

GRU.defun()

Wraps this modules call method in a callable graph function.

GRU.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

GRU.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

The set of all possible initializer keys are:
  • wz: weight for input -> update cell
  • uz: weight for prev_state -> update cell
  • bz: bias for update_cell
  • wr: weight for input -> reset cell
  • ur: weight for prev_state -> reset cell
  • br: bias for reset cell
  • wh: weight for input -> candidate activation
  • uh: weight for prev_state -> candidate activation
  • bh: bias for candidate activation
Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

GRU.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.graph

Returns the Graph instance which the module is connected to, or None.

GRU.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

GRU.is_connected

Returns true iff the Module been connected to the Graph at least once.

GRU.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.module_name

Returns the name of the Module.

GRU.name_scopes

Returns a tuple of all name_scopes generated by this module.

GRU.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.output_size

GRU.scope_name

Returns the full name of the Module's variable scope.

GRU.state_size

GRU.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GRU.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class GridWarper

Grid warper interface class.

An object implementing the GridWarper interface generates a reference grid of feature points at construction time, and warps it via a parametric transformation model, specified at run time by an input parameter Tensor. Grid warpers must then implement a create_features function used to generate the reference grid to be warped in the forward pass (according to a determined warping model).

GridWarper.__init__(source_shape, output_shape, num_coeff, name, **kwargs)

Constructs a GridWarper module and initializes the source grid params.

source_shape and output_shape are used to define the size of the source and output signal domains, as opposed to the shape of the respective Tensors. For example, for an image of size width=W and height=H, {source,output}_shape=[H, W]; for a volume of size width=W, height=H and depth=D, {source,output}_shape=[H, W, D].

Args:
  • source_shape: Iterable of integers determining the size of the source signal domain.
  • output_shape: Iterable of integers determining the size of the destination resampled signal domain.
  • num_coeff: Number of coefficients parametrizing the grid warp. For example, a 2D affine transformation will be defined by the 6 parameters populating the corresponding 2x3 affine matrix.
  • name: Name of Module.
  • **kwargs: Extra kwargs to be forwarded to the create_features function, instantiating the source grid parameters.
Raises:
  • Error: If len(output_shape) > len(source_shape).
  • TypeError: If output_shape and source_shape are not both iterable.

GridWarper.__call__(*args, **kwargs)

Add elements to the Graph, computing output Tensors from input Tensors.

Subclasses must implement this method, which will be wrapped in a Template.

Args:
  • *args: Input Tensors.
  • **kwargs: Additional Python flags controlling connection.
Returns:

output Tensor(s).

GridWarper.connected_subgraphs

Returns the subgraphs created by this module so far.

GridWarper.defun()

Wraps this modules call method in a callable graph function.

GridWarper.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

GridWarper.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

GridWarper.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.graph

Returns the Graph instance which the module is connected to, or None.

GridWarper.is_connected

Returns true iff the Module been connected to the Graph at least once.

GridWarper.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.module_name

Returns the name of the Module.

GridWarper.n_coeff

Returns number of coefficients of warping function.

GridWarper.name_scopes

Returns a tuple of all name_scopes generated by this module.

GridWarper.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.output_shape

Returns a tuple containing the shape of the output grid.

GridWarper.psi

Returns a list of features used to compute the grid warp.

GridWarper.scope_name

Returns the full name of the Module's variable scope.

GridWarper.source_shape

Returns a tuple containing the shape of the source signal.

GridWarper.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

GridWarper.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class HighwayCore

Recurrent Highway Network cell.

The implementation is based on: https://arxiv.org/pdf/1607.03474v5.pdf As per the first lines of section 5 of the reference paper, 1 - T is used instead of a dedicated C gate.

Attributes: state_size: Integer indicating the size of state tensor. output_size: Integer indicating the size of the core output.

HighwayCore.__init__(hidden_size, num_layers, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='highwaycore')

Construct a new Recurrent Highway core.

Args:
  • hidden_size: (int) Hidden size dimensionality.
  • num_layers: (int) Number of highway layers.
  • initializers: Dict containing ops to initialize the weights. This dict may contain any of the keys returned by HighwayCore.get_possible_initializer_keys.
  • partitioners: Optional dict containing partitioners to partition the weights and biases. As a default, no partitioners are used. This dict may contain any of the keys returned by HighwayCore.get_possible_initializer_keys.
  • regularizers: Optional dict containing regularizers for the weights and biases. As a default, no regularizers are used. This dict may contain any of the keys returned by HighwayCore.get_possible_initializer_keys.
  • custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. See the tf.get_variable documentation for more details.
  • name: Name of the module.
Raises:
  • KeyError: if initializers contains any keys not returned by HighwayCore.get_possible_initializer_keys.
  • KeyError: if partitioners contains any keys not returned by HighwayCore.get_possible_initializer_keys.
  • KeyError: if regularizers contains any keys not returned by HighwayCore.get_possible_initializer_keys.

HighwayCore.__call__(inputs, prev_state)

Connects the highway core module into the graph.

Args:
  • inputs: Tensor of size [batch_size, input_size].
  • prev_state: Tensor of size [batch_size, hidden_size].
Returns:

A tuple (output, next_state) where output is a Tensor of size [batch_size, hidden_size] and next_state is a Tensor of size [batch_size, hidden_size].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations.

HighwayCore.connected_subgraphs

Returns the subgraphs created by this module so far.

HighwayCore.defun()

Wraps this modules call method in a callable graph function.

HighwayCore.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

HighwayCore.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.get_possible_initializer_keys(cls, num_layers)

Returns the keys the dictionary of variable initializers may contain.

The set of all possible initializer keys are:
  • wt: weight for input -> T gate
  • wh: weight for input -> H gate
  • wtL: weight for prev state -> T gate for layer L (indexed from 0)
  • whL: weight for prev state -> H gate for layer L (indexed from 0)
  • btL: bias for prev state -> T gate for layer L (indexed from 0)
  • bhL: bias for prev state -> H gate for layer L (indexed from 0)
Args:
  • num_layers: (int) Number of highway layers.
Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

HighwayCore.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.graph

Returns the Graph instance which the module is connected to, or None.

HighwayCore.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

HighwayCore.is_connected

Returns true iff the Module been connected to the Graph at least once.

HighwayCore.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.module_name

Returns the name of the Module.

HighwayCore.name_scopes

Returns a tuple of all name_scopes generated by this module.

HighwayCore.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.output_size

HighwayCore.scope_name

Returns the full name of the Module's variable scope.

HighwayCore.state_size

HighwayCore.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

HighwayCore.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class InPlaneConv2D

Applies an in-plane convolution to each channel with tied filter weights.

This acts as a light wrapper around the TensorFlow op tf.nn.depthwise_conv2d; it differs from the DepthWiseConv2D module in that it has tied weights (i.e. the same filter) for all the in-out channel pairs.

InPlaneConv2D.__init__(kernel_shape, stride=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NHWC', custom_getter=None, name='in_plane_conv2d')

Constructs an InPlaneConv2D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • kernel_shape: Iterable with 2 elements in the layout [filter_height, filter_width]; or integer that is used to define the list in all dimensions.
  • stride: Iterable with 2 or 4 elements of kernel strides, or integer that is used to define stride in all dimensions.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 2.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition the filters (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NHWC), or the second dimension (NCHW).
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_2D_DATA_FORMATS). base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.

  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.

InPlaneConv2D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

InPlaneConv2D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

InPlaneConv2D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

InPlaneConv2D.connected_subgraphs

Returns the subgraphs created by this module so far.

InPlaneConv2D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

InPlaneConv2D.data_format

Returns the data format.

InPlaneConv2D.defun()

Wraps this modules call method in a callable graph function.

InPlaneConv2D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

InPlaneConv2D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.get_possible_initializer_keys(cls, use_bias=True)

InPlaneConv2D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.graph

Returns the Graph instance which the module is connected to, or None.

InPlaneConv2D.has_bias

Returns True if bias Variable is present in the module.

InPlaneConv2D.initializers

Returns the initializers dictionary.

InPlaneConv2D.input_channels

Returns the number of input channels.

InPlaneConv2D.input_shape

Returns the input shape.

InPlaneConv2D.is_connected

Returns true iff the Module been connected to the Graph at least once.

InPlaneConv2D.kernel_shape

Returns the kernel shape.

InPlaneConv2D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.mask

Returns the mask.

InPlaneConv2D.module_name

Returns the name of the Module.

InPlaneConv2D.name_scopes

Returns a tuple of all name_scopes generated by this module.

InPlaneConv2D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.output_channels

Returns the number of output channels.

InPlaneConv2D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

InPlaneConv2D.paddings

Returns a tuple with the padding algorithm used for each dimension.

InPlaneConv2D.partitioners

Returns the partitioners dictionary.

InPlaneConv2D.rate

Returns the dilation rate.

InPlaneConv2D.regularizers

Returns the regularizers dictionary.

InPlaneConv2D.scope_name

Returns the full name of the Module's variable scope.

InPlaneConv2D.stride

Returns the stride.

InPlaneConv2D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

InPlaneConv2D.w

Returns the Variable containing the weight matrix.

class IncompatibleShapeError

Error raised when the shape of the input at build time is incompatible.

class LSTM

LSTM recurrent network cell with optional peepholes & layer normalization.

The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

Layer normalization

This is described in https://arxiv.org/pdf/1607.06450.pdf

Peep-hole connections

Peep-hole connections may optionally be used by specifying a flag in the constructor. These connections can aid increasing the precision of output timing, for more details see:

https://research.google.com/pubs/archive/43905.pdf

Recurrent projections

Projection of the recurrent state, to reduce model parameters and speed up computation. For more details see:

https://arxiv.org/abs/1402.1128

Attributes: state_size: Tuple of tf.TensorShapes indicating the size of state tensors. output_size: tf.TensorShape indicating the size of the core output. use_peepholes: Boolean indicating whether peephole connections are used.

LSTM.__init__(hidden_size, forget_bias=1.0, initializers=None, partitioners=None, regularizers=None, use_peepholes=False, use_layer_norm=False, hidden_clip_value=None, projection_size=None, cell_clip_value=None, custom_getter=None, name='lstm')

Construct LSTM.

Args:
  • hidden_size: (int) Hidden size dimensionality.
  • forget_bias: (float) Bias for the forget activation.
  • initializers: Dict containing ops to initialize the weights. This dictionary may contain any of the keys returned by LSTM.get_possible_initializer_keys.
  • partitioners: Optional dict containing partitioners to partition the weights and biases. As a default, no partitioners are used. This dict may contain any of the keys returned by LSTM.get_possible_initializer_keys.
  • regularizers: Optional dict containing regularizers for the weights and biases. As a default, no regularizers are used. This dict may contain any of the keys returned by LSTM.get_possible_initializer_keys.
  • use_peepholes: Boolean that indicates whether peephole connections are used.
  • use_layer_norm: Boolean that indicates whether to apply layer normalization.
  • hidden_clip_value: Optional number; if set, then the LSTM hidden state vector is clipped by this value.
  • projection_size: Optional number; if set, then the LSTM hidden state is projected to this size via a learnable projection matrix.
  • cell_clip_value: Optional number; if set, then the LSTM cell vector is clipped by this value.
  • custom_getter: Callable that takes as a first argument the true getter, and allows overwriting the internal get_variable method. See the tf.get_variable documentation for more details.
  • name: Name of the module.
Raises:
  • KeyError: if initializers contains any keys not returned by LSTM.get_possible_initializer_keys.
  • KeyError: if partitioners contains any keys not returned by LSTM.get_possible_initializer_keys.
  • KeyError: if regularizers contains any keys not returned by LSTM.get_possible_initializer_keys.
  • ValueError: if a peephole initializer is passed in the initializer list, but use_peepholes is False.

LSTM.__call__(inputs, prev_state)

Connects the LSTM module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as inputs and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • inputs: Tensor of size [batch_size, input_size].
  • prev_state: Tuple (prev_hidden, prev_cell).
Returns:

A tuple (output, next_state) where 'output' is a Tensor of size [batch_size, hidden_size] and 'next_state' is a LSTMState namedtuple (next_hidden, next_cell) where next_hidden and next_cell have size [batch_size, hidden_size]. If projection_size is specified, then next_hidden will have size [batch_size, projection_size].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations.

LSTM.connected_subgraphs

Returns the subgraphs created by this module so far.

LSTM.defun()

Wraps this modules call method in a callable graph function.

LSTM.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

LSTM.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.get_possible_initializer_keys(cls, use_peepholes=False, use_projection=False)

Returns the keys the dictionary of variable initializers may contain.

The set of all possible initializer keys are:
  • w_gates: weight for gates
  • b_gates: bias of gates
  • w_f_diag: weight for prev_cell -> forget gate peephole
  • w_i_diag: weight for prev_cell -> input gate peephole
  • w_o_diag: weight for prev_cell -> output gate peephole
Args:

cls:The class.

  • use_peepholes: Boolean that indicates whether peephole connections are used.
  • use_projection: Boolean that indicates whether a recurrent projection layer is used.
Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

LSTM.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.graph

Returns the Graph instance which the module is connected to, or None.

LSTM.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

LSTM.is_connected

Returns true iff the Module been connected to the Graph at least once.

LSTM.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.module_name

Returns the name of the Module.

LSTM.name_scopes

Returns a tuple of all name_scopes generated by this module.

LSTM.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.output_size

tf.TensorShape indicating the size of the core output.

LSTM.scope_name

Returns the full name of the Module's variable scope.

LSTM.state_size

Tuple of tf.TensorShapes indicating the size of state tensors.

LSTM.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.use_layer_norm

Boolean indicating whether layer norm is enabled.

LSTM.use_peepholes

Boolean indicating whether peephole connections are used.

LSTM.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTM.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class LSTMBlockCell

Wraps the TensorFlow LSTMBlockCell as a Sonnet RNNCore.

LSTMBlockCell.__init__(num_units, forget_bias=1.0, cell_clip=None, use_peephole=False, dtype=None, reuse=None, name='lstm_cell')

Initialize the basic LSTM cell.

Args:
  • num_units: int, The number of units in the LSTM cell.
  • forget_bias: float, The bias added to forget gates (see above).
  • cell_clip: An optional float. Defaults to -1 (no clipping).
  • use_peephole: Whether to use peephole connections or not.
  • dtype: the variable dtype of this layer. Default to tf.float32.
  • reuse: (optional) boolean describing whether to reuse variables in an existing scope. If not True, and the existing scope already has the given variables, an error is raised.
  • name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. By default this is "lstm_cell", for variable-name compatibility with tf.nn.rnn_cell.LSTMCell.

When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMBlockCell instead.

LSTMBlockCell.__call__(inputs, prev_state)

LSTMBlockCell.connected_subgraphs

Returns the subgraphs created by this module so far.

LSTMBlockCell.defun()

Wraps this modules call method in a callable graph function.

LSTMBlockCell.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

LSTMBlockCell.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

LSTMBlockCell.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.graph

Returns the Graph instance which the module is connected to, or None.

LSTMBlockCell.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

LSTMBlockCell.is_connected

Returns true iff the Module been connected to the Graph at least once.

LSTMBlockCell.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.module_name

Returns the name of the Module.

LSTMBlockCell.name_scopes

Returns a tuple of all name_scopes generated by this module.

LSTMBlockCell.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.output_size

LSTMBlockCell.scope_name

Returns the full name of the Module's variable scope.

LSTMBlockCell.state_size

LSTMBlockCell.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LSTMBlockCell.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class LSTMState

LSTMState(hidden, cell)

LSTMState.cell

Alias for field number 1

LSTMState.hidden

Alias for field number 0

class LayerNorm

Layer normalization module.

Implementation based on: https://arxiv.org/abs/1607.06450

This module transforms input x into:

outputs = gamma * (x - mu) / sigma + beta

where mu and sigma are respectively the mean and standard deviation of x. Gamma and beta are trainable parameters for scaling and shifting respectively.

Since the axes over which normalization is perfomed is configurable, this also subsumes instance normalization.

LayerNorm.__init__(axis=None, offset=True, scale=True, eps=1e-05, initializers=None, partitioners=None, regularizers=None, name='layer_norm')

Constructs a LayerNorm module.

Args:
  • axis: Optional dimension or iterable of indices of dimensions to normalize and reduce over. By default None and all dimensions except the first/batch dimension are reduced over. If the input tensor represents an image, summing over all except the batch and channel dimensions (e.g. for image format NHWC, axes=[1,2]), then this module corresponds to Instance Normalization (https://arxiv.org/abs/1607.08022).
  • offset: Optional boolean to specify whether or not to apply a trained component-wise bias after the layer normalization and scaling.
  • scale: Optional boolean to specify whether or not to apply a trained component-wise scale after the layer normalization.
  • eps: small epsilon to avoid division by zero variance. Defaults to 1e-5 as used in the paper.
  • initializers: Dict containing ops to initialize the scale (with key 'gamma') and bias (with key 'beta').
  • partitioners: Optional dict containing partitioners to partition the scale (with key 'gamma') and bias (with key 'beta'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the scale (with key 'gamma') and bias (with key 'beta').. As a default, no regularizers are used.
  • name: name of the module.
Raises:
  • KeyError: If initializers, partitioners or regularizers contain any keys other than gamma or beta.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.

LayerNorm.__call__(inputs)

Connects the LayerNorm module into the graph.

Args:
  • inputs: a Tensor of dimensionality >= 2.
Returns:
  • normalized: layer normalized outputs with same shape as inputs.
Raises:

base.NotSupportedError: If inputs has less than 2 dimensions.

LayerNorm.beta

LayerNorm.connected_subgraphs

Returns the subgraphs created by this module so far.

LayerNorm.defun()

Wraps this modules call method in a callable graph function.

LayerNorm.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

LayerNorm.gamma

LayerNorm.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

LayerNorm.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.graph

Returns the Graph instance which the module is connected to, or None.

LayerNorm.initializers

LayerNorm.is_connected

Returns true iff the Module been connected to the Graph at least once.

LayerNorm.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.module_name

Returns the name of the Module.

LayerNorm.name_scopes

Returns a tuple of all name_scopes generated by this module.

LayerNorm.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.partitioners

LayerNorm.regularizers

LayerNorm.scope_name

Returns the full name of the Module's variable scope.

LayerNorm.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

LayerNorm.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class Linear

Linear module, optionally including bias.

Linear.__init__(output_size, use_bias=True, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='linear')

Constructs a Linear module.

Args:
  • output_size: Output dimensionality. output_size can be either an integer or a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that output_size can be called, returning an integer, when build is called.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing initializers to initialize the weights (with key 'w') or biases (with key 'b'). The default initializer for the weights is a truncated normal initializer, which is commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf). The default initializer for the bias is a zero initializer.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the weights (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • KeyError: If initializers, partitioners or regularizers contains any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.

Linear.__call__(inputs)

Connects the Linear module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the Tensor provided here must have the same final dimension, in order for the existing variables to be the correct size for the multiplication. The batch size may differ for each connection.

Args:
  • inputs: A 2D Tensor of size [batch_size, input_size].
Returns:

A 2D Tensor of size [batch_size, output_size].

Raises:

base.IncompatibleShapeError: If the input is not a 2-D Tensor with the size of the second dimension specified. base.IncompatibleShapeError: If reconnecting an already connected module into the graph, and the shape of the input is not compatible with previous inputs.

Linear.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

Linear.clone(name=None)

Returns a cloned Linear module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

Cloned Linear module.

Linear.connected_subgraphs

Returns the subgraphs created by this module so far.

Linear.defun()

Wraps this modules call method in a callable graph function.

Linear.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Linear.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.get_possible_initializer_keys(cls, use_bias=True)

Linear.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.graph

Returns the Graph instance which the module is connected to, or None.

Linear.has_bias

Returns True if bias Variable is present in the module.

Linear.initializers

Returns the initializers dictionary.

Linear.input_shape

Returns shape of input Tensor passed at last call to build.

Linear.is_connected

Returns true iff the Module been connected to the Graph at least once.

Linear.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.module_name

Returns the name of the Module.

Linear.name_scopes

Returns a tuple of all name_scopes generated by this module.

Linear.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.output_size

Returns the module output size.

Linear.partitioners

Returns the partitioners dictionary.

Linear.regularizers

Returns the regularizers dictionary.

Linear.scope_name

Returns the full name of the Module's variable scope.

Linear.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.transpose(name=None)

Returns transposed Linear module.

Args:
  • name: Optional string assigning name of transpose module. The default name is constructed by appending "_transpose" to self.module_name.
Returns:

Transposed Linear module.

Linear.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Linear.w

Returns the Variable containing the weight matrix.

Returns:

Variable object containing the weights, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

class MergeDims

Merges a tensor or nested list of tensors along a range of dimensions.

Tensors are reshaped by specifying the range of dimensions to merge. Hence, the reshape can be performed without knowing in advance the rank of the input tensor.

For example, merging dimensions 1, 2 and 3 together can be performed by calling:

output = MergeDims(start=1, size=3)(x)

A nested list of tensors can be merged:

x = [tf.random_uniform(shape=[5, 5]), [tf.random_uniform(shape=[3, 3, 3])]]
output = MergeDims(start=0, size=2)(x)

MergeDims.__init__(start, size, name='merge_dims')

Constructs the MergeDims module.

Args:
  • start: Start of the range of dimensions to merge.
  • size: Size the range of dimensions to merge.
  • name: The name of the module.
Raises:
  • ValueError: If size is not strictly greater than 1.

MergeDims.__call__(inputs)

Connects the MergeDims module into the graph.

Args:
  • inputs: Tensor or a nested list of Tensors to merge. Its rank must be greater than or equal to start + size.
Returns:

The merged Tensor or a nested list of merged Tensors.

Raises:
  • ValueError: If any of the inputs tensors has insufficient rank.

MergeDims.connected_subgraphs

Returns the subgraphs created by this module so far.

MergeDims.defun()

Wraps this modules call method in a callable graph function.

MergeDims.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

MergeDims.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

MergeDims.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.graph

Returns the Graph instance which the module is connected to, or None.

MergeDims.is_connected

Returns true iff the Module been connected to the Graph at least once.

MergeDims.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.module_name

Returns the name of the Module.

MergeDims.name_scopes

Returns a tuple of all name_scopes generated by this module.

MergeDims.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.scope_name

Returns the full name of the Module's variable scope.

MergeDims.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

MergeDims.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class ModelRNN

RNNCore that ignores input and uses a model to compute its next state.

ModelRNN.__init__(model, name='model_rnn')

Construct a Basic RNN core.

Args:
  • model: callable that computes the next state.
  • name: name of the module.
Raises:
  • TypeError: if model is not a callable object or if it is an RNNCore.
  • AttributeError: if model does not have an output_size attribute.

ModelRNN.__call__(inputs, prev_state)

Connects the ModelRNN module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as input_ and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • inputs: Tensor input to the ModelRNN (ignored).
  • prev_state: Tensor of size model.output_size.
Returns:
  • output: Tensor of size model.output_size.
  • next_state: Tensor of size model.output_size.

ModelRNN.connected_subgraphs

Returns the subgraphs created by this module so far.

ModelRNN.defun()

Wraps this modules call method in a callable graph function.

ModelRNN.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

ModelRNN.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

ModelRNN.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.graph

Returns the Graph instance which the module is connected to, or None.

ModelRNN.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

ModelRNN.is_connected

Returns true iff the Module been connected to the Graph at least once.

ModelRNN.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.module_name

Returns the name of the Module.

ModelRNN.name_scopes

Returns a tuple of all name_scopes generated by this module.

ModelRNN.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.output_size

ModelRNN.scope_name

Returns the full name of the Module's variable scope.

ModelRNN.state_size

ModelRNN.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ModelRNN.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class Module

Module wrapping a function provided by the user.

Module.__init__(build, custom_getter=None, name=None)

Constructs a module with a given build function.

The Module class can be used to wrap a function assembling a network into a module.

For example, the following code implements a simple one-hidden-layer MLP model by defining a function called make_model and using a Module instance to wrap it.

def make_model(inputs):
  lin1 = snt.Linear(name="lin1", output_size=10)(inputs)
  relu1 = tf.nn.relu(lin1, name="relu1")
  lin2 = snt.Linear(name="lin2", output_size=20)(relu1)
  return lin2

model = snt.Module(name='simple_mlp', build=make_model)
outputs = model(inputs)

The partial package from functools can be used to bake configuration parameters into the function at construction time, as shown in the following example.

from functools import partial

def make_model(inputs, output_sizes):
  lin1 = snt.Linear(name="lin1", output_size=output_sizes[0])(inputs)
  relu1 = tf.nn.relu(lin1, name="relu1")
  lin2 = snt.Linear(name="lin2", output_size=output_sizes[1])(relu1)
  return lin2

model = snt.Module(name='simple_mlp',
                   build=partial(make_model, output_sizes=[10, 20])
outputs = model(inputs)
Args:
  • build: Callable to be invoked when connecting the module to the graph. The build function is invoked when the module is called, and its role is to specify how to add elements to the Graph, and how to compute output Tensors from input Tensors. The build function signature can include the following parameters: args - Input Tensors. *kwargs - Additional Python parameters controlling connection.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Module name. If set to None (the default), the name will be set to that of the build callable converted to snake_case. If build has no name, the name will be 'module'.
Raises:
  • TypeError: If build is not callable.
  • TypeError: If a given custom_getter is not callable.

Module.__call__(*args, **kwargs)

Forwards call to the passed-in build function.

Module.connected_subgraphs

Returns the subgraphs created by this module so far.

Module.defun()

Wraps this modules call method in a callable graph function.

Module.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Module.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

Module.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.graph

Returns the Graph instance which the module is connected to, or None.

Module.is_connected

Returns true iff the Module been connected to the Graph at least once.

Module.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.module_name

Returns the name of the Module.

Module.name_scopes

Returns a tuple of all name_scopes generated by this module.

Module.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.scope_name

Returns the full name of the Module's variable scope.

Module.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Module.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class ModuleInfoError

Error raised when Sonnet ModuleInfo cannot be serialized.

class NotConnectedError

Error raised when operating on a module that has not yet been connected.

Some module properties / methods are valid to access before the module has been connected into the graph, but some are not. This Error is raised when the user attempts to do anything not valid before connection.

class NotInitializedError

Error raised when connecting an uninitialized Sonnet module.

Before they can be connected, all Sonnet modules must call AbstractModule.__init__ (e.g. via a super call).

class NotSupportedError

Error raised when something that cannot be supported is requested.

For example a Dilated Convolution module cannot be transposed.

class ParentNotBuiltError

Error raised when the parent of a module has not been built yet.

For example, when making a transpose of modules that inherit from module.Transposable, the parent has to be connected to the graph before the child transpose to ensure that shape inference has already occurred.

class RNNCellWrapper

RNN core that delegates to a tf.contrib.rnn.RNNCell.

RNNCellWrapper.__init__(cell_ctor, *args, **kwargs)

Constructs the cell, within this module's variable scope.

Args:
  • cell_ctor: Callable that instantiates a tf.contrib.rnn.RNNCell.
  • *args: Arguments to pass to cell_ctor.
  • **kwargs: Keyword arguments to pass to cell_ctor. If name is provided, it is passed to RNNCore.__init__ as well. If custom_getter is provided, it is passed to RNNCore.__init__ but not to cell_ctor.

RNNCellWrapper.__call__(inputs, prev_state)

RNNCellWrapper.connected_subgraphs

Returns the subgraphs created by this module so far.

RNNCellWrapper.defun()

Wraps this modules call method in a callable graph function.

RNNCellWrapper.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

RNNCellWrapper.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

RNNCellWrapper.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.graph

Returns the Graph instance which the module is connected to, or None.

RNNCellWrapper.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

RNNCellWrapper.is_connected

Returns true iff the Module been connected to the Graph at least once.

RNNCellWrapper.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.module_name

Returns the name of the Module.

RNNCellWrapper.name_scopes

Returns a tuple of all name_scopes generated by this module.

RNNCellWrapper.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.output_size

RNNCellWrapper.scope_name

Returns the full name of the Module's variable scope.

RNNCellWrapper.state_size

RNNCellWrapper.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCellWrapper.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class RNNCore

Superclass for Recurrent Neural Network Cores.

This class defines the basic functionality that every core should implement, mainly the initial_state method which will return an example of their initial state. It also inherits from the interface snt.AbstractModule.

As with any other snt.Module any subclass must implement a _build method that constructs the graph that corresponds to a core. Such a _build method should always have the same interface, which is the following:

output, next_state = self._build(input, prev_state)

where output, next_state, input, and prev_state are arbitrarily nested tensors. Such structures can be defined according to the following grammar:

element = tuple(element*) | list(element*) | tf.Tensor

This class is to be used with tensorflow containers such as rnn in tensorflow.python.ops.rnn. These containers only accept inputs which are compatible with the tf.contrib.rnn.RNNCell API, so that all the RNNCores should expose state_size and output_size properties.

RNNCore.__init__(_sentinel=None, custom_getter=None, name=None)

Performs the initialisation necessary for all AbstractModule instances.

Every subclass of AbstractModule must begin their constructor with a call to this constructor, i.e.

super(MySubModule, self).__init__(custom_getter=custom_getter, name=name).

If you instantiate sub-modules in init you must create them within the _enter_variable_scope context manager to ensure they are in the module's variable scope. Alternatively, instantiate sub-modules in _build.

Args:

_sentinel: Variable that only carries a non-None value if __init__ was called without named parameters. If this is the case, a deprecation warning is issued in form of a ValueError.

  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of this module. Used to construct the Templated build function. If None the module's class name is used (converted to snake case).
Raises:
  • TypeError: If name is not a string.
  • TypeError: If a given custom_getter is not callable.
  • ValueError: If __init__ was called without named arguments.

RNNCore.__call__(*args, **kwargs)

Add elements to the Graph, computing output Tensors from input Tensors.

Subclasses must implement this method, which will be wrapped in a Template.

Args:
  • *args: Input Tensors.
  • **kwargs: Additional Python flags controlling connection.
Returns:

output Tensor(s).

RNNCore.connected_subgraphs

Returns the subgraphs created by this module so far.

RNNCore.defun()

Wraps this modules call method in a callable graph function.

RNNCore.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

RNNCore.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

RNNCore.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.graph

Returns the Graph instance which the module is connected to, or None.

RNNCore.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

RNNCore.is_connected

Returns true iff the Module been connected to the Graph at least once.

RNNCore.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.module_name

Returns the name of the Module.

RNNCore.name_scopes

Returns a tuple of all name_scopes generated by this module.

RNNCore.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.output_size

Integer or TensorShape: size of outputs produced by this cell.

RNNCore.scope_name

Returns the full name of the Module's variable scope.

RNNCore.state_size

size(s) of state(s) used by this cell.

It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes.

RNNCore.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RNNCore.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class RelationalMemory

Relational Memory Core.

RelationalMemory.__init__(mem_slots, head_size, num_heads=1, num_blocks=1, forget_bias=1.0, input_bias=0.0, gate_style='unit', attention_mlp_layers=2, key_size=None, name='relational_memory')

Constructs a RelationalMemory object.

Args:
  • mem_slots: The total number of memory slots to use.
  • head_size: The size of an attention head.
  • num_heads: The number of attention heads to use. Defaults to 1.
  • num_blocks: Number of times to compute attention per time step. Defaults to 1.
  • forget_bias: Bias to use for the forget gate, assuming we are using some form of gating. Defaults to 1.
  • input_bias: Bias to use for the input gate, assuming we are using some form of gating. Defaults to 0.
  • gate_style: Whether to use per-element gating ('unit'), per-memory slot gating ('memory'), or no gating at all (None). Defaults to unit.
  • attention_mlp_layers: Number of layers to use in the post-attention MLP. Defaults to 2.
  • key_size: Size of vector to use for key & query vectors in the attention computation. Defaults to None, in which case we use head_size.
  • name: Name of the module.
Raises:
  • ValueError: gate_style not one of [None, 'memory', 'unit'].
  • ValueError: num_blocks is < 1.
  • ValueError: attention_mlp_layers is < 1.

RelationalMemory.__call__(inputs, memory, treat_input_as_matrix=False)

Adds relational memory to the TensorFlow graph.

Args:
  • inputs: Tensor input.
  • memory: Memory output from the previous time step.
  • treat_input_as_matrix: Optional, whether to treat input as a sequence of matrices. Defaulta to False, in which case the input is flattened into a vector.
Returns:
  • output: This time step's output.
  • next_memory: The next version of memory to use.

RelationalMemory.connected_subgraphs

Returns the subgraphs created by this module so far.

RelationalMemory.defun()

Wraps this modules call method in a callable graph function.

RelationalMemory.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

RelationalMemory.forget_gate

Returns the forget gate Tensor.

RelationalMemory.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

RelationalMemory.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.graph

Returns the Graph instance which the module is connected to, or None.

RelationalMemory.initial_state(batch_size, trainable=False)

Creates the initial memory.

We should ensure each row of the memory is initialized to be unique, so initialize the matrix to be the identity. We then pad or truncate as necessary so that init_state is of size (batch_size, self._mem_slots, self._mem_size).

Args:
  • batch_size: The size of the batch.
  • trainable: Whether the initial state is trainable. This is always True.
Returns:
  • init_state: A truncated or padded matrix of size (batch_size, self._mem_slots, self._mem_size).

RelationalMemory.input_gate

Returns the input gate Tensor.

RelationalMemory.is_connected

Returns true iff the Module been connected to the Graph at least once.

RelationalMemory.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.module_name

Returns the name of the Module.

RelationalMemory.name_scopes

Returns a tuple of all name_scopes generated by this module.

RelationalMemory.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.output_size

RelationalMemory.scope_name

Returns the full name of the Module's variable scope.

RelationalMemory.state_size

RelationalMemory.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

RelationalMemory.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

class Residual

Adds a residual connection to a base module.

This module wraps a module M, where if M with traditionally output M(X), Residual(M)(x) = M(x) + x.

Residual.__init__(base_module, name='residual')

Residual.__call__(inputs, **kwargs)

Residual.connected_subgraphs

Returns the subgraphs created by this module so far.

Residual.defun()

Wraps this modules call method in a callable graph function.

Residual.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Residual.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

Residual.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.graph

Returns the Graph instance which the module is connected to, or None.

Residual.is_connected

Returns true iff the Module been connected to the Graph at least once.

Residual.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.module_name

Returns the name of the Module.

Residual.name_scopes

Returns a tuple of all name_scopes generated by this module.

Residual.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.scope_name

Returns the full name of the Module's variable scope.

Residual.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Residual.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class ResidualCore

Adds a residual connection to a base RNN core.

This module wraps a module M, where if M with traditionally output M(X), Residual(M)(x) = M(x) + x.

ResidualCore.__init__(base_core, name='residual_core')

ResidualCore.__call__(inputs, prev_state, **kwargs)

ResidualCore.connected_subgraphs

Returns the subgraphs created by this module so far.

ResidualCore.defun()

Wraps this modules call method in a callable graph function.

ResidualCore.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

ResidualCore.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

ResidualCore.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.graph

Returns the Graph instance which the module is connected to, or None.

ResidualCore.initial_state(*args, **kwargs)

ResidualCore.is_connected

Returns true iff the Module been connected to the Graph at least once.

ResidualCore.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.module_name

Returns the name of the Module.

ResidualCore.name_scopes

Returns a tuple of all name_scopes generated by this module.

ResidualCore.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.output_size

ResidualCore.scope_name

Returns the full name of the Module's variable scope.

ResidualCore.state_size

ResidualCore.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

ResidualCore.zero_state(*args, **kwargs)

class SelectInput

Returns a subset of its inputs in an arbitrarily nested configuration.

This module can be used for multiple purposes.

The basic usage is to select a tensor or a subset of tensors:

output = snt.SelectInput(idx=0, name='select')(input0, input1)
==> input0

output = snt.SelectInput(idx=[0, 2], name='select')(input0, input1, input2)
==> (input0, input2)

Another usage is to change the orders of the input tensors:

output = snt.SelectInput(idx=[1, 0], name='select')(input0, input1)
==> (input1, input0)

Another usage is to duplicate an input:

output = snt.SelectInput(idx=[0, 0], name='select')(input0)
==> (input0, input0)

Another usage is to add arbitrary nesting:

output = snt.SelectInput(
    idx=[0, [1, [2]]], name='select')(input0, input1, input2)
==> (input0, (input1, (input2,)))

SelectInput.__init__(idx, name='select_input')

Module constructor.

Args:
  • idx: Indexes of the tensors to select. If idx is an integer, then a Tensor is returned. If idx is a (nested) list/tuple, then a (nested) tuple of Tensor is returned.
  • name: Name of the module.
Raises:
  • TypeError: If idx is not an list, tuple or integer.

SelectInput.__call__(*inputs)

Connects the module into the graph.

Args:
  • *inputs: Tensor variables to select.
Returns:

Subset of inputs in an arbitrarily nested configuration.

Raises:
  • ValueError: If any entry of idx is out of bounds with respect to the size of inputs.

SelectInput.connected_subgraphs

Returns the subgraphs created by this module so far.

SelectInput.defun()

Wraps this modules call method in a callable graph function.

SelectInput.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

SelectInput.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

SelectInput.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.graph

Returns the Graph instance which the module is connected to, or None.

SelectInput.is_connected

Returns true iff the Module been connected to the Graph at least once.

SelectInput.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.module_name

Returns the name of the Module.

SelectInput.name_scopes

Returns a tuple of all name_scopes generated by this module.

SelectInput.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.scope_name

Returns the full name of the Module's variable scope.

SelectInput.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SelectInput.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class SeparableConv1D

Performs an in-plane convolution to each channel independently.

This acts as a light wrapper around the TensorFlow op tf.nn.separable_conv2d, abstracting away variable creation and sharing.

SeparableConv1D.__init__(output_channels, channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NWC', custom_getter=None, name='separable_conv1d')

Constructs a SeparableConv1D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. Must be an integer.
  • channel_multiplier: Number of channels to expand pointwise (depthwise) convolution to. Must be an integer. Must be > 0. When channel_multiplier is set to 1, applies a different filter to each input channel. Numbers larger than 1 cause the filter to be applied to channel_multiplier input channels. Outputs are concatenated together.
  • kernel_shape: List with 2 elements in the following layout: [filter_height, filter_width] or integer that is used to define the list in all dimensions.
  • stride: List with 4 elements of kernel strides, or integer that is used to define stride in all dimensions. Layout of list: [1, stride_y, stride_x, 1].
  • rate: Sequence of dilation rates (of size 1), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard 1D convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 1.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with keys 'w_dw' for depthwise and 'w_pw' for pointwise) or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition the filters (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with keys 'w_dw' for depthwise and 'w_pw' for pointwise) and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NWC), or the second dimension ("NCW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • ValueError: If channel_multiplier isn't of type (numbers.Integral or tf.Dimension).
  • ValueError: If channel_multiplier is less than 1.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_1D_DATA_FORMATS). base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of one integer. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.IncompatibleShapeError: If a mask is a TensorFlow Tensor with a not fully defined shape. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.

  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w_dw', 'w_pw' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • TypeError: If mask is given and it is not convertible to a Tensor.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.

SeparableConv1D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

SeparableConv1D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

SeparableConv1D.channel_multiplier

Returns the channel multiplier argument.

SeparableConv1D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

SeparableConv1D.connected_subgraphs

Returns the subgraphs created by this module so far.

SeparableConv1D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

SeparableConv1D.data_format

Returns the data format.

SeparableConv1D.defun()

Wraps this modules call method in a callable graph function.

SeparableConv1D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

SeparableConv1D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.get_possible_initializer_keys(cls, use_bias=True)

SeparableConv1D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.graph

Returns the Graph instance which the module is connected to, or None.

SeparableConv1D.has_bias

Returns True if bias Variable is present in the module.

SeparableConv1D.initializers

Returns the initializers dictionary.

SeparableConv1D.input_channels

Returns the number of input channels.

SeparableConv1D.input_shape

Returns the input shape.

SeparableConv1D.is_connected

Returns true iff the Module been connected to the Graph at least once.

SeparableConv1D.kernel_shape

Returns the kernel shape.

SeparableConv1D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.mask

Returns the mask.

SeparableConv1D.module_name

Returns the name of the Module.

SeparableConv1D.name_scopes

Returns a tuple of all name_scopes generated by this module.

SeparableConv1D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.output_channels

Returns the number of output channels.

SeparableConv1D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

SeparableConv1D.paddings

Returns a tuple with the padding algorithm used for each dimension.

SeparableConv1D.partitioners

Returns the partitioners dictionary.

SeparableConv1D.rate

Returns the dilation rate.

SeparableConv1D.regularizers

Returns the regularizers dictionary.

SeparableConv1D.scope_name

Returns the full name of the Module's variable scope.

SeparableConv1D.stride

Returns the stride.

SeparableConv1D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv1D.w

Returns the Variable containing the weight matrix.

SeparableConv1D.w_dw

Returns the Variable containing the depthwise weight matrix.

SeparableConv1D.w_pw

Returns the Variable containing the pointwise weight matrix.

class SeparableConv2D

Performs an in-plane convolution to each channel independently.

This acts as a light wrapper around the TensorFlow op tf.nn.separable_conv2d, abstracting away variable creation and sharing.

SeparableConv2D.__init__(output_channels, channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', use_bias=True, initializers=None, partitioners=None, regularizers=None, data_format='NHWC', custom_getter=None, name='separable_conv2d')

Constructs a SeparableConv2D module.

See the following documentation for an explanation of VALID versus SAME padding modes: https://www.tensorflow.org/api_guides/python/nn#Convolution

Args:
  • output_channels: Number of output channels. Must be an integer.
  • channel_multiplier: Number of channels to expand pointwise (depthwise) convolution to. Must be an integer. Must be > 0. When channel_multiplier is set to 1, applies a different filter to each input channel. Numbers larger than 1 cause the filter to be applied to channel_multiplier input channels. Outputs are concatenated together.
  • kernel_shape: List with 2 elements in the following layout: [filter_height, filter_width] or integer that is used to define the list in all dimensions.
  • stride: List with 4 elements of kernel strides, or integer that is used to define stride in all dimensions. Layout of list: [1, stride_y, stride_x, 1].
  • rate: Sequence of dilation rates (of size 2), or integer that is used to define dilation rate in all dimensions. 1 corresponds to standard 2D convolution, rate > 1 corresponds to dilated convolution. Cannot be > 1 if any of stride is also > 1.
  • padding: Padding algorithm. Either snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL, or a sequence of these paddings of length 2.
    • snt.SAME and snt.VALID are explained in the Tensorflow docs at https://www.tensorflow.org/api_guides/python/nn#Convolution.
    • snt.FULL pre- and post-pads with the maximum padding which does not result in a convolution over just padded elements.
    • snt.CAUSAL pre-pads to ensure that each output value only depends on input values at the same or preceding indices ("no dependence on the future").
    • snt.REVERSE_CAUSAL post-pads to ensure that each output value only depends on input values at the same or greater indices ("no dependence on the past"). If you use the same padding for all dimensions, and it is one of SAME or VALID, then this is supported directly by the underlying convolution op. In all other cases, the input data will be padded using tf.pad before calling the convolution op.
  • use_bias: Whether to include bias parameters. Default True.
  • initializers: Optional dict containing ops to initialize the filters (with keys 'w_dw' for depthwise and 'w_pw' for pointwise) or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition the filters (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with keys 'w_dw' for depthwise and 'w_pw' for pointwise) and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • data_format: A string. Specifies whether the channel dimension of the input and output is the last dimension (default, NHWC), or the second dimension ("NCHW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • ValueError: If channel_multiplier isn't of type (numbers.Integral or tf.Dimension).
  • ValueError: If channel_multiplier is less than 1.
  • ValueError: If the given data_format is not a supported format (see SUPPORTED_2D_DATA_FORMATS). base.IncompatibleShapeError: If the given kernel shape is not an integer; or if the given kernel shape is not a sequence of two integers. base.IncompatibleShapeError: If the given stride is not an integer; or if the given stride is not a sequence of two integers. base.IncompatibleShapeError: If the given rate is not an integer; or if the given rate is not a sequence of two integers. base.IncompatibleShapeError: If a mask is a TensorFlow Tensor with a not fully defined shape. base.NotSupportedError: If rate in any dimension and the stride in any dimension are simultaneously > 1.

  • ValueError: If the given padding is not snt.VALID, snt.SAME, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a sequence of these.

  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w_dw', 'w_pw' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.
  • TypeError: If mask is given and it is not convertible to a Tensor.
  • ValueError: If the passed-in data_format doesn't have a channel dimension.

SeparableConv2D.__call__(inputs)

Connects the _ConvND module into the graph, with input Tensor inputs.

If this is not the first time the module has been connected to the graph, the input Tensor provided here must have the same number of channels, in order for the existing variables to be the correct size for the multiplication; the batch size and input spatial dimensions may differ for each connection.

Args:
  • inputs: A ND Tensor of the same rank as data_format, and either of types tf.float16, tf.bfloat16 or tf.float32.
Returns:

A ND Tensor of shape [batch_size, output_dim_1, output_dim_2, ..., output_channels].

Raises:
  • ValueError: If connecting the module into the graph any time after the first time and the inferred size of the input does not match previous invocations. base.IncompatibleShapeError: If the input tensor has the wrong number of dimensions. base.UnderspecifiedError: If the channel dimension of inputs isn't defined. base.IncompatibleShapeError: If a mask is present and its shape is incompatible with the shape of the weights.

  • TypeError: If input Tensor dtype is not compatible with either tf.float16, tf.bfloat16 or tf.float32.

SeparableConv2D.b

Returns the Variable containing the bias.

Returns:

Variable object containing the bias, from the most recent call.

Raises:

base.NotConnectedError: If the module has not been connected to the graph yet, meaning the variables do not exist.

  • AttributeError: If the module does not use bias.

SeparableConv2D.channel_multiplier

Returns the channel multiplier argument.

SeparableConv2D.clone(name=None)

Returns a cloned _ConvND module.

Args:
  • name: Optional string assigning name of cloned module. The default name is constructed by appending "_clone" to self.module_name.
Returns:

A copy of the current class.

SeparableConv2D.connected_subgraphs

Returns the subgraphs created by this module so far.

SeparableConv2D.conv_op_padding

Returns the padding algorithm used for the underlying convolution op.

SeparableConv2D.data_format

Returns the data format.

SeparableConv2D.defun()

Wraps this modules call method in a callable graph function.

SeparableConv2D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

SeparableConv2D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.get_possible_initializer_keys(cls, use_bias=True)

SeparableConv2D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.graph

Returns the Graph instance which the module is connected to, or None.

SeparableConv2D.has_bias

Returns True if bias Variable is present in the module.

SeparableConv2D.initializers

Returns the initializers dictionary.

SeparableConv2D.input_channels

Returns the number of input channels.

SeparableConv2D.input_shape

Returns the input shape.

SeparableConv2D.is_connected

Returns true iff the Module been connected to the Graph at least once.

SeparableConv2D.kernel_shape

Returns the kernel shape.

SeparableConv2D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.mask

Returns the mask.

SeparableConv2D.module_name

Returns the name of the Module.

SeparableConv2D.name_scopes

Returns a tuple of all name_scopes generated by this module.

SeparableConv2D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.output_channels

Returns the number of output channels.

SeparableConv2D.padding

Returns the padding algorithm used, if this is the same for all dims.

Use .paddings if you want a tuple with the padding algorithm used for each dimension.

Returns:

The padding algorithm used, if this is the same for all dimensions.

Raises:
  • ValueError: If different padding algorithms are used for different dimensions.

SeparableConv2D.paddings

Returns a tuple with the padding algorithm used for each dimension.

SeparableConv2D.partitioners

Returns the partitioners dictionary.

SeparableConv2D.rate

Returns the dilation rate.

SeparableConv2D.regularizers

Returns the regularizers dictionary.

SeparableConv2D.scope_name

Returns the full name of the Module's variable scope.

SeparableConv2D.stride

Returns the stride.

SeparableConv2D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SeparableConv2D.w

Returns the Variable containing the weight matrix.

SeparableConv2D.w_dw

Returns the Variable containing the depthwise weight matrix.

SeparableConv2D.w_pw

Returns the Variable containing the pointwise weight matrix.

class Sequential

Builds a module out of a sequence of callables.

Note that Sequential is limited in the range of possible architectures it can handle. This is a deliberate design decision; Sequential is only meant to be used for the simple case of fusing together modules/ops where the input of a particular module/op is the output of the previous one. Another restriction is that it is not possible to have extra arguments in the _build method that are passed to the constituents of the module - for example, if there is a BatchNorm module in Sequential and the user wishes to switch the is_training flag. If this is the desired use case, the recommended solution is to use snt.Module to wrap a custom function, as shown in the following example:

https://github.com/deepmind/sonnet/blob/master/sonnet/examples/module_with_build_args.py

Sequential.__init__(layers, name='sequential')

Constructs a Sequential module.

This feeds the output of each layer into the next and returns the output of the final layer.

If a layer returns a tuple, it is assumed that this must be unpacked into the argument list of the next layer. If it is not a tuple, it is simply passed through to the next layer unchanged.

Args:
  • layers: Iterable of callables to stack together, which can be modules or ops.
  • name: Name of the module.
Raises:
  • TypeError: If layers is None or contains any non-callable items.

Sequential.__call__(*args)

Connects the Sequential module into the graph.

Args:
  • *args: A tuple of inputs, to be unpacked as the arguments to the first layer.
Returns:

The output value of the last layer.

Sequential.connected_subgraphs

Returns the subgraphs created by this module so far.

Sequential.defun()

Wraps this modules call method in a callable graph function.

Sequential.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

Sequential.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Sequential.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

Sequential.get_variables(*args, **kwargs)

Provide a warning that get_variables on Sequential always returns ().

Sequential.graph

Returns the Graph instance which the module is connected to, or None.

Sequential.is_connected

Returns true iff the Module been connected to the Graph at least once.

Sequential.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Sequential.layers

Sequential.module_name

Returns the name of the Module.

Sequential.name_scopes

Returns a tuple of all name_scopes generated by this module.

Sequential.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Sequential.scope_name

Returns the full name of the Module's variable scope.

Sequential.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Sequential.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

Sequential.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class SkipConnectionCore

Adds a skip connection to the base RNN core.

The output of the wrapped core is the concatenation of the output of the base core with its input. The state of the wrapped core is the state of the base core.

SkipConnectionCore.__init__(base_core, input_shape=None, name='skip_connection_core')

Construct a SkipConnectionCore.

Args:
  • base_core: Base RNNCore to wrap.
  • input_shape: Shape of the input as tuple, excluding the batch size.
  • name: Name of the module.

SkipConnectionCore.__call__(inputs, prev_state, **kwargs)

SkipConnectionCore.connected_subgraphs

Returns the subgraphs created by this module so far.

SkipConnectionCore.defun()

Wraps this modules call method in a callable graph function.

SkipConnectionCore.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

SkipConnectionCore.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

SkipConnectionCore.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.graph

Returns the Graph instance which the module is connected to, or None.

SkipConnectionCore.initial_state(*args, **kwargs)

SkipConnectionCore.is_connected

Returns true iff the Module been connected to the Graph at least once.

SkipConnectionCore.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.module_name

Returns the name of the Module.

SkipConnectionCore.name_scopes

Returns a tuple of all name_scopes generated by this module.

SkipConnectionCore.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.output_size

SkipConnectionCore.scope_name

Returns the full name of the Module's variable scope.

SkipConnectionCore.state_size

SkipConnectionCore.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SkipConnectionCore.zero_state(*args, **kwargs)

class SliceByDim

Slices a tensor along specific dimensions.

The user can slice a tensor by specifying only the list of dimensions that they want to slice, together with the lists of integers containing the beginning indices of the slicing, and the size of the slices. Hence, with SliceByDim slicing can be performed without knowing in advance the rank of the input tensor.

Tensorflow also offers a built-in op performing slicing, tf.slice. However, tf.slice requires all the slicing dimensions to be specified, using wildcards when no slicing is required. For example, with tf.slice, slicing half a 5D tensor along dimension 1 would be:

output = tf.slice(inputs,
                  begin=[0, 0, 0, 0, 0],
                  size=[-1, inputs.get_shape()[1].value//2, -1, -1, -1])

The same operation using SliceByDim would be:

output = SliceByDim(dims=[1], begin=[0], size=[x.get_shape()[1].value//2])(x)

SliceByDim can be used to specify multiple slicing dimensions, for example:

output = SliceByDim(dims=[1, 3], begin=[0, 0], size=[12, 24])(x)

SliceByDim.__init__(dims, begin, size, name='slice_by_dim')

Constructs the SliceByDim module.

Args:
  • dims: The dimensions to slice along, as a list of unique integers. Negative integers index from the final dimension backwards, as in python arrays.
  • begin: The beginning indices of the slicing, as a list of integers. Must be the same length as the dims list.
  • size: The size of the slices, as a list of integers. Must be the same length as the dims list.
  • name: The name of the module.
Raises:
  • ValueError: If dims has non-unique integers, or if the size of begin is different from the size of dims, or if the size of size is different from the size of dims.

SliceByDim.__call__(inputs)

Connects the SliceByDim module into the graph.

Args:
  • inputs: Tensor to slice. Its rank must be greater than the maximum dimension specified in dims (plus one as python is 0 indexed).
Returns:

The sliced tensor.

Raises:
  • ValueError: If inputs tensor has insufficient rank.

SliceByDim.connected_subgraphs

Returns the subgraphs created by this module so far.

SliceByDim.defun()

Wraps this modules call method in a callable graph function.

SliceByDim.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

SliceByDim.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

SliceByDim.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.graph

Returns the Graph instance which the module is connected to, or None.

SliceByDim.is_connected

Returns true iff the Module been connected to the Graph at least once.

SliceByDim.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.module_name

Returns the name of the Module.

SliceByDim.name_scopes

Returns a tuple of all name_scopes generated by this module.

SliceByDim.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.scope_name

Returns the full name of the Module's variable scope.

SliceByDim.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

SliceByDim.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class TileByDim

Tile a tensor along specific dimensions.

The user can tile a tensor by specifying only the list of dimensions that they want to tile, together with the lists of integers containing the multiples of the tiling. Hence, with TileByDim tiling can be performed without knowing in advance the rank of the input tensor.

Tensorflow also offers a built-in op performing tiling, tf.tile. However, tf.tile requires all the tiling dimensions to be specified, using 1 when no tiling is required. For example, with tf.tiling, tiling a 5D tensor along dimension 1, by 2 would be:

output = tf.tile(inputs, multiples=[1, 2, 1, 1, 1])

The same operation using TileByDim would be:

output = TileByDim(dims=[1], multiples=[2])(x)

TileByDim can be used to specify multiple tiling dimensions, for example:

output = TileByDim(dims=[1, 3], multiples=[2, 4])(x)

TileByDim.__init__(dims, multiples, name='tile_by_dim')

Constructs the TileByDim module.

Args:
  • dims: The dimensions to tile along, as a list of unique integers.
  • multiples: The multiple of the tiling, as a list of integers. Must be the same length as the dims list.
  • name: The name of the module.
Raises:
  • ValueError: If dims has non-unique integers, or if the size of multiples is different from the size of dims.

TileByDim.__call__(inputs)

Connects the TileByDim module into the graph.

Args:
  • inputs: Tensor to tile.
Returns:

The tiled tensor.

TileByDim.connected_subgraphs

Returns the subgraphs created by this module so far.

TileByDim.defun()

Wraps this modules call method in a callable graph function.

TileByDim.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

TileByDim.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

TileByDim.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.graph

Returns the Graph instance which the module is connected to, or None.

TileByDim.is_connected

Returns true iff the Module been connected to the Graph at least once.

TileByDim.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.module_name

Returns the name of the Module.

TileByDim.name_scopes

Returns a tuple of all name_scopes generated by this module.

TileByDim.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.scope_name

Returns the full name of the Module's variable scope.

TileByDim.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TileByDim.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class TrainableInitialState

Helper Module that creates a learnable initial state for an RNNCore.

This class receives an example (possibly nested) initial state of an RNNCore, and returns a state that has the same shape, structure, and values, but is trainable. Additionally, the user may specify a boolean mask that indicates which parts of the initial state should be trainable.

This allows users to train an unrolled RNNCore with a learnable initial state in the following way:

core = ... # Any RNNCore module object.
initial_state = core.initial_state(batch_size, dtype)
trainable_initial_state = snt.TrainableInitialState(initial_state)()
output, final_state = tf.nn.dynamic_rnn(
    core, input_sequence, initial_state=trainable_initial_state)

TrainableInitialState.__init__(initial_state, mask=None, name='trainable_initial_state')

Constructs the Module that introduces a trainable state in the graph.

It receives an initial state that will be used as the initial values for the trainable variables that the module contains, and optionally a mask that indicates the parts of the initial state that should be learnable.

Args:
  • initial_state: tensor or arbitrarily nested iterables of tensors.
  • mask: optional boolean mask. It should have the same nested structure as the given initial_state.
  • name: module name.
Raises:
  • TypeError: if mask is not a list of booleans or None.

TrainableInitialState.__call__()

Connects the module to the graph.

Returns:

The learnable state, which has the same type, structure and shape as the initial_state passed to the constructor.

TrainableInitialState.connected_subgraphs

Returns the subgraphs created by this module so far.

TrainableInitialState.defun()

Wraps this modules call method in a callable graph function.

TrainableInitialState.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

TrainableInitialState.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

TrainableInitialState.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.graph

Returns the Graph instance which the module is connected to, or None.

TrainableInitialState.is_connected

Returns true iff the Module been connected to the Graph at least once.

TrainableInitialState.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.module_name

Returns the name of the Module.

TrainableInitialState.name_scopes

Returns a tuple of all name_scopes generated by this module.

TrainableInitialState.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.scope_name

Returns the full name of the Module's variable scope.

TrainableInitialState.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableInitialState.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class TrainableVariable

Provides learnable parameter Tensor.

TrainableVariable.__init__(shape, dtype=tf.float32, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='trainable_variable')

Constructs a TrainableVariable module.

Args:
  • shape: Tensor shape.
  • dtype: Tensor data type.
  • initializers: Optional dictionary containing ops to initialize the weight Tensor, with key 'w'.
  • partitioners: Optional dict containing a partitioner to partition the weight (with key 'w'). As a default, no partitioner is used.
  • regularizers: Optional dict containing regularizers for the weights (with key 'w'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • custom_getter: Optional callable or dictionary of callables to use as custom_getter for the module.
  • name: Name of the module.
Raises:
  • KeyError: If initializers contains any keys other than 'w'.
  • KeyError: If partitioners contains any keys other than 'w'.
  • KeyError: If regularizers contains any keys other than 'w'.
  • TypeError: If any of the given initializers are not callable.
  • TypeError: If any of the given partitioners are not callable.
  • TypeError: If any of the given regularizers are not callable.

TrainableVariable.__call__()

Connects the TrainableTensor module into the graph.

Returns:

A Tensor of shape as determined in the constructor.

TrainableVariable.connected_subgraphs

Returns the subgraphs created by this module so far.

TrainableVariable.defun()

Wraps this modules call method in a callable graph function.

TrainableVariable.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

TrainableVariable.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

TrainableVariable.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.graph

Returns the Graph instance which the module is connected to, or None.

TrainableVariable.is_connected

Returns true iff the Module been connected to the Graph at least once.

TrainableVariable.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.module_name

Returns the name of the Module.

TrainableVariable.name_scopes

Returns a tuple of all name_scopes generated by this module.

TrainableVariable.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.scope_name

Returns the full name of the Module's variable scope.

TrainableVariable.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

TrainableVariable.w

Returns the Variable containing the weights Tensor.

Returns:

Variable object containing the weights, from the most recent call.

Raises:

base.Error: If the module has not been connected to the graph yet, meaning the variables do not exist.

class Transposable

Transposable module interface.

The Transposable interface requires that transposable modules implement a method called transpose, returning a module that is the transposed version of the one the method is called on. Calling the method twice should return a module with the same specifications as the original module.

When implementing a transposable module, special care is required to make sure that parameters needed to instantiate the module are provided as functions whose invocation is deferred to graph construction time.

For example, in Linear we might want to call:

linear = snt.Linear(name="linear", output_size=output_size)
linear_transpose = linear.transpose()

where the output_size for linear_transpose is not known yet, as linear is not yet connected to the graph: output_size is passed to linear_transpose's constructor as a lambda returning linear.input_size. The lambda will return the correct value once linear is given an input. Notice that linear_transpose's output_size value does not need to be defined until the module is connected to the graph.

Transposable.input_shape()

Returns shape of input Tensor passed at last call to build.

Transposable.transpose(name=None, **kwargs)

Builds and returns transposed version of module.

Args:
  • name: Name of the transposed module.
  • **kwargs: Additional Python flags controlling transposition.
Returns:

Transposed version of the module.

class UnderspecifiedError

Error raised when too little information is available.

This does not typically mean the user is trying to do something that doesn't work (in which case IncompatibleShapeError should be used), just that some more information needs to be provided in order to build the Graph.

class VanillaRNN

Basic fully connected vanilla RNN core.

VanillaRNN.__init__(hidden_size, activation=tanh, initializers=None, partitioners=None, regularizers=None, name='vanilla_rnn')

Construct a Basic RNN core.

Args:
  • hidden_size: hidden size dimensionality.
  • activation: activation function to use.
  • initializers: optional dict containing ops to initialize the weights. This dictionary may contain the keys 'in_to_hidden' and/or 'hidden_to_hidden'.
  • partitioners: optional dict containing ops to partition the weights. This dictionary may contain the keys 'in_to_hidden' and/or 'hidden_to_hidden'.
  • regularizers: optional dict containing ops to regularize the weights. This dictionary may contain the keys 'in_to_hidden' and/or 'hidden_to_hidden'.
  • name: name of the module.
Raises:
  • KeyError: if initializers contains any keys other than 'in_to_hidden' or 'hidden_to_hidden'.
  • KeyError: if partitioners contains any keys other than 'in_to_hidden' or 'hidden_to_hidden'.
  • KeyError: if regularizers contains any keys other than 'in_to_hidden' or 'hidden_to_hidden'.
  • TypeError: If any of the given initializers are not callable.
  • TypeError: If any of the given partitioners are not callable.
  • TypeError: If any of the given regularizers are not callable.

VanillaRNN.__call__(input_, prev_state)

Connects the VanillaRNN module into the graph.

If this is not the first time the module has been connected to the graph, the Tensors provided as input_ and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection.

Args:
  • input_: a 2D Tensor of size [batch_size, input_size].
  • prev_state: a 2D Tensor of size [batch_size, hidden_size].
Returns:
  • output: a 2D Tensor of size [batch_size, hidden_size].
  • next_state: a Tensor of size [batch_size, hidden_size].
Raises:
  • ValueError: if connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations.

VanillaRNN.connected_subgraphs

Returns the subgraphs created by this module so far.

VanillaRNN.defun()

Wraps this modules call method in a callable graph function.

VanillaRNN.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

VanillaRNN.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

VanillaRNN.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.graph

Returns the Graph instance which the module is connected to, or None.

VanillaRNN.hidden_to_hidden_linear

VanillaRNN.hidden_to_hidden_variables

VanillaRNN.in_to_hidden_linear

VanillaRNN.in_to_hidden_variables

VanillaRNN.initial_state(batch_size, dtype=tf.float32, trainable=False, trainable_initializers=None, trainable_regularizers=None, name=None, **unused_kwargs)

Builds the default start state for an RNNCore.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • dtype: The data type to use for the state.
  • trainable: Boolean that indicates whether to learn the initial state. Note that intializers and regularizers will be ignored if trainable=False.
  • trainable_initializers: An initializer function or nested structure of functions with same structure as the state_size property of the core, to be used as initializers of the initial state variable.
  • trainable_regularizers: Optional regularizer function or nested structure of functions with the same structure as the state_size property of the core, to be used as regularizers of the initial state variable. As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • name: Optional string used to prefix the initial state variable names, in the case of a trainable initial state. If not provided, defaults to the name of the module.
Returns:

A tensor or nested tuple of tensors with same structure and shape as the state_size property of the core.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

VanillaRNN.is_connected

Returns true iff the Module been connected to the Graph at least once.

VanillaRNN.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.module_name

Returns the name of the Module.

VanillaRNN.name_scopes

Returns a tuple of all name_scopes generated by this module.

VanillaRNN.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.output_size

VanillaRNN.scope_name

Returns the full name of the Module's variable scope.

VanillaRNN.state_size

VanillaRNN.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

VanillaRNN.zero_state(batch_size, dtype)

Return zero-filled state tensor(s).

Args:
  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size x state_size] filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D tensors with the shapes [batch_size x s] for each s in state_size.

check_initializers(initializers, keys)

Checks the given initializers.

This checks that initializers is a dictionary that only contains keys in keys, and furthermore the entries in initializers are functions or further dictionaries (the latter used, for example, in passing initializers to modules inside modules) that must satisfy the same constraints.

Args:
  • initializers: Dictionary of initializers (allowing nested dictionaries) or None.
  • keys: Iterable of valid keys for initializers.
Returns:

Copy of checked dictionary of initializers. If initializers=None, an empty dictionary will be returned.

Raises:
  • KeyError: If an initializer is provided for a key not in keys.
  • TypeError: If a provided initializer is not a callable function, or initializers is not a Mapping.

check_partitioners(partitioners, keys)

Checks the given partitioners.

This checks that partitioners is a dictionary that only contains keys in keys, and furthermore the entries in partitioners are functions or further dictionaries (the latter used, for example, in passing partitioners to modules inside modules) that must satisfy the same constraints.

Args:
  • partitioners: Dictionary of partitioners (allowing nested dictionaries) or None.
  • keys: Iterable of valid keys for partitioners.
Returns:

Checked dictionary of partitioners. If partitioners=None, an empty dictionary will be returned.

Raises:
  • KeyError: If an partitioner is provided for a key not in keys.
  • TypeError: If a provided partitioner is not a callable function, or partitioners is not a Mapping.

check_regularizers(regularizers, keys)

Checks the given regularizers.

This checks that regularizers is a dictionary that only contains keys in keys, and furthermore the entries in regularizers are functions or further dictionaries (the latter used, for example, in passing regularizers to modules inside modules) that must satisfy the same constraints.

Args:
  • regularizers: Dictionary of regularizers (allowing nested dictionaries) or None.
  • keys: Iterable of valid keys for regularizers.
Returns:

Copy of checked dictionary of regularizers. If regularizers=None, an empty dictionary will be returned.

Raises:
  • KeyError: If an regularizers is provided for a key not in keys.
  • TypeError: If a provided regularizer is not a callable function, or regularizers is not a Mapping.

clip_gradient(net, clip_value_min, clip_value_max, name=None)

Clips respective gradients of a given tensor.

Acts as identity for the forward pass, but clips gradient tensor element-wise by value during the backward pass. Any gradient values less than clip_value_min or greater than clip_values_max are set to the respective limit values.

Args:
  • net: A tf.Tensor.
  • clip_value_min: A 0-D Tensor or scalar. The minimum value to clip by.
  • clip_value_max: A 0-D Tensor or scalar. The maximum value to clip by.
  • name: A name for the operation (optional, default 'clip_gradient').
Returns:

A tf.Tensor with the same type as the input tensor.

Raises:
  • ValueError: If net dtype is non-float.

count_variables_by_type(variables=None)

Returns a dict mapping dtypes to number of variables and scalars.

Args:
  • variables: iterable of tf.Variables, or None. If None is passed, then all global and local variables in the current graph are used.
Returns:

A dict mapping tf.dtype keys to a dict containing the keys 'num_scalars' and 'num_variables'.

custom_getter_router(custom_getter_map, name_fn)

Creates a custom getter than matches requests to dict of custom getters.

Custom getters are callables which implement the [custom getter API] (https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/get_variable).

The returned custom getter dispatches calls based on pattern matching the name of the requested variable to the keys of custom_getter_map. For example,

{
  ".*/w": snt.custom_getters.stop_gradient,
}

will match all variables named with the suffix "/w". The name_fn is provided to allow processing of the name, such as stripping off a scope prefix before matching.

Args:
  • custom_getter_map: Mapping of regular expressions to custom getter functions.
  • name_fn: Callable to map variable name through before matching to regular expressions. This might, for example, strip off a scope prefix.
Returns:

A custom getter.

Raises:
  • TypeError: If an entry in custom_getter_map is not a callable function.

deprecation_warning(deprecation_message)

Log a warning message the user is using deprecated functionality.

format_variable_map(variable_map, join_lines=True)

Takes a key-to-variable map and formats it as a table.

format_variables(variables, join_lines=True)

Takes a collection of variables and formats it as a table.

get_normalized_variable_map(scope_or_module, collection='variables', context=None, group_sliced_variables=True)

Builds map of tf.Variables in scope or module with normalized names.

The names of the variables are normalized to remove the scope prefix.

Args:
  • scope_or_module: Scope or module to build map from.
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.GLOBAL_VARIABLES, which includes non-trainable variables such as moving averages.
  • context: Scope or module, identical to or parent of scope. If given, this will be used as the stripped prefix. By default None, which means context=scope.
  • group_sliced_variables: Boolean, if set to True, sliced variables are grouped together in the returned map; if set to False, each partition of a sliced variable is a separate (key, value) pair.
Returns:

Dictionary mapping normalized variable name to tf.Variable, or a list of tf.Variables if the variable is a sliced (partitioned) variable.

Raises:
  • ValueError: If context is given but is not a proper prefix of scope.

get_saver(scope, collections=('variables',), context=None, **kwargs)

Builds a tf.train.Saver for the scope or module, with normalized names.

The names of the variables are normalized to remove the scope prefix. This allows the same variables to be restored into another similar scope or module using a complementary tf.train.Saver object.

Args:
  • scope: Scope or module. Variables within will be saved or restored.
  • collections: Sequence of collections of variables to restrict tf.train.Saver to. By default this is tf.GraphKeys.GLOBAL_VARIABLES which includes moving averages variables as well as trainable variables.
  • context: Scope or module, identical to or parent of scope. If given, this will be used as the stripped prefix.
  • **kwargs: Extra keyword arguments to pass to tf.train.Saver.
Returns:

A tf.train.Saver object for Variables in the scope or module.

get_variables_in_module(module, collection='trainable_variables')

Returns tuple of tf.Variables declared inside an snt.Module.

Note that this operates by searching the variable scope a module contains, and so does not know about any modules which were constructed elsewhere but used inside this module.

Args:
  • module: snt.Module instance to query the scope of.
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

get_variables_in_scope(scope, collection='trainable_variables')

Returns a tuple tf.Variables in a scope for a given collection.

Args:
  • scope: tf.VariableScope or string to retrieve variables from.
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

has_variable_scope(obj)

Determines whether the given object has a variable scope.

highway_core_with_recurrent_dropout(hidden_size, num_layers, keep_prob=0.5, **kwargs)

Highway core with recurrent dropout.

Args:
  • hidden_size: (int) Hidden size dimensionality.
  • num_layers: (int) Number of highway layers.
  • keep_prob: the probability to keep an entry when applying dropout.
  • **kwargs: Extra keyword arguments to pass to the highway core.
Returns:

A tuple (train_core, test_core) where train_core is a higway core with recurrent dropout enabled to be used for training and test_core is the same highway core without recurrent dropout.

log_variables(variables=None)

Logs variable information.

This function logs the name, shape, type, collections, and device for either all variables or a given iterable of variables. In the "Device" columns, the nature of the variable (legacy or resource (for ResourceVariables)) is also specified in parenthesis.

Args:
  • variables: iterable of variables; if not provided, then all variables (in the default graph) are logged.

lstm_with_recurrent_dropout(hidden_size, keep_prob=0.5, **kwargs)

LSTM with recurrent dropout.

Args:
  • hidden_size: the LSTM hidden size.
  • keep_prob: the probability to keep an entry when applying dropout.
  • **kwargs: Extra keyword arguments to pass to the LSTM.
Returns:

A tuple (train_lstm, test_lstm) where train_lstm is an LSTM with recurrent dropout enabled to be used for training and test_lstm is the same LSTM without recurrent dropout.

lstm_with_zoneout(hidden_size, keep_prob_c=0.5, keep_prob_h=0.95, **kwargs)

LSTM with recurrent dropout.

Args:
  • hidden_size: the LSTM hidden size.
  • keep_prob_c: the probability to use the new value of the cell state rather than freezing it.
  • keep_prob_h: the probability to use the new value of the hidden state rather than freezing it.
  • **kwargs: Extra keyword arguments to pass to the LSTM.
Returns:

A tuple (train_lstm, test_lstm) where train_lstm is an LSTM with recurrent dropout enabled to be used for training and test_lstm is the same LSTM without zoneout.

merge_leading_dims(array_or_tensor, n_dims=2)

Merge the first dimensions of a tensor.

Args:
  • array_or_tensor: Tensor to have its first dimensions merged. Can also be an array or numerical value, which will be converted to a tensor for batch application, if needed.
  • n_dims: Number of dimensions to merge.
Returns:

Either the input value converted to a Tensor, with the requested dimensions merged, or the unmodified input value if the input has less than n_dims dimensions.

Raises:
  • ValueError: If the rank of array_or_tensor is not well-defined.

observe_connections(observer)

Notifies the observer whenever any Sonnet module is connected to the graph.

If a module contains nested modules, the observer is notified once for each nested module, followed by the containing module.

For example:

def logging_observer(connected_subgraph):
  logging.info(connected_subgraph.module.module_name)

with snt.observe_connections(logging_observer):
  output = imagenet_module(input_tensor)
Args:
  • observer: Callable accepting a single argument. Will be called with a ConnectedSubGraph each time a module is connected to the graph.
Yields:
  • None: just yields control to the inner context.

parse_string_to_constructor(ctor_string)

Returns a callable which corresponds to the constructor string.

Various modules (eg, ConvNet2D) take constructor arguments which are callables, indicating a submodule to build. These can be passed as actual constructors, eg snt.LayerNorm, however that makes the config for that module not trivially serializable. This function tries to map a string representation to the underlying callable, allowing configs to remain serializable where necessary.

Args:
  • ctor_string: string representing some module in Sonnet. If the string is provided with no dots, we assume it is a member of Sonnet available at top level, i.e. "LayerNorm" maps to snt.LayerNorm.
Raises:
  • ValueError: if no matching constructor can be found.
Returns:

Callable constructor which corresponds to ctor_string.

remove_unsupported_kwargs(module_or_fn, all_kwargs_dict)

Removes any kwargs not supported by module_or_fn from all_kwargs_dict.

A new dict is return with shallow copies of keys & values from all_kwargs_dict, as long as the key is accepted by module_or_fn. The returned dict can then be used to connect module_or_fn (along with some other inputs, ie non-keyword arguments, in general).

snt.supports_kwargs is used to tell whether a given kwarg is supported. Note that this method may give false negatives, which would lead to extraneous removals in the result of this function. Please read the docstring for snt.supports_kwargs for details, and manually inspect the results from this function if in doubt.

Args:
  • module_or_fn: some callable which can be interrogated by snt.supports_kwargs. Generally a Sonnet module or a method (wrapped in @reuse_variables) of a Sonnet module.
  • all_kwargs_dict: a dict containing strings as keys, or None.
Raises:
  • ValueError: if all_kwargs_dict is not a dict.
Returns:

A dict containing some subset of the keys and values in all_kwargs_dict. This subset may be empty. If all_kwargs_dict is None, this will be an empty dict.

reuse_variables(method)

Wraps an arbitrary method so it does variable sharing.

This decorator creates variables the first time it calls method, and reuses them for subsequent calls. The object that calls method provides a tf.VariableScope, either as a variable_scope attribute or as the return value of an _enter_variable_scope() method.

The first time the wrapped method is invoked, it enters the caller's tf.VariableScope with reuse=False. On all subsequent calls it enters the same variable scope with reuse=True.

Variables are created in the context of the tf.VariableScope provided by the caller object. Ops are created with an additional tf.name_scope(), which adds a scope for the wrapped method name. For example:

class MyClass(object):

  def __init__(self, name):
    with tf.variable_scope(None, default_name=name) as variable_scope:
      self.variable_scope = variable_scope

  @snt.reuse_variables
  def add_x(self, tensor):
    x = tf.get_variable("x", shape=tensor.get_shape())
    return tensor + x

module = MyClass("my_module_name")
input_tensor = tf.zeros(shape=(5,))

# This creates the variable "my_module_name/x"
# and op "my_module_name/add_x/add"
output = module.add_x(input_tensor)

For performance when executing eagerly it may be desirable to additionally annotate these methods using defun, such that they are encapsulated as graph functions. This is not recommended if your method returns a variable since the output of defun would be an op that returned the variable's value when evaluated (rather than the variable instance).

class FooModule(snt.AbstractModule):
  def _build(self, inputs):
    return complex_math(inputs)

  @tfe.defun
  @snt.reuse_variables
  def more_complex_stuff(self, inputs):
    return more_complex_math(inputs)
Args:
  • method: The method to wrap.
Returns:

The wrapped method.

scale_gradient(net, scale, name='scale_gradient')

Scales gradients for the backwards pass.

This might be used to, for example, allow one part of a model to learn at a lower rate than the rest.

WARNING: Think carefully about how your optimizer works. If, for example, you use rmsprop, the gradient is always rescaled (with some additional epsilon) towards unity. This means scale_gradient won't have the effect of lowering the learning rate.

If scale is 0.0, this op reduces to tf.stop_gradient. If scale is 1.0, this op reduces to tf.identity.

Args:
  • net: A tf.Tensor or in eager mode a callable that produces a tf.Tensor.
  • scale: The scale factor for the gradient on the backwards pass.
  • name: A name for the operation (optional).
Returns:

In graph mode returns a tf.Tensor with the same type as the input tensor. In eager mode returns a callable wrapping net whose gradients are scaled.

Raises:
  • ValueError: If net dtype is non-float and scale is not zero or one.

split_leading_dim(tensor, inputs, n_dims=2)

Split the first dimension of a tensor.

Args:
  • tensor: Tensor to have its first dimension split.
  • inputs: Original reference input to look the dimensions of.
  • n_dims: Number of dimensions to split.
Returns:

The input tensor, with its first dimension split.

summarize_variables(variables=None)

Logs a summary of variable information.

This function groups Variables by dtype and prints out the number of Variables and the total number of scalar values for each datatype, as well as the total memory consumed.

For Variables of type tf.string, the memory usage cannot be accurately calculated from the Graph as the memory requirements change based on what strings are actually stored, which can only be determined inside a session. In this case, the amount of memory used to stored the pointers to the strings is logged, along with a warning.

Args:
  • variables: iterable of variables; if not provided, then all variables (in the default graph) are summarized.

supports_kwargs(module_or_fn, kwargs_list)

Determines whether the provided callable supports all the kwargs.

This is useful when you have a module that might or might not support a kwarg such as is_training. Rather than calling the module and catching the error, risking the potential modification of underlying state, this function introspects the module to see what kwargs are actually supported, using the python inspect module.

Note that many TF functions do not export a valid argspec object, rather they have a generic args, *kwargs signature due to various layers of wrapping (deprecation decorators, etc). In those circumstances we return MAYBE_SUPPORTED, and users will have to use another method to tell whether the kwargs are supported (e.g. by just calling the function).

Args:
  • module_or_fn: some callable, generally an object or a method of some object. If an object is provided, we check wither module_or_fn.__call__ supports the provided kwargs, which for a Sonnet module will automatically check the signature of _build. If module_or_fn is a function/method, then we check its signature directly, so non-Sonnet functions can be used.
  • kwargs_list: string or iterable of strings of keyword arg names to test for. If an empty iterable is provided this function will always return True.
Raises:
  • ValueError: if a non-string is provided in kwargs_list.
Returns:

a string, one of 'supported', 'not_supported' or 'maybe_supported'.

trainable_initial_state(batch_size, state_size, dtype, initializers=None, regularizers=None, name=None)

Creates an initial state consisting of trainable variables.

The trainable variables are created with the same shapes as the elements of state_size and are tiled to produce an initial state.

Args:
  • batch_size: An int, or scalar int32 Tensor representing the batch size.
  • state_size: A TensorShape or nested tuple of TensorShapes to use for the shape of the trainable variables.
  • dtype: The data type used to create the variables and thus initial state.
  • initializers: An optional container of the same structure as state_size containing initializers for the variables.
  • regularizers: An optional container of the same structure as state_size containing regularizers for the variables.
  • name: optional string used to prefix the initial state variable names.
Returns:

A Tensor or nested tuple of Tensors with the same size and structure as state_size, where each Tensor is a tiled trainable Variable.

Raises:
  • ValueError: if the user passes initializers that are not functions.
  • ValueError: if the user passes regularizers that are not functions.

variable_map_items(variable_map)

Yields an iterator over (string, variable) pairs in the variable map.

In general, variable maps map variable names to either a tf.Variable, or list of tf.Variables (in case of sliced variables).

Args:
  • variable_map: dict, variable map over which to iterate.
Yields:

(string, tf.Variable) pairs.

class custom_getters.Context

Contextually switching a custom getter on.

Example usage, using snt.custom_getters.stop_gradient with Context to selectively disable gradients flowing to variables for particular connections of the module:

  custom_getter = snt.custom_getters.Context(snt.custom_getters.stop_gradient)
  lin = snt.Linear(10, custom_getter=custom_getter)

  lin(net1)  # custom getter not used, gradients on
  with custom_getter:
    lin(net2)  # custom getter used, gradients off

Warning: If the custom getter affects the way the variable is created, then switching it on or off after the variable has been created will have no effect. For example, it is not possible to selectively switch off trainability using custom_getters.non_trainable, since this is a creation-time attribute. It is however possible to selectively switch off gradients using custom_getters.stop_gradient, since this applies an operation to the variable.

custom_getters.Context.__init__(getter, verbose=False)

Initializes a contextual switch for a custom getter.

Args:
  • getter: The custom getter which we may want to switch on.
  • verbose: Log out every time a variable is fetched, and whether or not getter is used.
Returns:

A custom getter which can also be used as a context manager. Entering the context enables the custom getter.

custom_getters.non_trainable(getter, *args, **kwargs)

Custom getter which makes a variable non-trainable.

Usage like:

with tf.variable_scope("", custom_getter=snt.custom_getters.non_trainable): net = snt.Linear(num_hidden)(net)

or, using the custom_getter constructor argument,

linear = snt.Linear(num_hidden, custom_getter=snt.custom_getters.non_trainable) net = linear(net)

will result in the variables inside the linear having trainable=False, i.e. won't be added to tf.trainable_variables() and thus won't be optimized.

Warning: If reuse=True and the variable has previously been created in the same graph with trainable=True, this custom getter will do nothing. Similarly if the variable is reused after being created by this custom getter it will still be non-trainable, even if trainable=True.

When used with a Sonnet module, the module must be constructed inside the variable scope with the custom getter. Just building the module inside said variable scope will not use the custom getter.

Args:
  • getter: The true getter to call.
  • *args: Arguments, in the same format as tf.get_variable.
  • **kwargs: Keyword arguments, in the same format as tf.get_variable.
Returns:

The return value of getter(*args, **kwargs) except with trainable=False enforced.

custom_getters.override_args(**kwargs)

Creates a custom getter that applies specified named arguments.

Args:
  • **kwargs: Overriding arguments for the custom getter to use in preference the named arguments it's called with.
Returns:

Custom getter.

custom_getters.override_default_args(**kwargs)

Creates a custom getter that applies specified named arguments.

The returned custom getter treats the specified named arguments as revised defaults, and does not override any non-None argument values supplied by the original get_variable call (or by a nested scope's custom getter).

Args:
  • **kwargs: Overriding arguments for the custom getter to use in preference the named arguments it's called with.
Returns:

Custom getter.

custom_getters.restore_initializer(filename, name_fn=None, collection='variables')

Custom getter to restore all variables with snt.restore_initializer.

Args:
  • filename: The filename of the checkpoint.
  • name_fn: A function which can map the name of the variable requested. This allows restoring variables with values having different names in the checkpoint.
  • collection: Only set the restore initializer for variables in this collection. If None, it will attempt to restore all variables. By default tf.GraphKeys.GLOBAL_VARIABLES.
Returns:

A restore_initializer custom getter, which is a function taking arguments (getter, name, args, *kwargs).

custom_getters.stop_gradient(getter, *args, **kwargs)

Custom getter which prevents variables being optimized.

Usage like:

with tf.variable_scope("", custom_getter=snt.custom_getters.stop_gradient): net = snt.Linear(num_hidden)(net)

or, using the custom_getter constructor argument,

linear = snt.Linear(num_hidden, custom_getter=snt.custom_getters.stop_gradient) net = linear(net)

will result in the gradient with respect to the variables in the linear module being None. By default, the variables will still be in the trainable variables collection.

When used with a Sonnet module, the module must be constructed inside the variable scope with the custom getter. Just building the module inside said variable scope will not use the custom getter.

Args:
  • getter: The true getter to call.
  • *args: Arguments, in the same format as tf.get_variable.
  • **kwargs: Keyword arguments, in the same format as tf.get_variable.
Returns:

The return value of getter(*args, **kwargs) with a tf.stop_gradient.

class custom_getters.bayes_by_backprop.EstimatorModes

class custom_getters.bayes_by_backprop._VariableMetadata

VariableMetadata(raw_variable_name, raw_variable_shape, scope_name, posterior, posterior_estimate, prior, kl_cost, prior_vars, posterior_vars)

custom_getters.bayes_by_backprop._VariableMetadata.kl_cost

Alias for field number 6

custom_getters.bayes_by_backprop._VariableMetadata.posterior

Alias for field number 3

custom_getters.bayes_by_backprop._VariableMetadata.posterior_estimate

Alias for field number 4

custom_getters.bayes_by_backprop._VariableMetadata.posterior_vars

Alias for field number 8

custom_getters.bayes_by_backprop._VariableMetadata.prior

Alias for field number 5

custom_getters.bayes_by_backprop._VariableMetadata.prior_vars

Alias for field number 7

custom_getters.bayes_by_backprop._VariableMetadata.raw_variable_name

Alias for field number 0

custom_getters.bayes_by_backprop._VariableMetadata.raw_variable_shape

Alias for field number 1

custom_getters.bayes_by_backprop._VariableMetadata.scope_name

Alias for field number 2

custom_getters.bayes_by_backprop.adaptive_gaussian_prior_builder(getter, name, *args, **kwargs)

A pre-canned builder for adaptive scalar gaussian prior distributions.

Given a true getter function and arguments forwarded from tf.get_variable, return a distribution object for a scalar-valued adaptive gaussian prior which will be broadcast over a variable of the requisite shape. This prior's parameters (e.g loc and scale for a gaussian) will consist of a single learned scalar for the entire tf.Variable for which it serves as the prior, regardless of that tf.Variable's shape.

Args:
  • getter: The getter passed to a custom_getter. Please see the documentation for tf.get_variable.
  • name: The name argument passed to tf.get_variable.
  • *args: See positional arguments passed to tf.get_variable.
  • **kwargs: See keyword arguments passed to tf.get_variable.
Returns:

An instance of tfp.distributions.Normal representing the prior distribution over the variable in question.

custom_getters.bayes_by_backprop.analytic_kl_builder(posterior, prior, sample)

A pre-canned builder for the analytic kl divergence.

custom_getters.bayes_by_backprop.bayes_by_backprop_getter(posterior_builder=diagonal_gaussian_posterior_builder, prior_builder=fixed_gaussian_prior_builder, kl_builder=stochastic_kl_builder, sampling_mode_tensor=None, fresh_noise_per_connection=True, keep_control_dependencies=False)

Creates a custom getter which does Bayes by Backprop.

Please see tf.get_variable for general documentation on custom getters.

All arguments are optional. If nothing is configued, then a diagonal gaussian posterior will be used, and a fixed N(0, 0.01) prior will be used. Please see the default posterior_builder and prior_builder for a more detailed understanding of the default settings.

Args:
  • posterior_builder: A builder function which constructs an instance of tfp.distributions.Distribution which shall serve as the posterior over the tf.Variable of interest. The builder receives the getter and the arguments forwarded from tf.get_variable. Suppose one wrote

    tf.get_variable( 'weights', shape=(3,), initializer=tf.zeros_initializer, dtype=tf.float32)

    then the posterior_builder argument would receive the name, shape, initializer, and dtype arguments passed above. The builder must return a tfp.distributions.Distribution object.

    Please see the tf.get_variable for documentation on custom_getter and getter, and see bbb.diagonal_gaussian_posterior_builder (the default) for an example of using this builder API.

  • prior_builder: A builder function which constructs an instance of tfp.distributions.Distribution which shall serve as the prior over the tf.Variable of interest. Identical API to posterior_builder. See bbb.fixed_gaussian_prior_builder (the default) for an example.

  • kl_builder: A builder function which receives the posterior distribution, prior distribution, and a sample from the posterior. It returns a scalar-shaped tf.Tensor representing the total KL cost for the tf.Variable in question. See bbb.stochastic_kl_builder (default) and bbb.analytic_kl_builder for examples.
  • sampling_mode_tensor: A tf.Tensor which determines how an estimate from the posterior is produced. It must be scalar-shaped and have a dtype of tf.string. Valid values for this tensor are bbb.EstimatorModes.sample (which is the default), bbb.EstimatorModes.mean, and bbb.EstimatorModes.last_sample. bbb.EstimatorModes.sample is appropriate for training, and bbb.EstimatorModes.mean can be used at test time.
  • fresh_noise_per_connection: A boolean. Indicates that each time a stochastic variable is retrieved with this custom getter, new sampling noise should be used. This is True by default. If this argument is set to False, then the same noise is used for each connection. Note that this does not apply to connections within a tf.while_loop; the same sampling noise is always used in different iterations of a tf.while_loop within one session.run() call. See the unit tests for details.
  • keep_control_dependencies: A boolean. This argument should only be used by advanced users. Indicates that each time a stochastic variable is retrieved in the loop body of a tf.while_loop construct, new sampling noise should be used. The default behavior is False, so that RNNs use the same weights at each recurrent time step. This is done by removing the creation of the Variable from any existing control flow contexts. Notably, the Variables will be created outside the context of any tf.while_loop, making them fetchable. When this argument is True, any Variables used in the loop body of a tf.while_loop will be non-fetchable. If the KL cost needs to be evaluated, the Variable must first be used outside the loop body. This op using the Variable simply needs to be placed on the graph to get a stochastic estimate of the KL; it doesn't need to ever be used. Example:

    ``` def loop_body(i): logits = sonnet_module(queue) i = i + 1

    with tf.variable_scope('bbb', custom_getter=bbb.bayes_by_backprop_getter( fresh_noise_per_connection=True, keep_control_dependencies=True)): unused_op = sonnet_module(queue) # Adds KL estimate to bbb Collection final_i = tf.while_loop(lambda i: i < 5, loop_body, tf.constant(0.)) ```

    Here when we add unused_op to the graph, we also add a number of tensors associated with the particular stochastic variable, including its contribution to the KL cost, to a graph-level registry. These are organized in a per-stochastic-variable data structure and be accessed with bbb.get_variable_metadata(). Without this line, these Tensors would instead be added the first time the Variable is used in the while_loop, which would make them non-fetchable.

    In all cases, the KL cost is only added once per Variable, which is the correct behavior, since if a variable is used multiple times in a model, the KL cost should remain unaffected.

Returns:

A custom_getter function which implements Bayes by Backprop.

custom_getters.bayes_by_backprop.diagonal_gaussian_posterior_builder(getter, name, shape=None, *args, **kwargs)

A pre-canned builder for diagonal gaussian posterior distributions.

Given a true getter function and arguments forwarded from tf.get_variable, return a distribution object for a diagonal posterior over a variable of the requisite shape.

Args:
  • getter: The getter passed to a custom_getter. Please see the documentation for tf.get_variable.
  • name: The name argument passed to tf.get_variable.
  • shape: The shape argument passed to tf.get_variable.
  • *args: See positional arguments passed to tf.get_variable.
  • **kwargs: See keyword arguments passed to tf.get_variable.
Returns:

An instance of tfp.distributions.Normal representing the posterior distribution over the variable in question.

custom_getters.bayes_by_backprop.fixed_gaussian_prior_builder(getter, name, dtype=None, *args, **kwargs)

A pre-canned builder for fixed gaussian prior distributions.

Given a true getter function and arguments forwarded from tf.get_variable, return a distribution object for a scalar-valued fixed gaussian prior which will be broadcast over a variable of the requisite shape.

Args:
  • getter: The getter passed to a custom_getter. Please see the documentation for tf.get_variable.
  • name: The name argument passed to tf.get_variable.
  • dtype: The dtype argument passed to tf.get_variable.
  • *args: See positional arguments passed to tf.get_variable.
  • **kwargs: See keyword arguments passed to tf.get_variable.
Returns:

An instance of tfp.distributions.Normal representing the prior distribution over the variable in question.

custom_getters.bayes_by_backprop.get_total_kl_cost(name='total_kl_cost', filter_by_name_substring=None)

Get the total cost for all (or a subset of) the stochastic variables.

Args:
  • name: A name for the tensor representing the total kl cost.
  • filter_by_name_substring: A string used to filter which variables count toward the total KL cost. By default, this argument is None, and all variables trained using Bayes by Backprop are included. If this argument is provided, the variables whose KL costs are summed will be all those whose name contains filter_by_name_substring. An example use of this would be to select all variables within a particular scope.
Returns:

A tensor representing the total KL cost in the ELBO loss.

custom_getters.bayes_by_backprop.get_variable_metadata(scope_name_substring=None)

custom_getters.bayes_by_backprop.inverse_softplus(y)

The inverse of the softplus function.

Computes the inverse of softplus, a function which maps an unconstrained real number to the positive reals, e.g. to squash an unconstrained neural network activation to parameterize a variance.

Args:
  • y: A positive number.
Returns:

The number x such that softplus(x) = y.

custom_getters.bayes_by_backprop.scale_variable_initializer(desired_scale)

custom_getters.bayes_by_backprop.stochastic_kl_builder(posterior, prior, sample)

A pre-canned builder for a ubiquitous stochastic KL estimator.

nest.assert_same_structure(*args, **kwargs)

nest.assert_shallow_structure(*args, **kwargs)

nest.flatten(*args, **kwargs)

nest.flatten_dict_items(*args, **kwargs)

nest.flatten_iterable(*args, **kwargs)

nest.flatten_up_to(*args, **kwargs)

nest.is_iterable(*args, **kwargs)

nest.is_sequence(*args, **kwargs)

nest.map(*args, **kwargs)

nest.map_up_to(*args, **kwargs)

nest.pack_iterable_as(*args, **kwargs)

nest.pack_sequence_as(*args, **kwargs)

nest.with_deprecation_warning(fn, extra_message='')

Wraps the function and prints a warn-once (per extra_message) warning.

class nets.AlexNet

Implementation of AlexNet with full and mini versions.

Based on: 'ImageNet Classification with Deep Convolutional Neural Networks' Alex Krizhevsky, Ilya Sutskever, Geoffrey E. Hinton, NIPS 2012 http://papers.nips.cc/paper/4824-imagenet-classification-w

nets.AlexNet.__init__(mode, use_batch_norm=False, batch_norm_config=None, initializers=None, partitioners=None, regularizers=None, bn_on_fc_layers=True, custom_getter=None, name='alex_net')

Constructs AlexNet.

Args:
  • mode: Construction mode of network: AlexNet.FULL or AlexNet.MINI.
  • use_batch_norm: Whether to use batch normalization between the output of a layer and the activation function.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializers are truncated normal initializers, which are commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf).
  • partitioners: Optional dict containing partitioners for the filters (with key 'w') and the biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • bn_on_fc_layers: If use_batch_norm is True, add batch normalization to the fully-connected layers. This is deprecated.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:

base.Error: If the given mode is not one of AlexNet.FULL, or AlexNet.MINI.

  • KeyError: If initializers, partitioners or regularizers contains any keys other than 'w' or 'b'.

nets.AlexNet.__call__(inputs, keep_prob=None, is_training=None, test_local_stats=True)

Connects the AlexNet module into the graph.

The is_training flag only controls the batch norm settings, if False it does not force no dropout by overriding any input keep_prob. To avoid any confusion this may cause, if is_training=False and keep_prob would cause dropout to be applied, an error is thrown.

Args:
  • inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images.
  • keep_prob: A scalar Tensor representing the dropout keep probability. When is_training=False this must be None or 1 to give no dropout.
  • is_training: Boolean to indicate if we are currently training. Must be specified if batch normalization or dropout is used.
  • test_local_stats: Boolean to indicate to snt.BatchNorm if batch normalization should use local batch statistics at test time. By default True.
Returns:

A Tensor of size [batch_size, output_size], where output_size depends on the mode the network was constructed in.

Raises:

base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode.

  • ValueError: If keep_prob is not None or 1 when is_training=False.
  • ValueError: If is_training is not explicitly specified when using batch normalization.

nets.AlexNet.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.AlexNet.conv_modules

Returns list containing convolutional modules of network.

Returns:

A list containing the Conv2D modules.

nets.AlexNet.defun()

Wraps this modules call method in a callable graph function.

nets.AlexNet.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.AlexNet.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.AlexNet.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.graph

Returns the Graph instance which the module is connected to, or None.

nets.AlexNet.initializers

nets.AlexNet.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.AlexNet.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.linear_modules

Returns list containing linear modules of network.

Returns:

A list containing the Linear modules.

nets.AlexNet.min_input_size

Returns integer specifying the minimum width and height for the input.

Note that the input can be non-square, but both the width and height must be >= this number in size.

Returns:

The minimum size as an integer.

nets.AlexNet.module_name

Returns the name of the Module.

nets.AlexNet.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.AlexNet.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.partitioners

nets.AlexNet.regularizers

nets.AlexNet.scope_name

Returns the full name of the Module's variable scope.

nets.AlexNet.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNet.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.AlexNetFull

AlexNet constructed in the 'FULL' mode.

nets.AlexNetFull.__init__(use_batch_norm=False, batch_norm_config=None, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='alex_net_full')

Constructs AlexNet.

Args:
  • use_batch_norm: Whether to use batch normalization between the output of a layer and the activation function.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializers are truncated normal initializers, which are commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf).
  • partitioners: Optional dict containing partitioners for the filters (with key 'w') and the biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • KeyError: If initializers, partitioners or regularizers contains any keys other than 'w' or 'b'.

nets.AlexNetFull.__call__(inputs, keep_prob=None, is_training=None, test_local_stats=True)

Connects the AlexNet module into the graph.

The is_training flag only controls the batch norm settings, if False it does not force no dropout by overriding any input keep_prob. To avoid any confusion this may cause, if is_training=False and keep_prob would cause dropout to be applied, an error is thrown.

Args:
  • inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images.
  • keep_prob: A scalar Tensor representing the dropout keep probability. When is_training=False this must be None or 1 to give no dropout.
  • is_training: Boolean to indicate if we are currently training. Must be specified if batch normalization or dropout is used.
  • test_local_stats: Boolean to indicate to snt.BatchNorm if batch normalization should use local batch statistics at test time. By default True.
Returns:

A Tensor of size [batch_size, output_size], where output_size depends on the mode the network was constructed in.

Raises:

base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode.

  • ValueError: If keep_prob is not None or 1 when is_training=False.
  • ValueError: If is_training is not explicitly specified when using batch normalization.

nets.AlexNetFull.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.AlexNetFull.conv_modules

Returns list containing convolutional modules of network.

Returns:

A list containing the Conv2D modules.

nets.AlexNetFull.defun()

Wraps this modules call method in a callable graph function.

nets.AlexNetFull.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.AlexNetFull.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.AlexNetFull.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.graph

Returns the Graph instance which the module is connected to, or None.

nets.AlexNetFull.initializers

nets.AlexNetFull.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.AlexNetFull.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.linear_modules

Returns list containing linear modules of network.

Returns:

A list containing the Linear modules.

nets.AlexNetFull.min_input_size

Returns integer specifying the minimum width and height for the input.

Note that the input can be non-square, but both the width and height must be >= this number in size.

Returns:

The minimum size as an integer.

nets.AlexNetFull.module_name

Returns the name of the Module.

nets.AlexNetFull.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.AlexNetFull.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.partitioners

nets.AlexNetFull.regularizers

nets.AlexNetFull.scope_name

Returns the full name of the Module's variable scope.

nets.AlexNetFull.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetFull.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.AlexNetMini

AlexNet constructed in the 'MINI' mode.

nets.AlexNetMini.__init__(use_batch_norm=False, batch_norm_config=None, initializers=None, partitioners=None, regularizers=None, custom_getter=None, name='alex_net_mini')

Constructs AlexNet.

Args:
  • use_batch_norm: Whether to use batch normalization between the output of a layer and the activation function.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules.
  • initializers: Optional dict containing ops to initialize the filters (with key 'w') or biases (with key 'b'). The default initializers are truncated normal initializers, which are commonly used when the inputs are zero centered (see https://arxiv.org/pdf/1502.03167v3.pdf).
  • partitioners: Optional dict containing partitioners for the filters (with key 'w') and the biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • KeyError: If initializers, partitioners or regularizers contains any keys other than 'w' or 'b'.

nets.AlexNetMini.__call__(inputs, keep_prob=None, is_training=None, test_local_stats=True)

Connects the AlexNet module into the graph.

The is_training flag only controls the batch norm settings, if False it does not force no dropout by overriding any input keep_prob. To avoid any confusion this may cause, if is_training=False and keep_prob would cause dropout to be applied, an error is thrown.

Args:
  • inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images.
  • keep_prob: A scalar Tensor representing the dropout keep probability. When is_training=False this must be None or 1 to give no dropout.
  • is_training: Boolean to indicate if we are currently training. Must be specified if batch normalization or dropout is used.
  • test_local_stats: Boolean to indicate to snt.BatchNorm if batch normalization should use local batch statistics at test time. By default True.
Returns:

A Tensor of size [batch_size, output_size], where output_size depends on the mode the network was constructed in.

Raises:

base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode.

  • ValueError: If keep_prob is not None or 1 when is_training=False.
  • ValueError: If is_training is not explicitly specified when using batch normalization.

nets.AlexNetMini.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.AlexNetMini.conv_modules

Returns list containing convolutional modules of network.

Returns:

A list containing the Conv2D modules.

nets.AlexNetMini.defun()

Wraps this modules call method in a callable graph function.

nets.AlexNetMini.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.AlexNetMini.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.AlexNetMini.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.graph

Returns the Graph instance which the module is connected to, or None.

nets.AlexNetMini.initializers

nets.AlexNetMini.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.AlexNetMini.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.linear_modules

Returns list containing linear modules of network.

Returns:

A list containing the Linear modules.

nets.AlexNetMini.min_input_size

Returns integer specifying the minimum width and height for the input.

Note that the input can be non-square, but both the width and height must be >= this number in size.

Returns:

The minimum size as an integer.

nets.AlexNetMini.module_name

Returns the name of the Module.

nets.AlexNetMini.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.AlexNetMini.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.partitioners

nets.AlexNetMini.regularizers

nets.AlexNetMini.scope_name

Returns the full name of the Module's variable scope.

nets.AlexNetMini.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.AlexNetMini.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.ConvNet2D

A 2D Convolutional Network module.

nets.ConvNet2D.__init__(output_channels, kernel_shapes, strides, paddings, rates=(1,), activation=relu, activate_final=False, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=None, use_bias=True, batch_norm_config=None, data_format='NHWC', custom_getter=None, name='conv_net_2d')

Constructs a ConvNet2D module.

By default, neither batch normalization nor activation are applied to the output of the final layer.

Args:
  • output_channels: Iterable of output channels, as defined in conv.Conv2D. Output channels can be defined either as number or via a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that entries can be called when build is called. Each entry in the iterable defines properties in the corresponding convolutional layer.
  • kernel_shapes: Iterable of kernel sizes as defined in conv.Conv2D; if the list contains one element only, the same kernel shape is used in each layer of the network.
  • strides: Iterable of kernel strides as defined in conv.Conv2D; if the list contains one element only, the same stride is used in each layer of the network.
  • paddings: Iterable of padding options as defined in conv.Conv2D. Each can be snt.SAME, snt.VALID, snt.FULL, snt.CAUSAL, snt.REVERSE_CAUSAL or a pair of these to use for height and width. If the Iterable contains one element only, the same padding is used in each layer of the network.
  • rates: Iterable of dilation rates as defined in conv.Conv2D; if the list contains one element only, the same rate is used in each layer of the network.
  • activation: An activation op.
  • activate_final: Boolean determining if the activation and batch normalization, if turned on, are applied to the final layer.
  • normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: snt.BatchNormV2, snt.LayerNorm. If a string is provided, importlib is used to convert the string to a callable, so either snt.LayerNorm or "snt.LayerNorm" can be provided.
  • normalization_kwargs: kwargs to be provided to normalization_ctor when it is called.
  • normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final.
  • initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • use_batch_norm: Boolean determining if batch normalization is applied after convolution. Deprecated, use normalization_ctor instead.
  • use_bias: Boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default True.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules. Deprecated, use normalization_kwargs instead.
  • data_format: A string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension (default, "NHWC"), or the second dimension ("NCHW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API. Note that this custom_getter will not be passed to the transpose method. If you want to use a custom getter with the transposed of this convolutional network, you should provide one to the transpose method instead.
  • name: Name of the module.
Raises:
  • TypeError: If output_channels is not iterable; or if kernel_shapes is not iterable; or strides is not iterable; or paddings is not iterable; or if activation is not callable.
  • ValueError: If output_channels is empty; or if kernel_shapes has not length 1 or len(output_channels); or if strides has not length 1 or len(output_channels); or if paddings has not length 1 or len(output_channels); or if rates has not length 1 or len(output_channels); or if the given data_format is not a supported format ("NHWC" or "NCHW"); or if normalization_ctor is provided but cannot be mapped to a callable.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.

nets.ConvNet2D.__call__(inputs, **normalization_build_kwargs)

Assembles the ConvNet2D and connects it to the graph.

Args:
  • inputs: A 4D Tensor of shape [batch_size, input_height, input_width, input_channels].
  • **normalization_build_kwargs: kwargs passed to the normalization module at _build time.
Returns:

A 4D Tensor of shape [batch_size, output_height, output_width, output_channels[-1]].

Raises:
  • ValueError: If is_training is not explicitly specified when using batch normalization.

nets.ConvNet2D.activate_final

nets.ConvNet2D.activation

nets.ConvNet2D.batch_norm_config

nets.ConvNet2D.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.ConvNet2D.defun()

Wraps this modules call method in a callable graph function.

nets.ConvNet2D.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.ConvNet2D.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.ConvNet2D.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.graph

Returns the Graph instance which the module is connected to, or None.

nets.ConvNet2D.initializers

nets.ConvNet2D.input_shape

Returns shape of input Tensor passed at last call to build.

nets.ConvNet2D.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.ConvNet2D.kernel_shapes

nets.ConvNet2D.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.layers

Returns a tuple containing the convolutional layers of the network.

nets.ConvNet2D.module_name

Returns the name of the Module.

nets.ConvNet2D.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.ConvNet2D.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.normalization_ctor

nets.ConvNet2D.normalization_kwargs

nets.ConvNet2D.normalize_final

nets.ConvNet2D.output_channels

nets.ConvNet2D.paddings

nets.ConvNet2D.partitioners

nets.ConvNet2D.rates

nets.ConvNet2D.regularizers

nets.ConvNet2D.scope_name

Returns the full name of the Module's variable scope.

nets.ConvNet2D.strides

nets.ConvNet2D.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.transpose(name=None, output_channels=None, kernel_shapes=None, strides=None, paddings=None, activation=None, activate_final=None, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=None, use_bias=None, batch_norm_config=None, data_format=None, custom_getter=None)

Returns transposed version of this network.

Args:
  • name: Optional string specifying the name of the transposed module. The default name is constructed by appending "_transpose" to self.module_name.
  • output_channels: Optional iterable of numbers of output channels.
  • kernel_shapes: Optional iterable of kernel sizes. The default value is constructed by reversing self.kernel_shapes.
  • strides: Optional iterable of kernel strides. The default value is constructed by reversing self.strides.
  • paddings: Optional iterable of padding options, either snt.SAME or snt.VALID; The default value is constructed by reversing self.paddings.
  • activation: Optional activation op. Default value is self.activation.
  • activate_final: Optional boolean determining if the activation and batch normalization, if turned on, are applied to the final layer.
  • normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: snt.BatchNormV2, snt.LayerNorm. If a string is provided, importlib is used to convert the string to a callable, so either snt.LayerNorm or "snt.LayerNorm" can be provided.
  • normalization_kwargs: kwargs to be provided to normalization_ctor when it is called.
  • normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final.
  • initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b'). The default value is self.initializers.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). The default value is self.partitioners.
  • regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). The default is self.regularizers.
  • use_batch_norm: Optional boolean determining if batch normalization is applied after convolution. The default value is self.use_batch_norm.
  • use_bias: Optional boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default is constructed by reversing self.use_bias.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules. Default is self.batch_norm_config.
  • data_format: Optional string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension. Default is self._data_format.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
Returns:

Matching ConvNet2DTranspose module.

Raises:
  • ValueError: If output_channels is specified and its length does not match the number of layers.
  • ValueError: If the given data_format is not a supported format ("NHWC" or "NCHW").
  • NotImplementedError: If the convolutions are dilated.

nets.ConvNet2D.use_batch_norm

nets.ConvNet2D.use_bias

nets.ConvNet2D.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2D.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.ConvNet2DTranspose

A 2D Transpose-Convolutional Network module.

nets.ConvNet2DTranspose.__init__(output_channels, output_shapes, kernel_shapes, strides, paddings, activation=relu, activate_final=False, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=False, use_bias=True, batch_norm_config=None, data_format='NHWC', custom_getter=None, name='conv_net_2d_transpose')

Constructs a ConvNetTranspose2D module.

output_{shapes,channels} can be defined either as iterable of {iterables,integers} or via a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that entries can be called returning meaningful values when build is called. Each entry in the iterable defines properties in the corresponding convolutional layer.

By default, neither batch normalization nor activation are applied to the output of the final layer.

Args:
  • output_channels: Iterable of numbers of output channels.
  • output_shapes: Iterable of output shapes as defined in conv.conv2DTranpose; if the iterable contains one element only, the same shape is used in each layer of the network.
  • kernel_shapes: Iterable of kernel sizes as defined in conv.Conv2D; if the list contains one element only, the same kernel shape is used in each layer of the network.
  • strides: Iterable of kernel strides as defined in conv.Conv2D; if the list contains one element only, the same stride is used in each layer of the network.
  • paddings: Iterable of padding options, either snt.SAME or snt.VALID; if the Iterable contains one element only, the same padding is used in each layer of the network.
  • activation: An activation op.
  • activate_final: Boolean determining if the activation and batch normalization, if turned on, are applied to the final layer.
  • normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: snt.BatchNormV2, snt.LayerNorm. If a string is provided, importlib is used to convert the string to a callable, so either snt.LayerNorm or "snt.LayerNorm" can be provided.
  • normalization_kwargs: kwargs to be provided to normalization_ctor when it is called.
  • normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final.
  • initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). As a default, no partitioners are used.
  • regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • use_batch_norm: Boolean determining if batch normalization is applied after convolution.
  • use_bias: Boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default True.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules.
  • data_format: A string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension (default, "NHWC"), or the second dimension ("NCHW").
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • TypeError: If output_channels is not iterable; or if output_shapes is not iterable; or if kernel_shapes is not iterable; or if strides is not iterable; or if paddings is not iterable; or if activation is not callable.
  • ValueError: If output_channels is empty; or if kernel_shapes has not length 1 or len(output_channels); or if strides has not length 1 or len(output_channels); or if paddings has not length 1 or len(output_channels).
  • ValueError: If the given data_format is not a supported format ("NHWC" or "NCHW").
  • ValueError: If normalization_ctor is provided but cannot be converted to a callable.
  • KeyError: If initializers, partitioners or regularizers contain any keys other than 'w' or 'b'.
  • TypeError: If any of the given initializers, partitioners or regularizers are not callable.

nets.ConvNet2DTranspose.__call__(inputs, **normalization_build_kwargs)

Assembles the ConvNet2D and connects it to the graph.

Args:
  • inputs: A 4D Tensor of shape [batch_size, input_height, input_width, input_channels].
  • **normalization_build_kwargs: kwargs passed to the normalization module at _build time.
Returns:

A 4D Tensor of shape [batch_size, output_height, output_width, output_channels[-1]].

Raises:
  • ValueError: If is_training is not explicitly specified when using batch normalization.

nets.ConvNet2DTranspose.activate_final

nets.ConvNet2DTranspose.activation

nets.ConvNet2DTranspose.batch_norm_config

nets.ConvNet2DTranspose.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.ConvNet2DTranspose.defun()

Wraps this modules call method in a callable graph function.

nets.ConvNet2DTranspose.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.ConvNet2DTranspose.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.ConvNet2DTranspose.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.graph

Returns the Graph instance which the module is connected to, or None.

nets.ConvNet2DTranspose.initializers

nets.ConvNet2DTranspose.input_shape

Returns shape of input Tensor passed at last call to build.

nets.ConvNet2DTranspose.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.ConvNet2DTranspose.kernel_shapes

nets.ConvNet2DTranspose.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.layers

Returns a tuple containing the convolutional layers of the network.

nets.ConvNet2DTranspose.module_name

Returns the name of the Module.

nets.ConvNet2DTranspose.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.ConvNet2DTranspose.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.normalization_ctor

nets.ConvNet2DTranspose.normalization_kwargs

nets.ConvNet2DTranspose.normalize_final

nets.ConvNet2DTranspose.output_channels

nets.ConvNet2DTranspose.output_shapes

nets.ConvNet2DTranspose.paddings

nets.ConvNet2DTranspose.partitioners

nets.ConvNet2DTranspose.rates

nets.ConvNet2DTranspose.regularizers

nets.ConvNet2DTranspose.scope_name

Returns the full name of the Module's variable scope.

nets.ConvNet2DTranspose.strides

nets.ConvNet2DTranspose.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.transpose(name=None, output_channels=None, kernel_shapes=None, strides=None, paddings=None, activation=None, activate_final=None, normalization_ctor=None, normalization_kwargs=None, normalize_final=None, initializers=None, partitioners=None, regularizers=None, use_batch_norm=None, use_bias=None, batch_norm_config=None, data_format=None, custom_getter=None)

Returns transposed version of this network.

Args:
  • name: Optional string specifying the name of the transposed module. The default name is constructed by appending "_transpose" to self.module_name.
  • output_channels: Optional iterable of numbers of output channels.
  • kernel_shapes: Optional iterable of kernel sizes. The default value is constructed by reversing self.kernel_shapes.
  • strides: Optional iterable of kernel strides. The default value is constructed by reversing self.strides.
  • paddings: Optional iterable of padding options, either snt.SAME or snt.VALID; The default value is constructed by reversing self.paddings.
  • activation: Optional activation op. Default value is self.activation.
  • activate_final: Optional boolean determining if the activation and batch normalization, if turned on, are applied to the final layer.
  • normalization_ctor: Constructor to return a callable which will perform normalization at each layer. Defaults to None / no normalization. Examples of what could go here: snt.BatchNormV2, snt.LayerNorm. If a string is provided, importlib is used to convert the string to a callable, so either snt.LayerNorm or "snt.LayerNorm" can be provided.
  • normalization_kwargs: kwargs to be provided to normalization_ctor when it is called.
  • normalize_final: Whether to apply normalization after the final conv layer. Default is to take the value of activate_final.
  • initializers: Optional dict containing ops to initialize the filters of the whole network (with key 'w') or biases (with key 'b'). The default value is self.initializers.
  • partitioners: Optional dict containing partitioners to partition weights (with key 'w') or biases (with key 'b'). The default value is self.partitioners.
  • regularizers: Optional dict containing regularizers for the filters of the whole network (with key 'w') or biases (with key 'b'). The default is self.regularizers.
  • use_batch_norm: Optional boolean determining if batch normalization is applied after convolution. The default value is self.use_batch_norm.
  • use_bias: Optional boolean or iterable of booleans determining whether to include bias parameters in the convolutional layers. Default is constructed by reversing self.use_bias.
  • batch_norm_config: Optional mapping of additional configuration for the snt.BatchNorm modules. Default is self.batch_norm_config.
  • data_format: Optional string, one of "NCHW" or "NHWC". Specifies whether the channel dimension of the input and output is the last dimension. Default is self._data_format.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
Returns:

Matching ConvNet2D module.

Raises:
  • ValueError: If output_channels is specified and its length does not match the number of layers.

nets.ConvNet2DTranspose.use_batch_norm

nets.ConvNet2DTranspose.use_bias

nets.ConvNet2DTranspose.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.ConvNet2DTranspose.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.Dilation

A convolutional module for per-pixel classification.

Consists of 8 convolutional layers, 4 of which are dilated. When applied to the output of a model like VGG-16 (before fully connected layers), can be used to make predictions on a per-pixel basis.

Note that the default initializers for the 'basic' model size require that the number of input channels be equal to the number of output classes, and the initializers for the 'large' model require it be a multiple.

Based on: 'Multi-Scale Context Aggregation by Dilated Convolutions' Fisher Yu, Vladlen Koltun, ICLR 2016 https://arxiv.org/abs/1511.07122

Properties: conv_modules: list of sonnet modules. The 8 convolution layers used in the Dilation module.

nets.Dilation.__init__(num_output_classes, initializers=None, regularizers=None, model_size='basic', name='dilation')

Creates a dilation module.

Args:
  • num_output_classes: Int. Number of output classes to predict for each pixel in an image.
  • initializers: Optional dict containing ops to initialize filters (with key 'w') or biases (with key 'b'). The default initializer makes this module equivalent to the identity.
  • regularizers: Optional dict containing regularizers for the weights (with key 'w') or biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • model_size: string. One of 'basic' or 'large'.
  • name: string. Name of module.

nets.Dilation.__call__(images)

Build dilation module.

Args:
  • images: Tensor of shape [batch_size, height, width, depth] and dtype float32. Represents a set of images with an arbitrary depth. Note that when using the default initializer, depth must equal num_output_classes.
Returns:

Tensor of shape [batch_size, height, width, num_output_classes] and dtype float32. Represents, for each image and pixel, logits for per-class predictions.

Raises:
  • IncompatibleShapeError: If images is not rank 4.
  • ValueError: If model_size is not one of 'basic' or 'large'.

nets.Dilation.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.Dilation.conv_modules

nets.Dilation.defun()

Wraps this modules call method in a callable graph function.

nets.Dilation.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.Dilation.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.Dilation.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.graph

Returns the Graph instance which the module is connected to, or None.

nets.Dilation.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.Dilation.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.module_name

Returns the name of the Module.

nets.Dilation.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.Dilation.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.scope_name

Returns the full name of the Module's variable scope.

nets.Dilation.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.Dilation.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.MLP

A Multi-Layer perceptron module.

nets.MLP.__init__(output_sizes, activation=relu, activate_final=False, initializers=None, partitioners=None, regularizers=None, use_bias=True, use_dropout=False, custom_getter=None, name='mlp')

Constructs an MLP module.

Args:
  • output_sizes: An iterable of output dimensionalities as defined in basic.Linear. Output size can be defined either as number or via a callable. In the latter case, since the function invocation is deferred to graph construction time, the user must only ensure that entries can be called when build is called. Each entry in the iterable defines properties in the corresponding linear layer.
  • activation: An activation op. The activation is applied to intermediate layers, and optionally to the output of the final layer.
  • activate_final: Boolean determining if the activation is applied to the output of the final layer. Default False.
  • initializers: Optional dict containing ops to initialize the linear layers' weights (with key 'w') or biases (with key 'b').
  • partitioners: Optional dict containing partitioners to partition the linear layers' weights (with key 'w') or biases (with key 'b').
  • regularizers: Optional dict containing regularizers for the linear layers' weights (with key 'w') and the biases (with key 'b'). As a default, no regularizers are used. A regularizer should be a function that takes a single Tensor as an input and returns a scalar Tensor output, e.g. the L1 and L2 regularizers in tf.contrib.layers.
  • use_bias: Whether to include bias parameters in the linear layers. Default True.
  • use_dropout: Whether to perform dropout on the linear layers. Default False.
  • custom_getter: Callable or dictionary of callables to use as custom getters inside the module. If a dictionary, the keys correspond to regexes to match variable names. See the tf.get_variable documentation for information about the custom_getter API.
  • name: Name of the module.
Raises:
  • KeyError: If initializers contains any keys other than 'w' or 'b'.
  • KeyError: If regularizers contains any keys other than 'w' or 'b'.
  • ValueError: If output_sizes is empty.
  • TypeError: If activation is not callable; or if output_sizes is not iterable.

nets.MLP.__call__(inputs, is_training=True, dropout_keep_prob=0.5)

Assembles the MLP and connects it to the graph.

Args:
  • inputs: A 2D Tensor of size [batch_size, input_size].
  • is_training: A bool or tf.Bool Tensor. Indicates whether we are currently training. Defaults to True.
  • dropout_keep_prob: The probability that each element is kept when both use_dropout and is_training are True. Defaults to 0.5.
Returns:

A 2D Tensor of size [batch_size, output_sizes[-1]].

nets.MLP.activate_final

nets.MLP.activation

nets.MLP.clone(name=None)

Creates a new MLP with the same structure.

Args:
  • name: Optional string specifying the name of the new module. The default name is constructed by appending "_clone" to the original name.
Returns:

A cloned MLP module.

nets.MLP.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.MLP.defun()

Wraps this modules call method in a callable graph function.

nets.MLP.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.MLP.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.get_possible_initializer_keys(cls, use_bias=True)

nets.MLP.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.graph

Returns the Graph instance which the module is connected to, or None.

nets.MLP.initializers

Returns the intializers dictionary.

nets.MLP.input_shape

Returns shape of input Tensor passed at last call to build.

nets.MLP.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.MLP.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.layers

Returns a tuple containing the linear layers of the MLP.

nets.MLP.module_name

Returns the name of the Module.

nets.MLP.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.MLP.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.output_size

Returns the size of the module output, not including the batch dimension.

This allows the MLP to be used inside a DeepRNN.

Returns:

The scalar size of the module output.

nets.MLP.output_sizes

Returns a tuple of all output sizes of all the layers.

nets.MLP.partitioners

Returns the partitioners dictionary.

nets.MLP.regularizers

Returns the regularizers dictionary.

nets.MLP.scope_name

Returns the full name of the Module's variable scope.

nets.MLP.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.transpose(name=None, activate_final=None)

Returns transposed MLP.

Args:
  • name: Optional string specifying the name of the transposed module. The default name is constructed by appending "_transpose" to self.module_name.
  • activate_final: Optional boolean determining if the activation and batch normalization, if turned on, are applied to the final layer.
Returns:

Matching transposed MLP module.

nets.MLP.use_bias

nets.MLP.use_dropout

nets.MLP.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.MLP.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.VectorQuantizer

Sonnet module representing the VQ-VAE layer.

Implements the algorithm presented in 'Neural Discrete Representation Learning' by van den Oord et al. https://arxiv.org/abs/1711.00937

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

Args: embedding_dim: integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well. num_embeddings: integer, the number of vectors in the quantized space. commitment_cost: scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

nets.VectorQuantizer.__init__(embedding_dim, num_embeddings, commitment_cost, name='vq_layer')

nets.VectorQuantizer.__call__(inputs, is_training)

Connects the module to some inputs.

Args:
  • inputs: Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.
  • is_training: boolean, whether this connection is to training data.
Returns:

dict containing the following keys and values:

  • quantize: Tensor containing the quantized version of the input.
  • loss: Tensor containing the loss to optimize.
  • perplexity: Tensor containing the perplexity of the encodings.
  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.
  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

nets.VectorQuantizer.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.VectorQuantizer.defun()

Wraps this modules call method in a callable graph function.

nets.VectorQuantizer.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.VectorQuantizer.embeddings

nets.VectorQuantizer.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.VectorQuantizer.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.

This method explicitly re-enters the Graph which this module has been connected to.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.graph

Returns the Graph instance which the module is connected to, or None.

nets.VectorQuantizer.is_connected

Returns true iff the Module been connected to the Graph at least once.

nets.VectorQuantizer.last_connected_subgraph

Returns the last subgraph created by this module.

Returns:

The last connected subgraph.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.module_name

Returns the name of the Module.

nets.VectorQuantizer.name_scopes

Returns a tuple of all name_scopes generated by this module.

nets.VectorQuantizer.non_trainable_variables

All non-trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.quantize(encoding_indices)

nets.VectorQuantizer.scope_name

Returns the full name of the Module's variable scope.

nets.VectorQuantizer.trainable_variables

All trainable tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.variable_scope

Returns the variable_scope declared by the module.

It is valid for library users to access the internal templated variable_scope, but only makes sense to do so after connection. Therefore we raise an error here if the variable_scope is requested before connection.

The only case where it does make sense to access the variable_scope before connection is to get the post-uniquification name, which we support using the separate .scope_name property.

Returns:
  • variable_scope: tf.VariableScope instance of the internal tf.Template.
Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizer.variables

All tf.Variables used when the module is connected.

This property does not rely on global collections and should generally be preferred vs. get_variables and get_all_variables.

See the documentation for AbstractModule._capture_variables() for more information about what variables are captured.

Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

class nets.VectorQuantizerEMA

Sonnet module representing the VQ-VAE layer.

Implements a slightly modified version of the algorithm presented in 'Neural Discrete Representation Learning' by van den Oord et al. https://arxiv.org/abs/1711.00937

The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses exponential moving averages to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac, ...) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

Args: embedding_dim: integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well. num_embeddings: integer, the number of vectors in the quantized space. commitment_cost: scalar which controls the weighting of the loss terms (see equation 4 in the paper). decay: float, decay for the moving averages. epsilon: small float constant to avoid numerical instability.

nets.VectorQuantizerEMA.__init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, name='VectorQuantizerEMA')

nets.VectorQuantizerEMA.__call__(inputs, is_training)

Connects the module to some inputs.

Args:
  • inputs: Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.
  • is_training: boolean, whether this connection is to training data. When this is set to False, the internal moving average statistics will not be updated.
Returns:

dict containing the following keys and values:

  • quantize: Tensor containing the quantized version of the input.
  • loss: Tensor containing the loss to optimize.
  • perplexity: Tensor containing the perplexity of the encodings.
  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.
  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

nets.VectorQuantizerEMA.connected_subgraphs

Returns the subgraphs created by this module so far.

nets.VectorQuantizerEMA.defun()

Wraps this modules call method in a callable graph function.

nets.VectorQuantizerEMA.defun_wrapped

Returns boolean indicating whether this module is defun wrapped.

nets.VectorQuantizerEMA.embeddings

nets.VectorQuantizerEMA.get_all_variables(collection='trainable_variables')

Returns all tf.Variables used when the module is connected.

See the documentation for AbstractModule._capture_variables() for more information.

Args:
  • collection: Collection to restrict query to. By default this is tf.Graphkeys.TRAINABLE_VARIABLES, which doesn't include non-trainable variables such as moving averages.
Returns:

A sorted (by variable name) tuple of tf.Variable objects.

Raises:
  • NotConnectedError: If the module is not connected to the Graph.

nets.VectorQuantizerEMA.get_possible_initializer_keys(cls)

Returns the keys the dictionary of variable initializers may contain.

This provides the user with a way of knowing the initializer keys that are available without having to instantiate a sonnet module. Subclasses may override this class method if they need additional arguments to determine what initializer keys may be provided.

Returns:

Set with strings corresponding to the strings that may be passed to the constructor.

nets.VectorQuantizerEMA.get_variables(collection='trainable_variables')

Returns tuple of tf.Variables declared inside this module.

Note that this operates by searching this module's variable scope, and so does not know about any modules that were constructed elsewhere but used inside this module.