Source code for vitalDSP.feature_engineering.ecg_autonomic_features

"""
ECG Autonomic Features Module for Physiological Signal Processing

This module provides comprehensive ECG feature extraction capabilities focusing
on autonomic nervous system analysis. It implements advanced algorithms for
detecting ECG waveform components, computing intervals, and identifying
arrhythmias for cardiovascular health assessment.

Author: vitalDSP Team
Date: 2025-01-27
Version: 1.0.0

Key Features:
- P-wave analysis (duration, amplitude)
- PR Interval computation (P-wave to QRS onset)
- QRS Complex analysis (width, amplitude)
- ST Segment analysis (elevation, depression)
- QT Interval computation (QRS onset to T-wave end)
- Arrhythmia detection (AFib, VTach, Bradycardia)
- Waveform morphology analysis
- Comprehensive ECG feature extraction

Examples:
---------
Basic ECG feature extraction:
    >>> import numpy as np
    >>> from vitalDSP.feature_engineering.ecg_autonomic_features import ECGExtractor
    >>> ecg_signal = np.random.rand(1000)  # Replace with actual ECG signal
    >>> fs = 250  # Sampling frequency in Hz
    >>> extractor = ECGExtractor(ecg_signal, fs)
    >>> p_wave_duration = extractor.compute_p_wave_duration()
    >>> pr_interval = extractor.compute_pr_interval()
    >>> qrs_width = extractor.compute_qrs_width()
    >>> print(f"P-wave Duration: {p_wave_duration}, PR Interval: {pr_interval}")

Advanced ECG analysis:
    >>> qt_interval = extractor.compute_qt_interval()
    >>> st_segment = extractor.compute_st_segment()
    >>> arrhythmias = extractor.detect_arrhythmias()
    >>> print(f"QT Interval: {qt_interval}, ST Segment: {st_segment}")
    >>> print(f"Arrhythmias detected: {arrhythmias}")

Comprehensive feature extraction:
    >>> all_features = extractor.extract_all_features()
    >>> print(f"Extracted {len(all_features)} ECG features")
"""

import numpy as np
from vitalDSP.utils.signal_processing.peak_detection import PeakDetection
from vitalDSP.physiological_features.waveform import WaveformMorphology


