Source code for vitalDSP.transforms.wavelet_transform

"""
Wavelet Transform Module for Physiological Signal Processing

This module provides comprehensive wavelet transform capabilities for physiological
signals including ECG, PPG, EEG, and other vital signs. It implements Discrete
Wavelet Transform (DWT) with multiple mother wavelets and inverse transform
capabilities for signal analysis and reconstruction.

Author: vitalDSP Team
Date: 2025-01-27
Version: 1.0.0

Key Features:
- Discrete Wavelet Transform (DWT) implementation
- Multiple mother wavelets (Haar, Daubechies, Coiflets, etc.)
- Inverse Wavelet Transform for signal reconstruction
- Multi-level decomposition capabilities
- Signal length preservation options
- Integration with mother wavelet utilities

Examples:
--------
Basic wavelet transform:
    >>> import numpy as np
    >>> from vitalDSP.transforms.wavelet_transform import WaveletTransform
    >>> signal = np.sin(np.linspace(0, 10, 1000)) + np.random.normal(0, 0.1, 1000)
    >>> wt = WaveletTransform(signal, wavelet_name="haar")
    >>> coefficients = wt.perform_wavelet_transform()
    >>> print(f"Coefficients shape: {len(coefficients)}")

Signal reconstruction:
    >>> reconstructed = wt.perform_inverse_wavelet_transform(coefficients)
    >>> print(f"Reconstruction error: {np.mean((signal - reconstructed)**2):.6f}")

Different wavelets:
    >>> wt_db4 = WaveletTransform(signal, wavelet_name="db4")
    >>> wt_coif2 = WaveletTransform(signal, wavelet_name="coif2")
    >>> db4_coeffs = wt_db4.perform_wavelet_transform()
    >>> coif2_coeffs = wt_coif2.perform_wavelet_transform()
"""

import numpy as np
from vitalDSP.utils.signal_processing.mother_wavelets import Wavelet
from scipy.signal import convolve


