Source code for vitalDSP.filtering.artifact_removal

"""
Signal Filtering Module for Physiological Signal Processing

This module provides comprehensive capabilities for physiological
signal processing including ECG, PPG, EEG, and other vital signs.

Author: vitalDSP Team
Date: 2025-01-27
Version: 1.0.0

Key Features:
- Object-oriented design with comprehensive classes
- Multiple processing methods and functions
- NumPy integration for numerical computations
- SciPy integration for advanced signal processing

Examples:
--------
Basic usage:
    >>> import numpy as np
    >>> from vitalDSP.filtering.artifact_removal import ArtifactRemoval
    >>> signal = np.random.randn(1000)
    >>> processor = ArtifactRemoval(signal)
    >>> result = processor.process()
    >>> print(f'Processing result: {result}')
"""

import numpy as np
from scipy.signal import butter, filtfilt, convolve, medfilt
from scipy.signal.windows import gaussian
from vitalDSP.utils.signal_processing.mother_wavelets import Wavelet
from sklearn.decomposition import IncrementalPCA


[docs] class ArtifactRemoval: """ A class for removing various types of artifacts from signals. Methods ------- mean_subtraction : function Removes artifacts by subtracting the mean of the signal. baseline_correction : function Corrects baseline drift by applying a high-pass filter. median_filter_removal : function Removes spike artifacts using a median filter. wavelet_denoising : function Removes noise using wavelet-based denoising with various mother wavelets. adaptive_filtering : function Uses an adaptive filter to remove artifacts correlated with reference signals. notch_filter : function Removes powerline interference using a notch filter. pca_artifact_removal : function Uses Principal Component Analysis (PCA) to remove artifacts. ica_artifact_removal : function Uses Independent Component Analysis (ICA) to remove artifacts using NumPy. """ def __init__(self, signal): """ Initialize the ArtifactRemoval class with the signal. Parameters ---------- signal : numpy.ndarray The input signal from which artifacts need to be removed. Notes ----- - The signal should be a 1D array. - If the signal is not already a NumPy array, it will be converted. """ if not isinstance(signal, np.ndarray): signal = np.array(signal) self.signal = signal
[docs] def mean_subtraction(self): """ Remove artifacts by subtracting the mean of the signal. This method is effective for removing constant or slow-varying baseline artifacts, which are common in many physiological signals like ECG or EEG. Returns ------- clean_signal : numpy.ndarray The artifact-removed signal. Examples -------- >>> signal = np.array([1, 2, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.mean_subtraction() >>> print(clean_signal) [0 1 2 3 4] """ return self.signal - np.mean(self.signal)
[docs] def baseline_correction(self, cutoff=0.5, fs=1000): """ Correct baseline drift by applying a high-pass filter. This method is particularly effective for removing low-frequency baseline wander in signals such as ECG or PPG, where baseline drift can obscure important features. Parameters ---------- cutoff : float The cutoff frequency for the high-pass filter. fs : float The sampling frequency of the signal. Returns ------- clean_signal : numpy.ndarray The baseline-corrected signal. Examples -------- >>> signal = np.array([1, 2, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.baseline_correction(cutoff=0.5, fs=1000) >>> print(clean_signal) [-0.4995 -0.4995 -0.4995 -0.4995 -0.4995] """ from scipy.signal import butter, filtfilt nyquist = 0.5 * fs if cutoff >= nyquist: cutoff = nyquist * 0.99 normal_cutoff = cutoff / nyquist b, a = butter(2, normal_cutoff, btype='high') clean_signal = filtfilt(b, a, self.signal) return clean_signal
[docs] def median_filter_removal(self, kernel_size=3): """ Remove spike artifacts using a median filter. This method is particularly useful for removing sharp spikes or noise in the signal, such as motion artifacts in PPG or EOG signals. Parameters ---------- kernel_size : int The size of the median filter kernel. A larger kernel size will smooth more but may remove important signal features. Returns ------- clean_signal : numpy.ndarray The artifact-removed signal. Examples -------- >>> signal = np.array([1, 100, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.median_filter_removal(kernel_size=3) >>> print(clean_signal) [1 3 4 4 5] """ padded_signal = np.pad( self.signal, (kernel_size // 2, kernel_size // 2), mode="edge" ) clean_signal = np.zeros_like(self.signal) for i in range(len(self.signal)): clean_signal[i] = np.median(padded_signal[i : i + kernel_size]) return clean_signal
[docs] def wavelet_denoising( self, wavelet_type="db", level=1, order=4, custom_wavelet=None, smoothing="lowpass", **smoothing_params, ): """ Remove noise using wavelet-based denoising with various mother wavelets. This method decomposes the signal into approximation and detail coefficients using wavelets, thresholds the detail coefficients, and reconstructs the signal. It is effective for denoising signals where noise is present at multiple scales. Parameters ---------- wavelet_type : str, optional The type of wavelet to use ('haar', 'db', 'sym', 'coif', 'custom'). Default is 'db'. level : int, optional The level of decomposition. Higher levels capture more global features. Default is 1. order : int, optional The order of the wavelet (used for 'db', 'sym', and 'coif' wavelets). Default is 4. custom_wavelet : numpy.ndarray, optional A custom wavelet provided by the user if `wavelet_type` is 'custom'. Default is None. Returns ------- clean_signal : numpy.ndarray The denoised signal with the same length as the original signal. Examples -------- >>> signal = np.array([1, 2, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.wavelet_denoising(wavelet_type='db', level=2, order=4) >>> print(clean_signal) >>> # Example using a custom wavelet >>> custom_wavelet = np.array([0.2, 0.5, 0.2]) >>> clean_signal = ar.wavelet_denoising(wavelet_type='custom', custom_wavelet=custom_wavelet) >>> print(clean_signal) """ wavelet = Wavelet() if wavelet_type == "haar": mother_wavelet = wavelet.haar() elif wavelet_type == "db": mother_wavelet = wavelet.db(order) elif wavelet_type == "sym": mother_wavelet = wavelet.sym(order) elif wavelet_type == "coif": mother_wavelet = wavelet.coif(order) elif wavelet_type == "custom": if custom_wavelet is None: raise ValueError( "A custom wavelet must be provided if wavelet_type is 'custom'." ) mother_wavelet = custom_wavelet else: raise ValueError( "Invalid wavelet_type. Must be 'haar', 'db', 'sym', 'coif', or 'custom'." ) # Wavelet decomposition approx_coeffs = self.signal.copy() detail_coeffs = [] N = len(mother_wavelet) high_pass = np.array([(-1)**n * mother_wavelet[N - 1 - n] for n in range(N)]) for _ in range(level): # Convolution with the low-pass and high-pass filters (approximation and detail coefficients) approx = np.convolve(approx_coeffs, mother_wavelet, mode="full") detail = np.convolve(approx_coeffs, high_pass, mode="full") # Downsample approx_coeffs = approx[::2] detail_coeffs.append(detail[::2]) # Thresholding detail coefficients threshold = ( np.sqrt(2 * np.log(len(self.signal))) * np.median(np.abs(detail_coeffs[-1])) / 0.6745 ) for i in range(len(detail_coeffs)): detail_coeffs[i] = np.sign(detail_coeffs[i]) * np.maximum( np.abs(detail_coeffs[i]) - threshold, 0 ) # Wavelet reconstruction for i in reversed(range(level)): # Upsample upsampled_approx = np.zeros(len(approx_coeffs) * 2) upsampled_approx[::2] = np.real(approx_coeffs) upsampled_detail = np.zeros(len(detail_coeffs[i]) * 2) upsampled_detail[::2] = np.real(detail_coeffs[i]) # Truncate to the same length before summing upsampled_approx = upsampled_approx[: len(upsampled_detail)] approx_coeffs = ( np.convolve(upsampled_approx, mother_wavelet, mode="full")[ : len(upsampled_approx) ] + np.convolve(upsampled_detail, high_pass, mode="full")[ : len(upsampled_detail) ] ) # Adjust the length of the output signal to match the original signal length clean_signal = approx_coeffs[: len(self.signal)] # Apply smoothing if specified if smoothing: clean_signal = self._apply_smoothing( clean_signal, smoothing, fs=smoothing_params.get('fs', 100.0), **{k: v for k, v in smoothing_params.items() if k != 'fs'} ) return clean_signal
@staticmethod def _apply_smoothing(signal, method, **kwargs): """ Apply a smoothing method to the signal. Parameters ---------- signal : numpy.ndarray The input signal to be smoothed. method : str The type of smoothing to apply ('lowpass', 'gaussian', 'median', 'moving_average'). kwargs : dict Additional parameters for the smoothing method. Returns ------- smoothed_signal : numpy.ndarray The smoothed signal. """ if method == "lowpass": cutoff = kwargs.get("cutoff", 0.2) fs = kwargs.get("fs", 1.0) order = kwargs.get("order", 5) return ArtifactRemoval._lowpass_filter(signal, cutoff, fs, order) elif method == "gaussian": sigma = kwargs.get("sigma", 1.0) return ArtifactRemoval._gaussian_smoothing(signal, sigma) elif method == "median": kernel_size = kwargs.get("kernel_size", 3) return ArtifactRemoval._median_smoothing(signal, kernel_size) elif method == "moving_average": window_size = kwargs.get("window_size", 5) return ArtifactRemoval._moving_average_smoothing(signal, window_size) else: raise ValueError(f"Unsupported smoothing method: {method}") @staticmethod def _gaussian_smoothing(signal, sigma): size = int(6 * sigma + 1) gaussian_kernel = gaussian(size, sigma) smoothed_signal = convolve(signal, gaussian_kernel, mode="same") / np.sum( gaussian_kernel ) return smoothed_signal @staticmethod def _median_smoothing(signal, kernel_size): return medfilt(signal, kernel_size) @staticmethod def _moving_average_smoothing(signal, window_size): kernel = np.ones(window_size) / window_size smoothed_signal = convolve(signal, kernel, mode="same") return smoothed_signal @staticmethod def _lowpass_filter(signal, cutoff, fs=100.0, order=5): """ Apply a low-pass Butterworth filter to smooth the signal. Parameters ---------- signal : numpy.ndarray The input signal to be smoothed. cutoff : float The cutoff frequency for the low-pass filter. fs : float The sampling frequency of the signal. order : int, optional The order of the Butterworth filter. Default is 5. Returns ------- smoothed_signal : numpy.ndarray The smoothed signal. """ nyquist = 0.5 * fs normal_cutoff = cutoff / nyquist b, a = butter(order, normal_cutoff, btype="low", analog=False) smoothed_signal = filtfilt(b, a, signal) return smoothed_signal
[docs] def adaptive_filtering( self, reference_signal=None, learning_rate=0.01, num_iterations=100 ): """ Use an adaptive filter to remove artifacts correlated with a reference signal. This method uses Least Mean Squares (LMS) adaptive filtering to iteratively adjust the signal to minimize the error between the filtered signal and a reference signal. It is particularly useful for removing artifacts that are correlated with another signal, such as EOG artifacts in EEG recordings, motion artifacts in PPG, or respiratory artifacts. If no reference signal is provided, the filter adapts towards zero (artifact removal/denoising). Parameters ---------- reference_signal : numpy.ndarray, optional The reference signal correlated with the artifact. If None, adapts towards zero (removes DC offset and baseline drift). Must have the same length as the input signal. learning_rate : float, default=0.01 The learning rate (step size) for the adaptive filter. Controls convergence speed. Typical range: 0.001 - 0.5 - Lower values (0.001-0.01): Slower convergence, more stable - Higher values (0.1-0.5): Faster convergence, may oscillate num_iterations : int, default=100 The number of iterations for adaptation. More iterations = better convergence. Returns ------- clean_signal : numpy.ndarray The artifact-removed signal with the same length as the input. Raises ------ ValueError If reference_signal length doesn't match signal length. Notes ----- **Algorithm**: Least Mean Squares (LMS) adaptive filtering - The filter iteratively adjusts the signal based on the error - Error = filtered_signal - reference_signal - Update rule: filtered_signal -= learning_rate * error **Use Cases**: 1. **With reference signal**: Remove correlated artifacts (EOG from EEG, motion from PPG) 2. **Without reference signal**: General denoising and DC offset removal **Convergence**: - Monitor the error reduction to ensure proper convergence - If the filter diverges (error increases), reduce the learning rate - Typical convergence: 50-200 iterations with learning_rate=0.01 See Also -------- baseline_correction : For simple baseline drift removal mean_subtraction : For DC offset removal Examples -------- **Example 1: Remove EOG artifacts from EEG using reference EOG signal** >>> import numpy as np >>> # Simulate EEG contaminated with EOG >>> eeg_clean = np.sin(2*np.pi*10*np.linspace(0, 1, 1000)) # 10 Hz alpha wave >>> eog_artifact = 2 * np.sin(2*np.pi*2*np.linspace(0, 1, 1000)) # 2 Hz eye movement >>> eeg_contaminated = eeg_clean + 0.5 * eog_artifact >>> >>> # Remove EOG artifact using reference EOG channel >>> ar = ArtifactRemoval(eeg_contaminated) >>> eeg_cleaned = ar.adaptive_filtering( ... reference_signal=eog_artifact, ... learning_rate=0.05, ... num_iterations=150 ... ) >>> print(f"Original SNR: {10*np.log10(np.var(eeg_clean)/np.var(eeg_contaminated-eeg_clean)):.2f} dB") >>> print(f"Cleaned SNR: {10*np.log10(np.var(eeg_clean)/np.var(eeg_cleaned-eeg_clean)):.2f} dB") **Example 2: Remove motion artifacts from PPG using accelerometer reference** >>> # Simulate PPG with motion artifact >>> ppg_signal = np.sin(2*np.pi*1.2*np.linspace(0, 10, 1000)) # Heart rate ~72 bpm >>> motion_artifact = 0.8 * np.sin(2*np.pi*0.5*np.linspace(0, 10, 1000)) # Motion >>> ppg_contaminated = ppg_signal + motion_artifact >>> >>> # Remove motion using accelerometer as reference >>> ar = ArtifactRemoval(ppg_contaminated) >>> ppg_cleaned = ar.adaptive_filtering( ... reference_signal=motion_artifact, ... learning_rate=0.02, ... num_iterations=100 ... ) **Example 3: General denoising without reference signal** >>> # Noisy signal >>> signal = np.sin(2*np.pi*5*np.linspace(0, 2, 500)) + 0.3*np.random.randn(500) >>> ar = ArtifactRemoval(signal) >>> denoised = ar.adaptive_filtering(learning_rate=0.01, num_iterations=100) >>> print(f"DC offset removed: {np.mean(signal):.3f} -> {np.mean(denoised):.3f}") **Example 4: Respiratory artifact removal from ECG** >>> # ECG with respiratory baseline wander >>> ecg = np.sin(2*np.pi*1.2*np.linspace(0, 5, 1000)) # Heart rate >>> respiratory = 0.4 * np.sin(2*np.pi*0.3*np.linspace(0, 5, 1000)) # Breathing >>> ecg_with_resp = ecg + respiratory >>> >>> # Remove using respiratory belt signal as reference >>> ar = ArtifactRemoval(ecg_with_resp) >>> ecg_cleaned = ar.adaptive_filtering( ... reference_signal=respiratory, ... learning_rate=0.03, ... num_iterations=120 ... ) """ # Ensure the signal is cast to float64 for numerical stability filtered_signal = self.signal.astype(np.float64) # Handle reference signal if reference_signal is None: # Adapt towards zero (denoising/DC removal) reference_signal = np.zeros_like(filtered_signal, dtype=np.float64) else: # Validate and convert reference signal reference_signal = np.asarray(reference_signal, dtype=np.float64) if len(reference_signal) != len(self.signal): raise ValueError( f"Reference signal length ({len(reference_signal)}) must match " f"signal length ({len(self.signal)})" ) # LMS adaptive filtering for _ in range(num_iterations): error = filtered_signal - reference_signal filtered_signal -= learning_rate * error return filtered_signal
[docs] def notch_filter(self, freq=50, fs=1000, Q=30): """ Remove powerline interference using a notch filter. This method is effective for removing specific frequency artifacts like powerline interference (50/60 Hz) from physiological signals. Parameters ---------- freq : float The frequency to be removed (e.g., 50 Hz for powerline interference). fs : float The sampling frequency of the signal. Q : float The quality factor of the notch filter, which controls the bandwidth of the filter. Returns ------- clean_signal : numpy.ndarray The artifact-removed signal. Examples -------- >>> signal = np.array([1, 2, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.notch_filter(freq=50, fs=1000, Q=30) >>> print(clean_signal) """ from scipy.signal import iirnotch, filtfilt b, a = iirnotch(freq, Q, fs) clean_signal = filtfilt(b, a, self.signal) return clean_signal
[docs] def pca_artifact_removal(self, num_components=1, window_size=100, overlap=50): """ Use Principal Component Analysis (PCA) to remove artifacts. This method removes artifacts by reconstructing the signal with a reduced number of principal components, which can be particularly useful for signals with multiple overlapping noise sources. Parameters ---------- num_components : int The number of principal components to retain. window_size : int, optional The size of each window used to segment the signal (default is 100). overlap : int, optional The number of samples that each window should overlap (default is 50). Returns ------- clean_signal : numpy.ndarray The artifact-removed signal. Examples -------- >>> signal = np.array([1, 2, 3, 4, 5]) >>> ar = ArtifactRemoval(signal) >>> clean_signal = ar.pca_artifact_removal(num_components=1, window_size=2, overlap=1) >>> print(clean_signal) """ # Segment the signal into overlapping windows segments = [] for start in range( 0, len(self.signal) - window_size + 1, window_size - overlap ): segment = self.signal[start : start + window_size] if len(segment) == window_size: segments.append(segment) segments = np.array(segments) # If there are no valid segments, return the original signal if segments.size == 0: raise ValueError( "No valid segments available for PCA. Ensure the signal has sufficient length and variability." ) # Center the segments by subtracting the mean signal_mean = np.mean(segments, axis=0) centered_signal = segments - signal_mean # Handle cases where the covariance matrix might be degenerate covariance_matrix = np.cov(centered_signal, rowvar=False) # Check if covariance matrix is 2D if covariance_matrix.ndim < 2 or covariance_matrix.shape[0] < num_components: raise ValueError( "Covariance matrix is not 2D or has insufficient dimensionality. Ensure the input signal has sufficient variation." ) # Perform eigenvalue decomposition with error handling try: eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) except np.linalg.LinAlgError as e: raise ValueError( f"Linear Algebra error during eigenvalue decomposition: {e}" ) # Sort eigenvalues and eigenvectors in descending order sorted_indices = np.argsort(eigenvalues)[::-1] selected_components = eigenvectors[:, sorted_indices[:num_components]] # Reconstruct the signal using the selected components reconstructed_segments = ( np.dot(centered_signal, selected_components).dot(selected_components.T) + signal_mean ) # Reconstruct the full signal by averaging overlapping windows clean_signal = np.zeros(len(self.signal)) count = np.zeros(len(self.signal)) for i, start in enumerate( range(0, len(self.signal) - window_size + 1, window_size - overlap) ): clean_signal[start : start + window_size] += reconstructed_segments[i] count[start : start + window_size] += 1 # Avoid division by zero by checking count clean_signal = np.divide( clean_signal, count, out=np.zeros_like(clean_signal), where=count != 0 ) return clean_signal
[docs] def ica_artifact_removal( self, num_components=1, max_iterations=1000, tol=1e-5, seed=23, window_size=None, step_size=None, batch_size=1000, ): """ Use Independent Component Analysis (ICA) to remove artifacts. This enhanced version automatically handles 1D signals by creating synthetic components. ICA separates the signal into independent components and allows for the removal of specific components identified as artifacts. For 1D signals (single channel), synthetic components are automatically generated using derivatives and delayed versions, enabling ICA to separate artifacts from the underlying physiological signal. Parameters ---------- num_components : int The number of independent components to retain. For 1D signals, this determines how many synthetic components to generate. Recommended: 3-5 for good artifact separation. max_iterations : int The maximum number of iterations for convergence. tol : float The tolerance level for convergence. seed : int The seed for random number generation to ensure reproducibility. window_size : int, optional The size of the sliding window to create a multi-dimensional signal. If None, automatic synthetic component generation is used for 1D signals. step_size : int, optional The step size for the sliding window. Must be used with window_size. batch_size : int, optional The batch size for IncrementalPCA to manage memory usage (legacy parameter). Returns ------- clean_signal : numpy.ndarray The artifact-removed signal (same shape as input). Notes ----- **For 1D Signals (Single Channel)**: The method automatically creates synthetic components from: - Original signal - First derivative (captures rapid changes/spikes) - Delayed version (captures temporal patterns) - Second derivative (captures acceleration/motion artifacts) - Smoothed version (captures baseline trends) **For Multi-Channel Signals**: Uses traditional windowing approach if window_size is specified. Examples -------- >>> # Example 1: 1D signal (most common case) >>> import numpy as np >>> signal_1d = np.sin(2*np.pi*np.linspace(0,10,1000)) + 0.1*np.random.randn(1000) >>> ar = ArtifactRemoval(signal_1d) >>> clean = ar.ica_artifact_removal(num_components=3) >>> print(clean.shape) # (1000,) >>> >>> # Example 2: With windowing (for backward compatibility) >>> signal = np.array([1, 2, 3, 4, 5, 6, 7, 8]) >>> ar = ArtifactRemoval(signal) >>> clean = ar.ica_artifact_removal(num_components=1, window_size=4, step_size=2) >>> print(clean.shape) # (8,) """ # Use enhanced ICA from blind_source_separation for 1D signals if self.signal.ndim == 1 and not (window_size and step_size): from vitalDSP.signal_quality_assessment.blind_source_separation import ( ica_artifact_removal as enhanced_ica, ) # Use the enhanced ICA that handles 1D signals with synthetic components clean_signal = enhanced_ica( self.signal, max_iter=max_iterations, tol=tol, auto_synthetic=True, n_components=max( 3, num_components + 2 ), # Ensure enough components for good separation ) return clean_signal # Legacy path: Apply windowing if window_size and step_size are provided if window_size and step_size: segments = [ self.signal[i : i + window_size] for i in range(0, len(self.signal) - window_size + 1, step_size) ] multi_dimensional_signal = np.array(segments).T else: # Fallback for edge case raise ValueError( "For 1D signals without windowing, ICA will use automatic synthetic components. " "This should not be reached. Please report this as a bug." ) # Validate the number of components based on the signal's dimensionality n_features = multi_dimensional_signal.shape[1] if num_components > n_features: raise ValueError( f"n_components={num_components} invalid for n_features={n_features}. " f"For 1D signals, omit window_size/step_size to use automatic synthetic components." ) # Center the signal signal_centered = multi_dimensional_signal - np.mean( multi_dimensional_signal, axis=0 ) # Apply IncrementalPCA for whitening ipca = IncrementalPCA(n_components=num_components, batch_size=batch_size) X_whitened = ipca.fit_transform(signal_centered) # Initialize weights randomly np.random.seed(seed) W = np.random.rand(num_components, X_whitened.shape[1]) for i in range(max_iterations): # Update the weights using the FastICA algorithm W_new = ( np.dot(X_whitened.T, np.tanh(np.dot(X_whitened, W.T))) / X_whitened.shape[0] - np.mean(1 - np.tanh(np.dot(X_whitened, W.T)) ** 2, axis=0) * W ) W_new /= np.linalg.norm(W_new, axis=1)[ :, np.newaxis ] # Normalize the weights # Check for convergence if np.max(np.abs(np.abs(np.diag(np.dot(W_new, W.T))) - 1)) < tol: W = W_new break W = W_new # Separate the independent components S = np.dot(W, X_whitened.T).T # Reconstruct the signal from the components reconstructed_signal = np.dot(S, np.linalg.pinv(W)).dot(ipca.components_).dot( ipca.components_.T ) + np.mean(multi_dimensional_signal, axis=0) # Stack segments and calculate the mean across the segments if windowing is applied if window_size and step_size: stacked_segments = np.column_stack( [reconstructed_signal[i] for i in range(len(reconstructed_signal))] ) final_signal = np.mean(stacked_segments, axis=1) else: final_signal = reconstructed_signal.flatten() return final_signal