tfwavelets package

Submodules

tfwavelets.dwtcoeffs module

The ‘dwtcoeffs’ module contains predefined wavelets, as well as the classes necessary to create more user-defined wavelets.

Wavelets are defined by the Wavelet class. A Wavelet object mainly consists of four Filter objects (defined by the Filter class) representing the decomposition and reconstruction low pass and high pass filters.

Examples

You can define your own wavelet by creating four filters, and combining them to a wavelet:

>>> decomp_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> decomp_hp = Filter([1 / np.sqrt(2), -1 / np.sqrt(2)], 1)
>>> recon_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> recon_hp = Filter([-1 / np.sqrt(2), 1 / np.sqrt(2)], 1)
>>> haar = Wavelet(decomp_lp, decomp_hp, recon_lp, recon_hp)
class tfwavelets.dwtcoeffs.Filter(coeffs, zero)[source]

Bases: object

Class representing a filter.

coeffs

Filter coefficients

Type

tf.constant

zero

Origin of filter (which index of coeffs array is actually indexed as 0).

Type

int

edge_matrices

List of edge matrices, used for circular convolution. Stored as 3D TF tensors (constants).

Type

iterable

num_neg()[source]

Number of negative indexed coefficients, excluding the origin.

Returns

Number of negative indexed coefficients

Return type

int

num_pos()[source]

Number of positive indexed coefficients in filter, including the origin. Ie, strictly speaking it’s the number of non-negative indexed coefficients.

Returns

Number of positive indexed coefficients in filter.

Return type

int

class tfwavelets.dwtcoeffs.TrainableFilter(initial_coeffs, zero, name=None)[source]

Bases: tfwavelets.dwtcoeffs.Filter

Class representing a trainable filter.

coeffs

Filter coefficients

Type

tf.Variable

zero

Origin of filter (which index of coeffs array is actually indexed as 0).

Type

int

class tfwavelets.dwtcoeffs.TrainableWavelet(wavelet)[source]

Bases: tfwavelets.dwtcoeffs.Wavelet

Class representing a trainable wavelet

decomp_lp

Filter coefficients for decomposition low pass filter

Type

TrainableFilter

decomp_hp

Filter coefficients for decomposition high pass filter

Type

TrainableFilter

recon_lp

Filter coefficients for reconstruction low pass filter

Type

TrainableFilter

recon_hp

Filter coefficients for reconstruction high pass filter

Type

TrainableFilter

class tfwavelets.dwtcoeffs.Wavelet(decomp_lp, decomp_hp, recon_lp, recon_hp)[source]

Bases: object

Class representing a wavelet.

decomp_lp

Filter coefficients for decomposition low pass filter

Type

Filter

decomp_hp

Filter coefficients for decomposition high pass filter

Type

Filter

recon_lp

Filter coefficients for reconstruction low pass filter

Type

Filter

recon_hp

Filter coefficients for reconstruction high pass filter

Type

Filter

tfwavelets.nodes module

The ‘nodes’ module contains methods to construct TF subgraphs computing the 1D or 2D DWT or IDWT. Intended to be used if you need a DWT in your own TF graph.

tfwavelets.nodes.cyclic_conv1d(input_node, filter_)[source]

Cyclic convolution

Parameters
  • input_node – Input signal (3-tensor [batch, width, in_channels])

  • filter – Filter

Returns

Tensor with the result of a periodic convolution

tfwavelets.nodes.cyclic_conv1d_alt(input_node, filter_)[source]

Alternative cyclic convolution. Uses more memory than cyclic_conv1d.

Parameters
  • input_node – Input signal

  • filter (Filter) – Filter object

Returns

Tensor with the result of a periodic convolution.

tfwavelets.nodes.dwt1d(input_node, wavelet, levels=1)[source]

Constructs a TF computational graph computing the 1D DWT of an input signal.

Parameters
  • input_node – A 3D tensor containing the signal. The dimensions should be [batch, signal, channels].

  • wavelet – Wavelet object

  • levels – Number of levels.

Returns

The output node of the DWT graph.

tfwavelets.nodes.dwt2d(input_node, wavelet, levels=1)[source]

Constructs a TF computational graph computing the 2D DWT of an input signal.

Parameters
  • input_node – A 3D tensor containing the signal. The dimensions should be [rows, cols, channels].

  • wavelet – Wavelet object.

  • levels – Number of levels.

Returns

The output node of the DWT graph.

tfwavelets.nodes.idwt1d(input_node, wavelet, levels=1)[source]

Constructs a TF graph that computes the 1D inverse DWT for a given wavelet.