[docs] class WaveletTransform: """ A class to perform Discrete Wavelet Transform (DWT) on signals using different mother wavelets. Methods ------- perform_wavelet_transform : method Computes the DWT of the signal. perform_inverse_wavelet_transform : method Reconstructs the signal using the inverse DWT. """ def __init__(self, signal, wavelet_name="haar", same_length=True): """ Initialize the WaveletTransform class with the signal and select the mother wavelet. Parameters ---------- signal : numpy.ndarray The input signal to be transformed. wavelet_name : str, optional Name of the wavelet to be used (default is 'haar'). same_length : bool, optional If True, the transformed signal will have the same length as the original (default is True). Raises ------ ValueError If the specified wavelet name is not found in the Wavelet class. """ self.signal = signal self.original_length = len(signal) # Store the original length of the signal self.wavelet_name = wavelet_name self.same_length = same_length # Option to maintain the same length # Retrieve the wavelet filters (low_pass, high_pass) from the Wavelet class wavelet_class = Wavelet() # Handle common wavelet naming conventions wavelet_mapping = { "haar": lambda: wavelet_class.haar(), "db1": lambda: wavelet_class.db(order=1), "db2": lambda: wavelet_class.db(order=2), "db3": lambda: wavelet_class.db(order=3), "db4": lambda: wavelet_class.db(order=4), "db5": lambda: wavelet_class.db(order=5), "db6": lambda: wavelet_class.db(order=6), "db7": lambda: wavelet_class.db(order=7), "db8": lambda: wavelet_class.db(order=8), "sym1": lambda: wavelet_class.sym(order=1), "sym2": lambda: wavelet_class.sym(order=2), "sym3": lambda: wavelet_class.sym(order=3), "sym4": lambda: wavelet_class.sym(order=4), "sym5": lambda: wavelet_class.sym(order=5), "sym6": lambda: wavelet_class.sym(order=6), "sym7": lambda: wavelet_class.sym(order=7), "sym8": lambda: wavelet_class.sym(order=8), "coif1": lambda: wavelet_class.coif(order=1), "coif2": lambda: wavelet_class.coif(order=2), "coif3": lambda: wavelet_class.coif(order=3), "coif4": lambda: wavelet_class.coif(order=4), "coif5": lambda: wavelet_class.coif(order=5), # Common abbreviations for continuous wavelets "mexh": lambda: wavelet_class.mexican_hat(), "morl": lambda: wavelet_class.morlet(), "mexican_hat": lambda: wavelet_class.mexican_hat(), "morlet": lambda: wavelet_class.morlet(), } # Get the wavelet method if wavelet_name in wavelet_mapping: wavelet_method = wavelet_mapping[wavelet_name] else: # Try direct method lookup as fallback wavelet_method = getattr(wavelet_class, wavelet_name, None) if wavelet_method is None: raise ValueError(f"Wavelet '{wavelet_name}' not found in Wavelet class.") # Call the wavelet method to get the wavelet coefficients filters = wavelet_method() if isinstance(filters, tuple) and len(filters) == 2: self.low_pass, self.high_pass = filters else: self.low_pass = filters import warnings warnings.warn( f"Wavelet '{self.wavelet_name}' is a continuous wavelet and may not be suitable for DWT. " "Using a default difference filter for the high-pass component.", UserWarning ) self.high_pass = np.array([1, -1]) # Ensure the wavelet filters are numpy arrays self.low_pass = np.asarray(self.low_pass) self.high_pass = np.asarray(self.high_pass) def _wavelet_decompose(self, data): """ Perform a single-level wavelet transform using vectorized convolution. Produces same-length output arrays (undecimated/stationary wavelet transform), suitable for denoising and other applications where coefficient alignment with the original signal is important. Parameters ---------- data : numpy.ndarray The input data to be transformed. Returns ------- tuple approximation : numpy.ndarray Approximation (low-pass) coefficients, same length as input. detail : numpy.ndarray Detail (high-pass) coefficients, same length as input. """ output_length = len(data) filter_len = len(self.low_pass) # Apply padding based on the same_length option if self.same_length: pad_left = (filter_len - 1) // 2 pad_right = filter_len // 2 padded_data = np.pad(data, (pad_left, pad_right), "reflect") else: padded_data = np.pad(data, (0, filter_len - 1), "constant") # OPTIMIZATION: Use vectorized convolution instead of loops try: from scipy.signal import convolve # Vectorized convolution for O(n log n) complexity approximation = convolve(padded_data, self.low_pass[::-1], mode="valid") detail = convolve(padded_data, self.high_pass[::-1], mode="valid") # Ensure output length matches input length if len(approximation) > output_length: approximation = approximation[:output_length] if len(detail) > output_length: detail = detail[:output_length] except ImportError: # Fallback to original implementation if scipy not available approximation = np.zeros(output_length) detail = np.zeros(output_length) # Iterate over the signal and apply the filters for i in range(output_length): data_segment = padded_data[i : i + filter_len] if len(data_segment) == len(self.low_pass): approximation[i] = np.dot(self.low_pass, data_segment) if len(data_segment) == len(self.high_pass): detail[i] = np.dot(self.high_pass, data_segment) return approximation, detail
[docs] def perform_wavelet_transform(self, level=1): """ Perform the Discrete Wavelet Transform (DWT) on the signal. Parameters ---------- level : int, optional The number of decomposition levels (default is 1). Returns ------- list Wavelet coefficients as a list of arrays. ``coeffs[0]`` through ``coeffs[-2]`` are detail coefficient arrays (one per level, in order from finest to coarsest). ``coeffs[-1]`` is the final approximation array. Examples -------- >>> signal = np.sin(np.linspace(0, 10, 100)) + np.random.normal(0, 0.1, 100) >>> wavelet_transform = WaveletTransform(signal, wavelet_name='db') >>> coeffs = wavelet_transform.perform_wavelet_transform(level=3) >>> print(coeffs) """ coeffs = [] data = self.signal.copy() for _ in range(level): approximation, detail = self._wavelet_decompose(data) coeffs.append(detail) data = approximation coeffs.append(data) # Final approximation at the highest level return coeffs
def _wavelet_reconstruct(self, approximation, detail): """ Perform a single-level inverse wavelet transform. Convolves the approximation and detail coefficient arrays with the corresponding reconstruction filters and combines the results. Parameters ---------- approximation : numpy.ndarray Approximation coefficients. detail : numpy.ndarray Detail coefficients. Returns ------- numpy.ndarray Reconstructed data at this level. """ # Convolve approximation and detail coefficients with the corresponding filters approx_conv = convolve(approximation, self.low_pass, mode="full") detail_conv = convolve(detail, self.high_pass, mode="full") # Ensure both convolutions have the same length min_length = min(len(approx_conv), len(detail_conv)) approx_conv = approx_conv[:min_length] detail_conv = detail_conv[:min_length] # Combine the convolved signals and scale appropriately data = (approx_conv + detail_conv) / np.sqrt(2) # Trim the data to maintain the original length if required if self.same_length: data = data[: len(approximation)] return data
[docs] def perform_inverse_wavelet_transform(self, coeffs): """ Perform the Inverse Discrete Wavelet Transform (IDWT) to reconstruct the signal. Parameters ---------- coeffs : list Wavelet coefficients from the wavelet transform. Returns ------- numpy.ndarray Reconstructed signal from the wavelet coefficients. Examples -------- >>> signal = np.sin(np.linspace(0, 10, 100)) + np.random.normal(0, 0.1, 100) >>> wavelet_transform = WaveletTransform(signal, wavelet_name='db') >>> coeffs = wavelet_transform.perform_wavelet_transform(level=3) >>> reconstructed_signal = wavelet_transform.perform_inverse_wavelet_transform(coeffs) >>> print(reconstructed_signal) """ data = coeffs[-1] # Start with the final approximation for detail in reversed(coeffs[:-1]): data = self._wavelet_reconstruct(data, detail) return data[: self.original_length] if self.same_length else data