"""
The 'dwtcoeffs' module contains predefined wavelets, as well as the classes necessary to
create more user-defined wavelets.
Wavelets are defined by the Wavelet class. A Wavelet object mainly consists of four Filter
objects (defined by the Filter class) representing the decomposition and reconstruction
low pass and high pass filters.
Examples:
You can define your own wavelet by creating four filters, and combining them to a wavelet:
>>> decomp_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> decomp_hp = Filter([1 / np.sqrt(2), -1 / np.sqrt(2)], 1)
>>> recon_lp = Filter([1 / np.sqrt(2), 1 / np.sqrt(2)], 0)
>>> recon_hp = Filter([-1 / np.sqrt(2), 1 / np.sqrt(2)], 1)
>>> haar = Wavelet(decomp_lp, decomp_hp, recon_lp, recon_hp)
"""
import numpy as np
import tensorflow as tf
from tfwavelets.utils import adapt_filter, to_tf_mat
[docs]class Filter:
"""
Class representing a filter.
Attributes:
coeffs (tf.constant): Filter coefficients
zero (int): Origin of filter (which index of coeffs array is
actually indexed as 0).
edge_matrices (iterable): List of edge matrices, used for circular convolution.
Stored as 3D TF tensors (constants).
"""
def __init__(self, coeffs, zero):
"""
Create a filter based on given filter coefficients
Args:
coeffs (np.ndarray): Filter coefficients
zero (int): Origin of filter (which index of coeffs array is
actually indexed as 0).
"""
self.coeffs = tf.constant(adapt_filter(coeffs), dtype=tf.float32)
if not isinstance(coeffs, np.ndarray):
coeffs = np.array(self.coeffs)
self._coeffs = coeffs.astype(np.float32)
self.zero = zero
self.edge_matrices = to_tf_mat(self._edge_matrices())
def __getitem__(self, item):
"""
Returns filter coefficients at requested indeces. Indeces are offset by the filter
origin
Args:
item (int or slice): Item(s) to get
Returns:
np.ndarray: Item(s) at specified place(s)
"""
if isinstance(item, slice):
return self._coeffs.__getitem__(
slice(item.start + self.zero, item.stop + self.zero, item.step)
)
else:
return self._coeffs.__getitem__(item + self.zero)
[docs] def num_pos(self):
"""
Number of positive indexed coefficients in filter, including the origin. Ie,
strictly speaking it's the number of non-negative indexed coefficients.
Returns:
int: Number of positive indexed coefficients in filter.
"""
return len(self._coeffs) - self.zero
[docs] def num_neg(self):
"""
Number of negative indexed coefficients, excluding the origin.
Returns:
int: Number of negative indexed coefficients
"""
return self.zero
def _edge_matrices(self):
"""Computes the submatrices needed at the ends for circular convolution.
Returns:
Tuple of 2d-arrays, (top-left, top-right, bottom-left, bottom-right).
"""
if not isinstance(self._coeffs, np.ndarray):
self._coeffs = np.array(self._coeffs)
n, = self._coeffs.shape
self._coeffs = self._coeffs[::-1]
# Some padding is necesssary to keep the submatrices
# from having having columns in common
padding = max((self.zero, n - self.zero - 1))
matrix_size = n + padding
filter_matrix = np.zeros((matrix_size, matrix_size), dtype=np.float32)
negative = self._coeffs[
-(self.zero + 1):] # negative indexed filter coeffs (and 0)
positive = self._coeffs[
:-(self.zero + 1)] # filter coeffs with strictly positive indeces
# Insert first row
filter_matrix[0, :len(negative)] = negative
# Because -0 == 0, a length of 0 makes it impossible to broadcast
# (nor is is necessary)
if len(positive) > 0:
filter_matrix[0, -len(positive):] = positive
# Cycle previous row to compute the entire filter matrix
for i in range(1, matrix_size):
filter_matrix[i, :] = np.roll(filter_matrix[i - 1, :], 1)
# TODO: Indexing not thoroughly tested
num_pos = len(positive)
num_neg = len(negative)
top_left = filter_matrix[:num_pos, :(num_pos + num_neg - 1)]
top_right = filter_matrix[:num_pos, -num_pos:]
bottom_left = filter_matrix[-num_neg + 1:, :num_neg - 1]
bottom_right = filter_matrix[-num_neg + 1:, -(num_pos + num_neg - 1):]
# Indexing wrong when there are no negative indexed coefficients
if num_neg == 1:
bottom_left = np.zeros((0, 0), dtype=np.float32)
bottom_right = np.zeros((0, 0), dtype=np.float32)
return top_left, top_right, bottom_left, bottom_right
[docs]class TrainableFilter(Filter):
"""
Class representing a trainable filter.
Attributes:
coeffs (tf.Variable): Filter coefficients
zero (int): Origin of filter (which index of coeffs array is
actually indexed as 0).
"""
def __init__(self, initial_coeffs, zero, name=None):
"""
Create a trainable filter initialized with given filter coefficients
Args:
initial_coeffs (np.ndarray): Initial filter coefficients
zero (int): Origin of filter (which index of coeffs array
is actually indexed as 0).
name (str): Optional. Name of tf variable created to hold
the filter coeffs.
"""
super().__init__(initial_coeffs, zero)
self.coeffs = tf.Variable(
initial_value=adapt_filter(initial_coeffs),
trainable=True,
name=name,
dtype=tf.float32,
constraint=tf.keras.constraints.max_norm(np.sqrt(2), [1, 2])
)
# Erase stuff that will be invalid once the filter coeffs has changed
self._coeffs = [None]*len(self._coeffs)
self.edge_matrices = None
[docs]class Wavelet:
"""
Class representing a wavelet.
Attributes:
decomp_lp (Filter): Filter coefficients for decomposition low pass filter
decomp_hp (Filter): Filter coefficients for decomposition high pass filter
recon_lp (Filter): Filter coefficients for reconstruction low pass filter
recon_hp (Filter): Filter coefficients for reconstruction high pass filter
"""
def __init__(self, decomp_lp, decomp_hp, recon_lp, recon_hp):
"""
Create a new wavelet based on specified filters
Args:
decomp_lp (Filter): Filter coefficients for decomposition low pass filter
decomp_hp (Filter): Filter coefficients for decomposition high pass filter
recon_lp (Filter): Filter coefficients for reconstruction low pass filter
recon_hp (Filter): Filter coefficients for reconstruction high pass filter
"""
self.decomp_lp = decomp_lp
self.decomp_hp = decomp_hp
self.recon_lp = recon_lp
self.recon_hp = recon_hp
[docs]class TrainableWavelet(Wavelet):
"""
Class representing a trainable wavelet
Attributes:
decomp_lp (TrainableFilter): Filter coefficients for decomposition low pass filter
decomp_hp (TrainableFilter): Filter coefficients for decomposition high pass filter
recon_lp (TrainableFilter): Filter coefficients for reconstruction low pass filter
recon_hp (TrainableFilter): Filter coefficients for reconstruction high pass filter
"""
def __init__(self, wavelet):
"""
Create a new trainable wavelet initialized as specified wavelet
Args:
wavelet (Wavelet): Starting point for the trainable wavelet
"""
super().__init__(
TrainableFilter(wavelet.decomp_lp._coeffs, wavelet.decomp_lp.zero),
TrainableFilter(wavelet.decomp_hp._coeffs, wavelet.decomp_hp.zero),
TrainableFilter(wavelet.recon_lp._coeffs, wavelet.recon_lp.zero),
TrainableFilter(wavelet.recon_hp._coeffs, wavelet.recon_hp.zero)
)
# Haar wavelet
haar = Wavelet(
Filter(np.array([0.70710677, 0.70710677]), 1),
Filter(np.array([-0.70710677, 0.70710677]), 0),
Filter(np.array([0.70710677, 0.70710677]), 0),
Filter(np.array([0.70710677, -0.70710677]), 1),
)
# Daubechies wavelets
db1 = haar
db2 = Wavelet(
Filter(np.array([-0.12940952255092145,
0.22414386804185735,
0.836516303737469,
0.48296291314469025]), 3),
Filter(np.array([-0.48296291314469025,
0.836516303737469,
-0.22414386804185735,
-0.12940952255092145]), 0),
Filter(np.array([0.48296291314469025,
0.836516303737469,
0.22414386804185735,
-0.12940952255092145]), 0),
Filter(np.array([-0.12940952255092145,
-0.22414386804185735,
0.836516303737469,
-0.48296291314469025]), 3)
)
db3 = Wavelet(
Filter(np.array([0.035226291882100656,
-0.08544127388224149,
-0.13501102001039084,
0.4598775021193313,
0.8068915093133388,
0.3326705529509569]), 5),
Filter(np.array([-0.3326705529509569,
0.8068915093133388,
-0.4598775021193313,
-0.13501102001039084,
0.08544127388224149,
0.035226291882100656]), 0),
Filter(np.array([0.3326705529509569,
0.8068915093133388,
0.4598775021193313,
-0.13501102001039084,
-0.08544127388224149,
0.035226291882100656]), 0),
Filter(np.array([0.035226291882100656,
0.08544127388224149,
-0.13501102001039084,
-0.4598775021193313,
0.8068915093133388,
-0.3326705529509569]), 5)
)
db4 = Wavelet(
Filter(np.array([-0.010597401784997278,
0.032883011666982945,
0.030841381835986965,
-0.18703481171888114,
-0.02798376941698385,
0.6308807679295904,
0.7148465705525415,
0.23037781330885523]), 7),
Filter(np.array([-0.23037781330885523,
0.7148465705525415,
-0.6308807679295904,
-0.02798376941698385,
0.18703481171888114,
0.030841381835986965,
-0.032883011666982945,
-0.010597401784997278]), 0),
Filter(np.array([0.23037781330885523,
0.7148465705525415,
0.6308807679295904,
-0.02798376941698385,
-0.18703481171888114,
0.030841381835986965,
0.032883011666982945,
-0.010597401784997278]), 0),
Filter(np.array([-0.010597401784997278,
-0.032883011666982945,
0.030841381835986965,
0.18703481171888114,
-0.02798376941698385,
-0.6308807679295904,
0.7148465705525415,
-0.23037781330885523]), 7)
)