Parameters
  • input_node (tf.placeholder) – Input signal. A 3D tensor with dimensions as [batch, signal, channels]

  • wavelet (tfwavelets.dwtcoeffs.Wavelet) – Wavelet object.

  • levels (int) – Number of levels.

Returns

Output node of IDWT graph.

tfwavelets.nodes.idwt2d(input_node, wavelet, levels=1)[source]

Constructs a TF graph that computes the 2D inverse DWT for a given wavelet.

Parameters
  • input_node (tf.placeholder) – Input signal. A 3D tensor with dimensions as [rows, cols, channels]

  • wavelet (tfwavelets.dwtcoeffs.Wavelet) – Wavelet object.

  • levels (int) – Number of levels.

Returns

Output node of IDWT graph.

tfwavelets.nodes.upsample(input_node, odd=False)[source]

Upsamples. Doubles the length of the input, filling with zeros

Parameters
  • input_node – 3-tensor [batch, spatial dim, channels] to be upsampled

  • odd – Bool, optional. If True, content of input_node will be placed on the odd indeces of the output. Otherwise, the content will be places on the even indeces. This is the default behaviour.

Returns

The upsampled output Tensor.

tfwavelets.utils module

The ‘utils’ module contains some useful helper functions, mostly used during the implementation of the other modules.

tfwavelets.utils.adapt_filter(filter)[source]

Expands dimensions of a 1d vector to match the required tensor dimensions in a TF graph.

Parameters

filter (np.ndarray) – A 1D vector containing filter coefficients

Returns

A 3D vector with two empty dimensions as dim 2 and 3.

Return type

np.ndarray

tfwavelets.utils.to_tf_mat(matrices)[source]

Expands dimensions of 2D matrices to match the required tensor dimensions in a TF graph, and wrapping them as TF constants.

Parameters

matrices (iterable) – A list (or tuple) of 2D numpy arrays.

Returns

A list of all the matrices converted to 3D TF tensors.

Return type

iterable

tfwavelets.wrappers module

The ‘wrappers’ module contains methods that wraps around the functionality in nodes. The construct a full TF graph, launches a session, and evaluates the graph. Intended to be used when you just want to compute the DWT/IDWT of a signal.

tfwavelets.wrappers.dwt1d(signal, wavelet='haar', levels=1)[source]

Computes the DWT of a 1D signal.

Parameters
  • signal (np.ndarray) – A 1D array to compute DWT of.

  • wavelet (str) – Type of wavelet (haar or db2)

  • levels (int) – Number of DWT levels

Returns

The DWT of the signal.

Return type

np.ndarray

Raises

ValueError – If wavelet is not supported

tfwavelets.wrappers.dwt2d(signal, wavelet='haar', levels=1)[source]

Computes the DWT of a 2D signal.

Parameters
  • signal (np.ndarray) – A 2D array to compute DWT of.

  • wavelet (str) – Type of wavelet (haar or db2)

  • levels (int) – Number of DWT levels

Returns

The DWT of the signal.

Return type

np.ndarray

Raises

ValueError – If wavelet is not supported

tfwavelets.wrappers.idwt1d(signal, wavelet='haar', levels=1)[source]

Computes the inverse DWT of a 1D signal.

Parameters
  • signal (np.ndarray) – A 1D array to compute IDWT of.

  • wavelet (str) – Type of wavelet (haar or db2)

  • levels (int) – Number of DWT levels

Returns

The IDWT of the signal.

Return type

np.ndarray

Raises

ValueError – If wavelet is not supported

tfwavelets.wrappers.idwt2d(signal, wavelet='haar', levels=1)[source]

Computes the inverse DWT of a 2D signal.

Parameters
  • signal (np.ndarray) – A 2D array to compute iDWT of.

  • wavelet (str) – Type of wavelet (haar or db2)

  • levels (int) – Number of DWT levels

Returns

The IDWT of the signal.

Return type

np.ndarray

Raises

ValueError – If wavelet is not supported

Module contents

The tfwavelets package offers ways to achieve discrete wavelet transforms in tensorflow.

The package consists of the following modules:

  • ‘nodes’ contains methods to construct TF subgraphs computing the 1D or 2D DWT or IDWT. Intended to be used if you need a DWT in your own TF graph.

  • ‘wrappers’ contains methods that wraps around the functionality in nodes. The construct a full TF graph, launches a session, and evaluates the graph. Intended to be used when you just want to compute the DWT/IDWT of a signal.

  • ‘dwtcoeffs’ contains predefined wavelets, as well as the classes necessary to create more user-defined wavelets.

  • ‘utils’ contains some useful helper functions, mostly used during the implementation of the other modules.