"""
The 'wrappers' module contains methods that wraps around the functionality in nodes. The
construct a full TF graph, launches a session, and evaluates the graph. Intended to be
used when you just want to compute the DWT/IDWT of a signal.
"""
import numpy as np
import tfwavelets as tfw
import tensorflow as tf
[docs]def dwt1d(signal, wavelet="haar", levels=1):
"""
Computes the DWT of a 1D signal.
Args:
signal (np.ndarray): A 1D array to compute DWT of.
wavelet (str): Type of wavelet (haar or db2)
levels (int): Number of DWT levels
Returns:
np.ndarray: The DWT of the signal.
Raises:
ValueError: If wavelet is not supported
"""
# Prepare signal for tf. Turn into 32bit floats for GPU computation, and
# expand dims to make it into a 3d tensor so tf.nn.conv1d is happy
signal = signal.astype(np.float32)
signal = np.expand_dims(signal, 0)
signal = np.expand_dims(signal, -1)
# Construct and compute TF graph
return _construct_and_compute_graph(
signal,
tfw.nodes.dwt1d,
_parse_wavelet(wavelet),
levels
)
[docs]def dwt2d(signal, wavelet="haar", levels=1):
"""
Computes the DWT of a 2D signal.
Args:
signal (np.ndarray): A 2D array to compute DWT of.
wavelet (str): Type of wavelet (haar or db2)
levels (int): Number of DWT levels
Returns:
np.ndarray: The DWT of the signal.
Raises:
ValueError: If wavelet is not supported
"""
# Prepare signal for tf. Turn into 32bit floats for GPU computation, and
# expand dims to make it into a 3d tensor so tf.nn.conv1d is happy
signal = signal.astype(np.float32)
signal = np.expand_dims(signal, -1)
# Construct and compute TF graph
return _construct_and_compute_graph(
signal,
tfw.nodes.dwt2d,
_parse_wavelet(wavelet),
levels
)
[docs]def idwt1d(signal, wavelet="haar", levels=1):
"""
Computes the inverse DWT of a 1D signal.
Args:
signal (np.ndarray): A 1D array to compute IDWT of.
wavelet (str): Type of wavelet (haar or db2)
levels (int): Number of DWT levels
Returns:
np.ndarray: The IDWT of the signal.
Raises:
ValueError: If wavelet is not supported
"""
# Prepare signal for tf. Turn into 32bit floats for GPU computation, and
# expand dims to make it into a 3d tensor so tf.nn.conv1d is happy
signal = signal.astype(np.float32)
signal = np.expand_dims(signal, 0)
signal = np.expand_dims(signal, -1)
# Construct and compute TF graph
return _construct_and_compute_graph(
signal,
tfw.nodes.idwt1d,
_parse_wavelet(wavelet),
levels
)
[docs]def idwt2d(signal, wavelet="haar", levels=1):
"""
Computes the inverse DWT of a 2D signal.
Args:
signal (np.ndarray): A 2D array to compute iDWT of.
wavelet (str): Type of wavelet (haar or db2)
levels (int): Number of DWT levels
Returns:
np.ndarray: The IDWT of the signal.
Raises:
ValueError: If wavelet is not supported
"""
# Prepare signal for tf. Turn into 32bit floats for GPU computation, and
# expand dims to make it into a 3d tensor so tf.nn.conv1d is happy
signal = signal.astype(np.float32)
signal = np.expand_dims(signal, -1)
# Construct and compute TF graph
return _construct_and_compute_graph(
signal,
tfw.nodes.idwt2d,
_parse_wavelet(wavelet),
levels
)
def _construct_and_compute_graph(input_signal, node, wavelet_obj, levels):
"""
Constructs a TF graph processing the input signal with given node and evaluates it.
Args:
input_signal: Input signal. A 3D array with [batch, signal, channels]
node: Node to process signal with, any kind of dwt/idwt
wavelet_obj: Wavelet object to pass to node
levels: Num of levels (passed to node)
Returns:
"""
# Placeholder for input signal
tf_signal = tf.placeholder(dtype=tf.float32, shape=input_signal.shape)
# Set up tf graph
output = node(tf_signal, wavelet=wavelet_obj, levels=levels)
# Compute graph
with tf.Session() as sess:
signal = sess.run(output, feed_dict={tf_signal: input_signal})
# Remove empty dimensions and return
return np.squeeze(signal)
def _parse_wavelet(wavelet):
"""
Look for wavelet coeffs in database, and return them if they exists
Args:
wavelet (str): Name of wavelet
Returns:
(np.ndarray, np.ndarray): Filters for wavelet
Raises:
ValueError: If wavelet is not supported
"""
if wavelet == "haar":
return tfw.dwtcoeffs.haar
elif wavelet == "db2":
return tfw.dwtcoeffs.db2
elif wavelet == "db3":
return tfw.dwtcoeffs.db3
elif wavelet == "db4":
return tfw.dwtcoeffs.db4
else:
raise ValueError("dwt1d does not support wavelet {}".format(wavelet))