[docs] class ECGExtractor: """ A class to extract ECG features including: - P-wave analysis (duration, amplitude) - PR Interval (P-wave to QRS onset) - QRS Complex (width, amplitude) - ST Segment (elevation, depression) - QT Interval (QRS onset to T-wave end) - Detection of Arrhythmias (AFib, VTach, Bradycardia) Example usage:: ecg_signal = np.random.rand(1000) # Replace with actual ECG signal fs = 250 # Sampling frequency in Hz extractor = ECGExtractor(ecg_signal, fs) p_wave_duration = extractor.compute_p_wave_duration() pr_interval = extractor.compute_pr_interval() qrs_width = extractor.compute_qrs_width() qt_interval = extractor.compute_qt_interval() st_segment = extractor.compute_st_segment() arrhythmias = extractor.detect_arrhythmias() print(f"P-wave Duration: {p_wave_duration}, PR Interval: {pr_interval}, QRS Width: {qrs_width}") """ def __init__(self, ecg_signal, sampling_frequency): if not isinstance(ecg_signal, np.ndarray): raise TypeError("Input signal must be a numpy array") if len(ecg_signal) < 2: raise ValueError("ECG signal is too short to compute features") if np.isnan(ecg_signal).any() or np.isinf(ecg_signal).any(): raise ValueError("ECG signal contains invalid values") self.ecg_signal = ecg_signal self.fs = sampling_frequency # Initialize the WaveformMorphology for Q, R, S, T detection self.morphology = WaveformMorphology( ecg_signal, fs=sampling_frequency, signal_type="ECG" )
[docs] def detect_r_peaks(self): """ Detects R-peaks from the ECG signal using WaveformMorphology. Returns: np.array: Array of indices where R-peaks are detected. """ detector = PeakDetection(self.ecg_signal, method="ecg_r_peak") r_peaks = detector.detect_peaks() if len(r_peaks) == 0: raise ValueError("No R-peaks detected in ECG signal") return r_peaks
def _find_p_onset(self, p_peak, search_start, derivative): """Find P-wave onset by looking for the last derivative zero-crossing before the P-peak, or fall back to a percentage-based boundary.""" onset = search_start if p_peak > search_start and p_peak < len(derivative): seg_deriv = derivative[search_start:p_peak] if len(seg_deriv) > 0: zero_crossings = np.where(np.diff(np.sign(seg_deriv)))[0] if len(zero_crossings) > 0: onset = search_start + zero_crossings[-1] else: # Fallback: use the point of minimum absolute derivative onset = search_start + np.argmin(np.abs(seg_deriv)) return onset def _pair_p_peaks_q_valleys(self, p_peaks, q_valleys): """Pair each P-peak with the nearest Q-valley that follows it.""" pairs = [] q_idx = 0 for p in p_peaks: while q_idx < len(q_valleys) and q_valleys[q_idx] <= p: q_idx += 1 if q_idx < len(q_valleys): pairs.append((p, q_valleys[q_idx])) return pairs
[docs] def compute_p_wave_duration(self, r_peaks=None): """ Computes the P-wave duration by finding the onset and offset around each detected P-peak. Returns: float: Mean duration of P-waves in seconds. """ if r_peaks is None: r_peaks = self.detect_r_peaks() q_valleys = self.morphology.detect_q_valley(r_peaks=r_peaks) p_peaks = self.morphology.detect_p_peak(r_peaks=r_peaks, q_valleys=q_valleys) if len(p_peaks) == 0: return 0.0 pairs = self._pair_p_peaks_q_valleys(p_peaks, q_valleys) if len(pairs) == 0: return 0.0 durations = [] derivative = np.diff(self.ecg_signal) for p_peak, q_valley in pairs: search_start = max(0, p_peak - int(self.fs * 0.12)) onset = self._find_p_onset(p_peak, search_start, derivative) offset = q_valley if q_valley > p_peak and p_peak < len(derivative): seg_deriv = derivative[p_peak:q_valley] if len(seg_deriv) > 0: zero_crossings = np.where(np.diff(np.sign(seg_deriv)))[0] if len(zero_crossings) > 0: offset = p_peak + zero_crossings[0] else: offset = p_peak + np.argmin(np.abs(seg_deriv)) if offset > onset: durations.append((offset - onset) / self.fs) return np.mean(durations) if durations else 0.0
[docs] def compute_pr_interval(self, r_peaks=None): """ Computes the PR interval from P-wave onset to QRS onset (Q-valley). Returns: float: Mean PR interval in seconds. """ if r_peaks is None: r_peaks = self.detect_r_peaks() q_valleys = self.morphology.detect_q_valley(r_peaks=r_peaks) p_peaks = self.morphology.detect_p_peak(r_peaks=r_peaks, q_valleys=q_valleys) if len(p_peaks) == 0 or len(q_valleys) == 0: return 0.0 pairs = self._pair_p_peaks_q_valleys(p_peaks, q_valleys) if len(pairs) == 0: return 0.0 derivative = np.diff(self.ecg_signal) intervals = [] for p_peak, q_valley in pairs: search_start = max(0, p_peak - int(self.fs * 0.12)) p_onset = self._find_p_onset(p_peak, search_start, derivative) if q_valley > p_onset: intervals.append((q_valley - p_onset) / self.fs) return np.mean(intervals) if intervals else 0.0
[docs] def compute_qrs_duration(self, r_peaks=None): """ Computes the QRS duration using WaveformMorphology. Returns: float: The mean duration of QRS complexes in seconds. """ if r_peaks is None: r_peaks = self.detect_r_peaks() # Detect R-peaks first qrs_durations = self.morphology.detect_qrs_session(r_peaks) # Convert to numpy array if it's a list (e.g., from mocked return value) if isinstance(qrs_durations, list): qrs_durations = np.array(qrs_durations) # Check if empty using len() for compatibility with both list and array if len(qrs_durations) == 0: return 0.0 durations = [(end - start) / self.fs for start, end in qrs_durations] return np.mean(durations)
[docs] def compute_s_wave(self, r_peaks=None): """ Detects the S-wave based on the R-peaks using WaveformMorphology. Returns: np.array: Indices of detected S-wave points. """ if r_peaks is None: r_peaks = self.detect_r_peaks() return self.morphology.detect_s_session(r_peaks)
[docs] def compute_qt_interval(self): """ Computes the QT interval (from QRS onset to T-wave end). Returns: float: QT interval in seconds. """ q_valleys = self.morphology.detect_q_valley() t_peaks = self.morphology.detect_t_peak() if len(q_valleys) == 0 or len(t_peaks) == 0: return 0.0 if q_valleys[0] > t_peaks[0]: t_peaks = t_peaks[1:] if len(t_peaks) == 0: return 0.0 sessions = [] for i in range(min(len(q_valleys), len(t_peaks))): q_start = q_valleys[i] t_end = t_peaks[i] if q_start < t_end: sessions.append((q_start, t_end)) if len(sessions) == 0: return 0.0 sessions = np.array(sessions) return self.morphology.compute_duration(sessions=sessions, mode="Custom")
[docs] def compute_st_interval(self): """ Computes the ST segment duration (from S-wave to T-wave peak). Returns: float: Mean ST segment duration in seconds. """ s_valleys = self.morphology.detect_s_valley() t_peaks = self.morphology.detect_t_peak() if len(s_valleys) == 0 or len(t_peaks) == 0: return 0.0 if s_valleys[0] > t_peaks[0]: t_peaks = t_peaks[1:] if len(t_peaks) == 0: return 0.0 sessions = [] for i in range(min(len(s_valleys), len(t_peaks))): s_start = s_valleys[i] t_end = t_peaks[i] if s_start < t_end: sessions.append((s_start, t_end)) if len(sessions) == 0: return 0.0 sessions = np.array(sessions) return self.morphology.compute_duration(sessions=sessions, mode="Custom")
[docs] def detect_arrhythmias(self, r_peaks=None): """ Detects basic arrhythmias such as: - Atrial Fibrillation (AFib) - Ventricular Tachycardia (VTach) - Bradycardia (slow heart rate) Returns: dict: Dictionary containing the detected arrhythmias. """ if r_peaks is None: r_peaks = self.detect_r_peaks() rr_intervals = np.diff(r_peaks) / self.fs mean_rr = np.mean(rr_intervals) arrhythmias = {"AFib": False, "VTach": False, "Bradycardia": False} # Detect Atrial Fibrillation (AFib): Irregular RR intervals if np.std(rr_intervals) > 0.2 * mean_rr: arrhythmias["AFib"] = True # Detect Ventricular Tachycardia (VTach): Sustained high heart rate (> 100 bpm) if np.mean(rr_intervals) < 0.6: # Corresponds to HR > 100 bpm arrhythmias["VTach"] = True # Detect Bradycardia: Slow heart rate (< 60 bpm) if np.mean(rr_intervals) > 1.0: # Corresponds to HR < 60 bpm arrhythmias["Bradycardia"] = True return arrhythmias