tfwavelets package¶
Submodules¶
tfwavelets.dwtcoeffs module¶
The ‘dwtcoeffs’ module contains predefined wavelets, as well as the classes necessary to create more user-defined wavelets.
Wavelets are defined by the Wavelet class. A Wavelet object mainly consists of four Filter objects (defined by the Filter class) representing the decomposition and reconstruction low pass and high pass filters.
Examples
You can define your own wavelet by creating four filters, and combining them to a wavelet:
>>> decomp_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> decomp_hp = Filter([1 / np.sqrt(2), -1 / np.sqrt(2)], 1)
>>> recon_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> recon_hp = Filter([-1 / np.sqrt(2), 1 / np.sqrt(2)], 1)
>>> haar = Wavelet(decomp_lp, decomp_hp, recon_lp, recon_hp)
-
class
tfwavelets.dwtcoeffs.
Filter
(coeffs, zero)[source]¶ Bases:
object
Class representing a filter.
-
coeffs
¶ Filter coefficients
- Type
tf.constant
-
zero
¶ Origin of filter (which index of coeffs array is actually indexed as 0).
- Type
int
-
edge_matrices
¶ List of edge matrices, used for circular convolution. Stored as 3D TF tensors (constants).
- Type
iterable
-
-
class
tfwavelets.dwtcoeffs.
TrainableFilter
(initial_coeffs, zero, name=None)[source]¶ Bases:
tfwavelets.dwtcoeffs.Filter
Class representing a trainable filter.
-
coeffs
¶ Filter coefficients
- Type
tf.Variable
-
zero
¶ Origin of filter (which index of coeffs array is actually indexed as 0).
- Type
int
-
-
class
tfwavelets.dwtcoeffs.
TrainableWavelet
(wavelet)[source]¶ Bases:
tfwavelets.dwtcoeffs.Wavelet
Class representing a trainable wavelet
-
decomp_lp
¶ Filter coefficients for decomposition low pass filter
- Type
-
decomp_hp
¶ Filter coefficients for decomposition high pass filter
- Type
-
recon_lp
¶ Filter coefficients for reconstruction low pass filter
- Type
-
recon_hp
¶ Filter coefficients for reconstruction high pass filter
- Type
-
tfwavelets.nodes module¶
The ‘nodes’ module contains methods to construct TF subgraphs computing the 1D or 2D DWT or IDWT. Intended to be used if you need a DWT in your own TF graph.
-
tfwavelets.nodes.
cyclic_conv1d
(input_node, filter_)[source]¶ Cyclic convolution
- Parameters
input_node – Input signal (3-tensor [batch, width, in_channels])
filter – Filter
- Returns
Tensor with the result of a periodic convolution
-
tfwavelets.nodes.
cyclic_conv1d_alt
(input_node, filter_)[source]¶ Alternative cyclic convolution. Uses more memory than cyclic_conv1d.
- Parameters
input_node – Input signal
filter (Filter) – Filter object
- Returns
Tensor with the result of a periodic convolution.
-
tfwavelets.nodes.
dwt1d
(input_node, wavelet, levels=1)[source]¶ Constructs a TF computational graph computing the 1D DWT of an input signal.
- Parameters
input_node – A 3D tensor containing the signal. The dimensions should be [batch, signal, channels].
wavelet – Wavelet object
levels – Number of levels.
- Returns
The output node of the DWT graph.
-
tfwavelets.nodes.
dwt2d
(input_node, wavelet, levels=1)[source]¶ Constructs a TF computational graph computing the 2D DWT of an input signal.
- Parameters
input_node – A 3D tensor containing the signal. The dimensions should be [rows, cols, channels].
wavelet – Wavelet object.
levels – Number of levels.
- Returns
The output node of the DWT graph.
-
tfwavelets.nodes.
idwt1d
(input_node, wavelet, levels=1)[source]¶ Constructs a TF graph that computes the 1D inverse DWT for a given wavelet.
- Parameters
input_node (tf.placeholder) – Input signal. A 3D tensor with dimensions as [batch, signal, channels]
wavelet (tfwavelets.dwtcoeffs.Wavelet) – Wavelet object.
levels (int) – Number of levels.
- Returns
Output node of IDWT graph.
-
tfwavelets.nodes.
idwt2d
(input_node, wavelet, levels=1)[source]¶ Constructs a TF graph that computes the 2D inverse DWT for a given wavelet.
- Parameters
input_node (tf.placeholder) – Input signal. A 3D tensor with dimensions as [rows, cols, channels]
wavelet (tfwavelets.dwtcoeffs.Wavelet) – Wavelet object.
levels (int) – Number of levels.
- Returns
Output node of IDWT graph.
-
tfwavelets.nodes.
upsample
(input_node, odd=False)[source]¶ Upsamples. Doubles the length of the input, filling with zeros
- Parameters
input_node – 3-tensor [batch, spatial dim, channels] to be upsampled
odd – Bool, optional. If True, content of input_node will be placed on the odd indeces of the output. Otherwise, the content will be places on the even indeces. This is the default behaviour.
- Returns
The upsampled output Tensor.
tfwavelets.utils module¶
The ‘utils’ module contains some useful helper functions, mostly used during the implementation of the other modules.
-
tfwavelets.utils.
adapt_filter
(filter)[source]¶ Expands dimensions of a 1d vector to match the required tensor dimensions in a TF graph.
- Parameters
filter (np.ndarray) – A 1D vector containing filter coefficients
- Returns
A 3D vector with two empty dimensions as dim 2 and 3.
- Return type
np.ndarray
-
tfwavelets.utils.
to_tf_mat
(matrices)[source]¶ Expands dimensions of 2D matrices to match the required tensor dimensions in a TF graph, and wrapping them as TF constants.
- Parameters
matrices (iterable) – A list (or tuple) of 2D numpy arrays.
- Returns
A list of all the matrices converted to 3D TF tensors.
- Return type
iterable
tfwavelets.wrappers module¶
The ‘wrappers’ module contains methods that wraps around the functionality in nodes. The construct a full TF graph, launches a session, and evaluates the graph. Intended to be used when you just want to compute the DWT/IDWT of a signal.
-
tfwavelets.wrappers.
dwt1d
(signal, wavelet='haar', levels=1)[source]¶ Computes the DWT of a 1D signal.
- Parameters
signal (np.ndarray) – A 1D array to compute DWT of.
wavelet (str) – Type of wavelet (haar or db2)
levels (int) – Number of DWT levels
- Returns
The DWT of the signal.
- Return type
np.ndarray
- Raises
ValueError – If wavelet is not supported
-
tfwavelets.wrappers.
dwt2d
(signal, wavelet='haar', levels=1)[source]¶ Computes the DWT of a 2D signal.
- Parameters
signal (np.ndarray) – A 2D array to compute DWT of.
wavelet (str) – Type of wavelet (haar or db2)
levels (int) – Number of DWT levels
- Returns
The DWT of the signal.
- Return type
np.ndarray
- Raises
ValueError – If wavelet is not supported
-
tfwavelets.wrappers.
idwt1d
(signal, wavelet='haar', levels=1)[source]¶ Computes the inverse DWT of a 1D signal.
- Parameters
signal (np.ndarray) – A 1D array to compute IDWT of.
wavelet (str) – Type of wavelet (haar or db2)
levels (int) – Number of DWT levels
- Returns
The IDWT of the signal.
- Return type
np.ndarray
- Raises
ValueError – If wavelet is not supported
-
tfwavelets.wrappers.
idwt2d
(signal, wavelet='haar', levels=1)[source]¶ Computes the inverse DWT of a 2D signal.
- Parameters
signal (np.ndarray) – A 2D array to compute iDWT of.
wavelet (str) – Type of wavelet (haar or db2)
levels (int) – Number of DWT levels
- Returns
The IDWT of the signal.
- Return type
np.ndarray
- Raises
ValueError – If wavelet is not supported
Module contents¶
The tfwavelets package offers ways to achieve discrete wavelet transforms in tensorflow.
The package consists of the following modules:
‘nodes’ contains methods to construct TF subgraphs computing the 1D or 2D DWT or IDWT. Intended to be used if you need a DWT in your own TF graph.
‘wrappers’ contains methods that wraps around the functionality in nodes. The construct a full TF graph, launches a session, and evaluates the graph. Intended to be used when you just want to compute the DWT/IDWT of a signal.
‘dwtcoeffs’ contains predefined wavelets, as well as the classes necessary to create more user-defined wavelets.
‘utils’ contains some useful helper functions, mostly used during the implementation of the other modules.