"""
Health Analysis Module for Physiological Signal Processing
This module provides comprehensive capabilities for physiological
signal processing including ECG, PPG, EEG, and other vital signs.
Author: vitalDSP Team
Date: 2025-01-27
Version: 1.0.0
Key Features:
- Object-oriented design with comprehensive classes
- Multiple processing methods and functions
- NumPy integration for numerical computations
- SciPy integration for advanced signal processing
- Interactive visualization capabilities
Examples:
--------
Basic usage:
>>> import numpy as np
>>> from vitalDSP.health_analysis.health_report_visualization import HealthReportVisualization
>>> signal = np.random.randn(1000)
>>> processor = HealthReportVisualization(signal)
>>> result = processor.process()
>>> print(f'Processing result: {result}')
"""
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import pandas as pd
from statsmodels.graphics.tsaplots import plot_acf
from statsmodels.tsa.seasonal import seasonal_decompose
from scipy.signal import spectrogram, periodogram, welch
from matplotlib.patches import Rectangle
from scipy.stats import pearsonr
from collections import Counter
from scipy.interpolate import make_interp_spline
from vitalDSP.utils.config_utilities.common import find_peaks
import logging
import threading
# Set thread-safe matplotlib backend
matplotlib.use("Agg")
# Enable thread safety
matplotlib.rcParams["figure.max_open_warning"] = 0
# Create a global lock for matplotlib operations
matplotlib_lock = threading.Lock()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]
class HealthReportVisualizer:
"""
A class responsible for creating visualizations of health feature data, including line plots and heatmaps.
The class takes feature data and creates visualizations, such as normal distributions for ranges and heatmaps, and stores them as images.
"""
def __init__(self, config, segment_duration="1_min"):
"""
Initializes the Visualization class by loading the feature configuration.
Args:
config (dict): Configuration data that includes normal ranges and interpretations for features.
Example Usage:
>>> visualization = Visualization(config)
"""
if not isinstance(config, dict):
raise TypeError("Config should be a dictionary.")
self.config = config
self.logger = logging.getLogger(__name__)
self.segment_duration = segment_duration
def _thread_safe_matplotlib_operation(self, func, *args, **kwargs):
"""
Execute matplotlib operations in a thread-safe manner.
Args:
func: The matplotlib function to execute
*args: Arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function execution
"""
with matplotlib_lock:
try:
# Clear any existing figures to prevent conflicts
plt.clf()
plt.close("all")
# Execute the function
result = func(*args, **kwargs)
# Ensure proper cleanup
plt.clf()
return result
except Exception as e:
# Clean up on error
plt.clf()
plt.close("all")
raise e
def _normalize_web_path(self, filepath):
"""
Normalize file path for web usage by ensuring forward slashes.
Args:
filepath (str): The file path to normalize.
Returns:
str: Web-compatible path with forward slashes.
"""
if filepath is None:
return None
# Convert backslashes to forward slashes for web compatibility
return filepath.replace("\\", "/")
def _fetch_and_validate_normal_range(self, feature, value):
"""
Fetches the normal range for a given feature, handles NaN, Inf values, and validates the feature.
Args:
feature (str): The feature name.
value (float): The current value for the feature.
Returns:
tuple: (normal_min, normal_max, feature_names) where feature_names are ["Min Range", "Max Range", "Current Value"].
Raises:
ValueError: If the feature has NaN or invalid values, or normal range is not found.
"""
normal_range = self._get_normal_range_for_feature(feature)
if not normal_range:
raise ValueError(f"Normal range for feature '{feature}' not found.")
normal_min, normal_max = normal_range
# Initialize valid normal range
valid_values = []
# Collect valid values, avoiding NaN and Inf
if not np.isnan(normal_min) and not np.isinf(normal_min):
valid_values.append(normal_min)
if not np.isnan(normal_max) and not np.isinf(normal_max):
valid_values.append(normal_max)
if not np.isnan(value) and not np.isinf(value):
valid_values.append(value)
# Check if we have valid values to compute the normal range
if not valid_values:
raise ValueError(f"All values for feature '{feature}' are NaN or Inf.")
# Compute new normal_min and normal_max based on valid values
normal_min = min(valid_values)
normal_max = max(valid_values)
if len(valid_values) == 1:
value = valid_values[0]
return (value - 2 * value), (value + 2 * value)
return normal_min, normal_max
[docs]
def create_visualizations(self, feature_data, output_dir="visualizations"):
"""
Creates visualizations for the provided feature data and saves them as image files.
Args:
feature_data (dict): Dictionary containing feature values.
output_dir (str): Directory where the visualizations will be saved.
Returns:
dict: Dictionary with feature names as keys and paths to the saved visualizations as values.
Example Usage:
>>> visualizations = visualization.create_visualizations(feature_data)
>>> print(visualizations)
"""
if output_dir is None:
output_dir = "visualizations"
os.makedirs(output_dir, exist_ok=True)
visualization_paths = {}
for feature, values in feature_data.items():
if not isinstance(values, list):
values = [values] # Handle list (multiple segments) for each feature
visualization_paths[feature] = {
"heatmap": self._normalize_web_path(
self._create_heatmap_plot(feature, values, output_dir)
),
"bell_plot": self._normalize_web_path(
self._create_bell_shape_plot(feature, values, output_dir)
),
"radar_plot": self._normalize_web_path(
self._create_radar_plot(feature, values, output_dir)
),
"violin_plot": self._normalize_web_path(
self._create_violin_plot(feature, values, output_dir)
),
"line_with_rolling_stats": self._normalize_web_path(
self._create_plot_line_with_rolling_stats(
feature, values, output_dir
)
),
"lag_plot": self._normalize_web_path(
self._create_plot_lag(feature, values, output_dir)
),
"plot_periodogram": self._normalize_web_path(
self._create_plot_periodogram(feature, values, output_dir)
),
"plot_spectrogram": self._normalize_web_path(
self._create_plot_spectrogram(feature, values, output_dir)
),
"plot_spectral_density": self._normalize_web_path(
self._create_spectral_density_plot(feature, values, output_dir)
),
"plot_box_swarm": self._normalize_web_path(
self._create_box_swarm_plot(feature, values, output_dir)
),
}
return visualization_paths
def _create_bell_shape_plot(self, feature, values, output_dir):
"""
Creates a bell shape plot with an overlaid histogram for better understanding.
Caps outliers at 1.5 times the normal range.
Args:
feature (str): The name of the feature.
values (list): List of values for the feature.
output_dir (str): Directory where the plot will be saved.
Returns:
str: Path to the saved bell shape plot image.
"""
try:
# Fetch normal range
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
# Outlier thresholds
outlier_min = normal_min - 1.5 * (normal_max - normal_min)
outlier_max = normal_max + 1.5 * (normal_max - normal_min)
# Cap the values to fit the normal range
capped_values = np.clip(values, outlier_min, outlier_max)
# Calculate statistics
mean_value = (
normal_min + normal_max
) / 2 # Center bell curve between normal min and max
stddev_value = (
normal_max - normal_min
) / 4 # Standard deviation should reflect the range
# Create the figure
plt.figure(figsize=(10, 6))
# Histogram
sns.histplot(
capped_values,
bins=15,
kde=False,
color="#a4c6f7",
edgecolor="black",
alpha=0.7,
label="Histogram",
)
# Generate x values for the bell curve
x = np.linspace(
normal_min - 2 * stddev_value, normal_max + 2 * stddev_value, 200
)
# Bell curve calculation
bell_curve = (1 / (stddev_value * np.sqrt(2 * np.pi))) * np.exp(
-0.5 * ((x - mean_value) / stddev_value) ** 2
)
# Plot the bell-shaped curve
plt.plot(
x,
bell_curve * max(np.histogram(capped_values, bins=15)[0]),
color="#007bff",
linestyle="--",
label="Bell Curve",
)
# Mark vertical lines for the normal range, median, and mean
plt.axvline(
normal_min,
color="#f39c12",
linestyle="--",
label=f"Normal Min: {normal_min:.2f}",
)
plt.axvline(
normal_max,
color="#e74c3c",
linestyle="--",
label=f"Normal Max: {normal_max:.2f}",
)
plt.axvline(
np.median(capped_values),
color="#2ecc71",
linestyle="-",
label=f"Median: {np.median(capped_values):.2f}",
)
plt.axvline(
mean_value,
color="#8e44ad",
linestyle="--",
label=f"Mean: {mean_value:.2f}",
)
# Highlight outliers
outliers = [
value for value in values if value < outlier_min or value > outlier_max
]
for outlier in outliers:
plt.scatter(
outlier,
0,
color="#e74c3c",
s=100,
zorder=5,
edgecolor="black",
label="Outlier" if outlier == outliers[0] else "",
)
# Add annotation for key statistics
plt.annotate(
f"Mean: {mean_value:.2f}\nStd Dev: {stddev_value:.2f}",
xy=(mean_value, bell_curve.max() / 2),
xytext=(mean_value + 1, bell_curve.max() / 1.5),
arrowprops=dict(facecolor="black", shrink=0.05),
fontsize=12,
)
# Calculate and annotate percentage of values within the normal range
in_range = [v for v in values if normal_min <= v <= normal_max]
percentage_in_range = (len(in_range) / len(values)) * 100
plt.annotate(
f"{percentage_in_range:.1f}% within normal range",
xy=(normal_min, bell_curve.max() / 1.5),
fontsize=12,
color="blue",
)
# Add title and labels
plt.title(f"{feature} Bell Shape Plot with Histogram and KDE", fontsize=16)
plt.xlabel(f"{feature} Values", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
# Display legend
plt.legend(loc="upper right", fontsize=12)
# Save the plot
filepath = os.path.join(output_dir, f"{feature}_bell_plot.png")
plt.savefig(filepath, bbox_inches="tight", dpi=300)
plt.close()
except Exception as e:
self.logger.error(f"Error generating heatmap for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_spectral_density_plot(
self,
feature,
values,
output_dir,
sampling_rate=1,
nfft=16,
nperseg=8,
highlight_freqs=None,
segment_overlap=20,
peak_threshold=None,
):
"""
Creates an enhanced spectral density plot with smoother curves, confidence intervals,
and highlighted regions of interest (ROI) in the frequency domain.
Args:
feature (str): Name of the feature being analyzed.
values (np.array): Time-series data for the feature.
output_dir (str): Directory to save the plot.
sampling_rate (int): Sampling rate of the time-series data.
nfft (int): Number of data points used in each block for the FFT.
highlight_freqs (list of tuples): List of (min_freq, max_freq) ranges to highlight specific frequencies.
Returns:
str: Path to the saved spectral density plot image.
"""
try:
# Convert to numpy array if needed (FIX for list input bug)
values = np.asarray(values)
# Ensure we have enough data points for spectral analysis
if len(values) < 8: # Minimum required for nperseg=8
raise ValueError(f"Insufficient data points: {len(values)} < 8")
# Calculate the Power Spectral Density (PSD) using Welch's method
freqs, psd = welch(values, fs=sampling_rate, nfft=nfft, nperseg=nperseg)
freqs = freqs * (60 / segment_overlap)
# Create a smooth curve using interpolation for the PSD
freqs_smooth = np.linspace(freqs.min(), freqs.max(), 500)
# Use safer interpolation method
try:
spline = make_interp_spline(freqs, psd, k=3)
psd_smooth = spline(freqs_smooth)
except Exception as interp_error:
# Fallback to linear interpolation if spline fails
self.logger.warning(
f"Using linear interpolation fallback: {interp_error}"
)
psd_smooth = np.interp(freqs_smooth, freqs, psd)
# Create the plot
plt.figure(figsize=(10, 6))
# Plot the smoothed PSD curve with safety check for log10
psd_smooth_safe = np.maximum(psd_smooth, 1e-10) # Avoid log(0)
plt.plot(
freqs_smooth,
10 * np.log10(psd_smooth_safe),
color="blue",
label="Spectral Density (dB)",
linewidth=2,
)
# Confidence interval shading (using an arbitrary ±10% range for illustration)
lower_bound = 10 * np.log10(psd_smooth_safe * 0.9)
upper_bound = 10 * np.log10(psd_smooth_safe * 1.1)
plt.fill_between(
freqs_smooth,
lower_bound,
upper_bound,
color="blue",
alpha=0.2,
label="Confidence Interval",
)
# Highlight specific frequency ranges if provided
if highlight_freqs:
for min_freq, max_freq in highlight_freqs:
plt.axvspan(
min_freq,
max_freq,
color="orange",
alpha=0.3,
label=f"ROI: {min_freq}-{max_freq} Hz",
)
# Detect peaks in the PSD and highlight them with safety check
if peak_threshold:
psd_safe = np.maximum(psd, 1e-10) # Avoid log(0)
peaks, _ = find_peaks(10 * np.log10(psd_safe), height=peak_threshold)
plt.scatter(
freqs[peaks],
10 * np.log10(psd_safe[peaks]),
color="red",
zorder=5,
s=80,
edgecolors="black",
label="Peaks",
)
# Mark the dominant frequency
dominant_freq = freqs[np.argmax(psd)]
plt.axvline(
x=dominant_freq,
color="green",
linestyle="--",
label=f"Dominant Frequency: {dominant_freq:.2f} Hz",
)
# Plot trend line with safety check
z = np.polyfit(freqs_smooth, 10 * np.log10(psd_smooth_safe), 1)
p = np.poly1d(z)
plt.plot(
freqs_smooth,
p(freqs_smooth),
color="red",
linestyle="--",
label="Trend Line",
alpha=0.7,
)
# Add grid, labels, title, and legend
plt.grid(True, which="both", linestyle="--", alpha=0.5)
plt.title(
f"Spectral Density Plot of {feature} (with Peaks & Trend)", fontsize=16
)
plt.xlabel("Frequency (Hz)", fontsize=14)
plt.ylabel("Power/Frequency (dB/Hz)", fontsize=14)
# Optionally use logarithmic scale for the frequency axis
plt.xscale("log")
# Move the legend outside the plot
plt.legend(loc="upper right", bbox_to_anchor=(1.3, 1.0))
# Save the plot
filepath = os.path.join(
output_dir, f"{feature}_enhanced_spectral_density_plot.png"
)
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Save with error handling
plt.savefig(
filepath, bbox_inches="tight", dpi=150
) # Lower DPI for compatibility
plt.close()
# Verify file was created successfully
if not os.path.exists(filepath):
raise FileNotFoundError(f"Failed to create plot file: {filepath}")
except Exception as e:
self.logger.error(
f"Error generating spectral density plot for {feature}: {e}"
)
return f"Error generating plot for {feature}"
return filepath
def _create_plot_spectrogram(
self,
feature,
values,
output_dir,
sampling_rate=1,
nfft=16,
nperseg=4,
noverlap=2,
threshold=0.55,
seg_overlap=20,
):
"""
Plots a spectrogram for a given time-series data.
Args:
time_series (np.array): Time-series data for which to compute the spectrogram.
sampling_rate (int): The sampling rate of the time-series data.
output_dir (str): Directory where the spectrogram image will be saved.
nfft (int): Number of data points used in each block for FFT. Higher values give better frequency resolution.
noverlap (int): Number of points to overlap between segments.
Returns:
str: File path to the saved spectrogram plot image.
"""
try:
# Convert values to numpy array
values = np.array(values)
# Validate minimum data length
min_required = max(nperseg, nfft)
if len(values) < min_required:
raise ValueError(
f"Insufficient data points for spectrogram: {len(values)} < {min_required}"
)
# Compute the spectrogram
frequencies, times, Sxx = spectrogram(
values, fs=sampling_rate, nfft=nfft, nperseg=nperseg, noverlap=noverlap
)
# Validate spectrogram output shape
if Sxx.shape[0] < 2 or Sxx.shape[1] < 2:
raise ValueError(
f"Input z must be at least a (2, 2) shaped array, but has shape {Sxx.shape}"
)
# Convert frequencies from Hz to CPM (Cycles Per Minute) and times to minutes
frequencies_cpm = frequencies * 60 / seg_overlap
times_minutes = times / (60 / seg_overlap)
# Automatically detect ROI based on the power in the spectrogram
roi_time, roi_freq = self.auto_detect_roi(
Sxx, times_minutes, frequencies_cpm, threshold
)
# Start plotting
plt.figure(figsize=(10, 6))
# Plot the spectrogram with safety check for log10
Sxx_safe = np.maximum(Sxx, 1e-10) # Avoid log(0)
plt.pcolormesh(
times_minutes,
frequencies_cpm,
10 * np.log10(Sxx_safe),
shading="gouraud",
cmap="magma",
)
plt.title(f"Spectrogram of {feature} Over Time", fontsize=14)
plt.ylabel("Frequency (Cycles Per Minute)", fontsize=12)
plt.xlabel("Time (Minutes)", fontsize=12)
plt.colorbar(label="Power (dB)")
# Highlight areas above a threshold with contour using safe values
threshold_safe = max(threshold, 1e-10) # Avoid log(0)
plt.contour(
times_minutes,
frequencies_cpm,
10 * np.log10(Sxx_safe),
levels=[10 * np.log10(threshold_safe)],
colors="white",
linewidths=1,
)
# Add rectangles for the Region(s) of Interest (ROI)
roi_rect = Rectangle(
(roi_time[0], roi_freq[0]), # Bottom-left corner (x, y)
roi_time[1] - roi_time[0], # Width (end_time - start_time)
roi_freq[1] - roi_freq[0], # Height (end_freq - start_freq)
linewidth=2,
edgecolor="red",
facecolor="none",
linestyle="--",
)
plt.gca().add_patch(roi_rect)
# Annotate the ROI in the plot
plt.text(
roi_time[0],
roi_freq[1],
"R.o.I",
color="red",
fontsize=12,
verticalalignment="bottom",
horizontalalignment="left",
)
# Save the plot as an image file
filepath = os.path.join(output_dir, f"{feature}_enhanced_spectrogram.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(f"Error generating spectrogram for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_plot_periodogram(
self,
feature,
values,
output_dir,
sampling_rate=1,
nfft=16,
threshold=0.55,
seg_overlap=20,
):
"""
Plots a spectrogram for a given time-series data.
Args:
time_series (np.array): Time-series data for which to compute the spectrogram.
sampling_rate (int): The sampling rate of the time-series data.
output_dir (str): Directory where the spectrogram image will be saved.
nfft (int): Number of data points used in each block for FFT. Higher values give better frequency resolution.
noverlap (int): Number of points to overlap between segments.
Returns:
str: File path to the saved spectrogram plot image.
"""
try:
# Ensure values are converted to a numpy array for FFT computation
values = np.array(values)
# Compute the periodogram
frequencies, Pxx = periodogram(values, fs=sampling_rate, nfft=nfft)
# Convert frequencies from Hz to CPM (Cycles Per Minute)
frequencies_cpm = frequencies * 60
# Create a vibrant plot
plt.figure(figsize=(12, 8))
# Use a lively color palette and thicker line for the periodogram
plt.semilogy(
frequencies_cpm, Pxx, color="royalblue", linewidth=2.5, alpha=0.8
)
# Highlight the threshold line if applicable
if threshold is not None:
plt.axhline(
threshold,
color="crimson",
linestyle="--",
label="Threshold",
linewidth=1.8,
alpha=0.7,
)
# Add mean and standard deviation lines for additional context
mean_power = np.mean(Pxx)
std_power = np.std(Pxx)
plt.axhline(
mean_power,
color="forestgreen",
linestyle=":",
label=f"Mean Power: {mean_power:.2f}",
linewidth=2.0,
)
plt.axhline(
mean_power + std_power,
color="orange",
linestyle=":",
label=f"Mean + 1 Std Dev: {mean_power + std_power:.2f}",
linewidth=1.5,
)
plt.axhline(
mean_power - std_power,
color="orange",
linestyle=":",
label=f"Mean - 1 Std Dev: {mean_power - std_power:.2f}",
linewidth=1.5,
)
# Add a fill between the standard deviation range for better visualization
plt.fill_between(
frequencies_cpm,
mean_power - std_power,
mean_power + std_power,
color="lightgoldenrodyellow",
alpha=0.4,
label="1 Std Dev Range",
)
# Add labels and title with enhanced styling
plt.xlabel("Frequency (Cycles Per Minute)", fontsize=14, fontweight="bold")
plt.ylabel("Power Spectral Density (PSD)", fontsize=14, fontweight="bold")
plt.title(
f"Periodogram of {feature}",
fontsize=16,
fontweight="bold",
color="darkslategray",
)
plt.legend(loc="upper right", fontsize=12)
# Save the plot
filepath = os.path.join(output_dir, f"{feature}_periodogram_plot.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(f"Error generating periodogram for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_plot_line_with_rolling_stats(
self, feature, values, output_dir, time=None, window=10
):
"""
Plots a line chart with rolling mean and standard deviation over time.
Args:
data (list or np.array): Time series data points.
time (list or np.array): Corresponding time points.
output_dir (str): Directory to save the plot.
window (int): Window size for rolling statistics.
"""
try:
if time is None:
time = np.arange(len(values))
# Create DataFrame for rolling statistics
df = pd.DataFrame({"data": values, "time": time})
df["rolling_mean"] = df["data"].rolling(window=window).mean()
df["rolling_std"] = df["data"].rolling(window=window).std()
df["rolling_median"] = df["data"].rolling(window=window).median()
df["quantile_25"] = df["data"].rolling(window=window).quantile(0.25)
df["quantile_75"] = df["data"].rolling(window=window).quantile(0.75)
# Create the plot
plt.figure(figsize=(12, 6))
# Plot original data
plt.plot(df["time"], df["data"], label="Data", color="blue", alpha=0.7)
# Plot rolling mean and standard deviation bands
plt.plot(
df["time"],
df["rolling_mean"],
label="Rolling Mean",
color="orange",
linestyle="--",
linewidth=2,
)
plt.fill_between(
df["time"],
df["rolling_mean"] - df["rolling_std"],
df["rolling_mean"] + df["rolling_std"],
color="orange",
alpha=0.3,
label="Rolling Std Dev",
)
# Plot rolling median and quantiles for additional insights
plt.plot(
df["time"],
df["rolling_median"],
label="Rolling Median",
color="green",
linestyle="--",
linewidth=2,
)
plt.fill_between(
df["time"],
df["quantile_25"],
df["quantile_75"],
color="green",
alpha=0.2,
label="Interquartile Range (25-75%)",
)
# Highlight the normal range if provided
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
plt.axhline(
normal_min,
color="red",
linestyle="--",
linewidth=1.5,
label=f"Normal Min: {normal_min}",
)
plt.axhline(
normal_max,
color="red",
linestyle="--",
linewidth=1.5,
label=f"Normal Max: {normal_max}",
)
# Highlight important events (peaks/troughs) in the data
peaks = find_peaks(df["data"])
troughs = find_peaks(-df["data"])
plt.scatter(
df["time"].iloc[peaks],
df["data"].iloc[peaks],
color="purple",
s=100,
label="Peaks",
zorder=5,
)
plt.scatter(
df["time"].iloc[troughs],
df["data"].iloc[troughs],
color="cyan",
s=100,
label="Troughs",
zorder=5,
)
# Add titles, labels, and legends
plt.title(
f"Time-Series Data with Enhanced Rolling Statistics ({feature})",
fontsize=14,
)
plt.xlabel("Time", fontsize=12)
plt.ylabel("Values", fontsize=12)
plt.legend(loc="upper right", fontsize=10, frameon=True)
plt.grid(True, which="both", linestyle="--", linewidth=0.7, alpha=0.6)
# Save the plot
filepath = os.path.join(
output_dir, f"{feature}_line_with_enhanced_stats.png"
)
plt.savefig(filepath, bbox_inches="tight", dpi=300)
plt.close()
except Exception as e:
self.logger.error(
f"Error creating line chart with rolling statistics for {feature}: {e}"
)
return f"Error generating plot for {feature}"
return filepath
def _create_plot_autocorrelation(self, feature, values, output_dir, lags=5):
"""
Plots the autocorrelation function (ACF) for the time-series data.
Args:
data (list or np.array): Time series data points.
output_dir (str): Directory to save the plot.
lags (int): Number of lags to plot.
"""
try:
plt.figure(figsize=(10, 6))
plot_acf(values, lags=lags)
plt.title("Autocorrelation Plot")
filepath = os.path.join(output_dir, f"{feature}_autocorrelation_plot.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(f"Error creating autocorrelation plot for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_box_swarm_plot(self, feature, values, output_dir):
"""
Creates an enhanced box plot combined with a swarm plot for visualizing the feature value compared to the normal range.
Args:
feature (str): The name of the feature.
values (list): The values of the feature for the current segment.
output_dir (str): Directory where the plot will be saved.
Returns:
str: Path to the saved box + swarm plot image.
"""
try:
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
# Define the outlier threshold
outlier_min = normal_min - 1.5 * (normal_max - normal_min)
outlier_max = normal_max + 1.5 * (normal_max - normal_min)
# Cap the values to fit within the plot and mark outliers
capped_values = []
for value in values:
if value < outlier_min:
capped_values.append(outlier_min)
elif value > outlier_max:
capped_values.append(outlier_max)
else:
capped_values.append(value)
# Set up the figure
plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")
# Create boxplot for overall distribution
sns.boxplot(capped_values, color="#D0E1F9", linewidth=2, width=0.3)
# Overlay swarm plot for individual data points with enhanced appearance
sns.swarmplot(
capped_values, color="#34495E", size=6, edgecolor="black", alpha=0.8
)
# Add shaded normal range
plt.fill_between(
[-0.4, 0.4],
normal_min,
normal_max,
color="#A9DFBF",
alpha=0.25,
label="Normal Range",
)
# Add statistical markers for key values
mean_value = (normal_max + normal_min) / 2
plt.axhline(
mean_value,
color="#27AE60",
linestyle="--",
label=f"Normal Mean: {mean_value:.2f}",
)
plt.axhline(
normal_min,
color="#E74C3C",
linestyle="--",
alpha=0.6,
label=f"Normal Min: {normal_min:.2f}",
)
plt.axhline(
normal_max,
color="#E74C3C",
linestyle="--",
alpha=0.6,
label=f"Normal Max: {normal_max:.2f}",
)
# Style enhancements for modern appearance
plt.title(
f"{feature} Distribution and Data Points", fontsize=16, color="#2C3E50"
)
plt.xlabel("")
plt.ylabel(f"{feature} Values", fontsize=12, color="#2C3E50")
plt.xticks([]) # No x-ticks since there's only one category
# Add legend
plt.legend(
loc="upper right",
fontsize=10,
frameon=True,
fancybox=True,
framealpha=0.7,
shadow=True,
)
# Improve layout aesthetics
plt.grid(visible=False)
sns.despine(left=True, bottom=True) # Remove spines for a cleaner look
# Add labels for mean, min, max (optional but informative)
plt.text(
0.35,
mean_value,
f"{mean_value:.2f}",
verticalalignment="center",
color="#27AE60",
)
plt.text(
0.35,
normal_min,
f"{normal_min:.2f}",
verticalalignment="center",
color="#E74C3C",
)
plt.text(
0.35,
normal_max,
f"{normal_max:.2f}",
verticalalignment="center",
color="#E74C3C",
)
# Save the plot
filepath = os.path.join(
output_dir, f"{feature}_enhanced_box_swarm_plot.png"
)
plt.savefig(
filepath, bbox_inches="tight", dpi=300
) # Save with higher resolution for clarity
plt.close()
except Exception as e:
self.logger.error(
f"Error creating enhanced box + swarm plot for {feature}: {e}"
)
return f"Error generating plot for {feature}"
return filepath
def _create_plot_lag(self, feature, values, output_dir, lags=3):
"""
Creates an enhanced lag plot for the time-series data with additional visual cues like trend line,
density coloring, and correlation coefficient.
Args:
feature (str): Name of the feature to label the plot.
values (list or np.array): Time series data points.
output_dir (str): Directory to save the plot.
lags (int): Number of lags to plot.
Returns:
str: Path to the saved enhanced lag plot image.
"""
try:
# Ensure values are converted to a numpy array for indexing
values = np.array(values)
# Validate data length
if len(values) <= lags:
raise ValueError(
f"The length of the data ({len(values)}) must be greater than the lag ({lags})."
)
# Ensure that data_lagged and data_original have matching lengths
data_lagged = values[:-lags]
data_original = values[lags:]
# Create the enhanced lag plot
plt.figure(figsize=(10, 6))
# Use Seaborn's kdeplot to show density distribution
sns.kdeplot(
x=data_lagged,
y=data_original,
fill=True,
cmap="Blues",
thresh=0.05,
alpha=0.5,
)
# Plot scatter points
plt.scatter(
data_lagged,
data_original,
alpha=0.6,
label="Data Points",
color="royalblue",
s=50,
)
# Plot a trend line using a simple linear regression (polyfit)
coeffs = np.polyfit(data_lagged, data_original, 1)
trend_line = np.polyval(coeffs, data_lagged)
plt.plot(
data_lagged,
trend_line,
color="darkorange",
label="Trend Line",
linewidth=2,
linestyle="--",
)
# Calculate and display Pearson correlation coefficient
corr_coeff, _ = pearsonr(data_lagged, data_original)
plt.text(
0.05,
0.95,
f"Correlation (r) = {corr_coeff:.3f}",
transform=plt.gca().transAxes,
fontsize=12,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
# Additional Information: Mean and Median of Original Values
mean_original = np.mean(data_original)
median_original = np.median(data_original)
plt.axhline(
mean_original,
color="green",
linestyle=":",
label=f"Mean: {mean_original:.2f}",
)
plt.axhline(
median_original,
color="purple",
linestyle=":",
label=f"Median: {median_original:.2f}",
)
# Labels and title
plt.title(f"Enhanced Lag Plot of {feature} with Lag = {lags}", fontsize=16)
plt.xlabel(f"{feature} (t)", fontsize=14)
plt.ylabel(f"{feature} (t+{lags})", fontsize=14)
# Display the legend
plt.legend(loc="upper left", fontsize=12)
# Save the plot
filepath = os.path.join(output_dir, f"{feature}_enhanced_lag_plot.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
return filepath
except Exception as e:
self.logger.error(f"Error creating enhanced lag plot for {feature}: {e}")
return f"Error generating plot for {feature}"
def _create_plot_seasonal_decomposition(
self, feature, values, output_dir, freq=1, time=None
):
"""
Decomposes the time-series data into trend, seasonality, and residuals.
Args:
data (list or np.array): Time series data points.
time (list or np.array): Corresponding time points.
output_dir (str): Directory to save the plot.
freq (int): Frequency of seasonality in data (e.g., 12 for monthly data).
"""
try:
if time is None:
time = np.arange(len(values))
df = pd.DataFrame({"data": values}, index=pd.to_datetime(time))
decomposition = seasonal_decompose(
df["data"], model="additive", period=freq
)
fig = decomposition.plot()
fig.set_size_inches(10, 8)
plt.suptitle("Seasonal Decomposition of Time-Series Data", y=1.05)
filepath = os.path.join(output_dir, f"{feature}_seasonal_decomposition.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(
f"Error creating seasonal decomposition plot for {feature}: {e}"
)
return f"Error generating plot for {feature}"
return filepath
def _create_violin_plot(self, feature, values, output_dir):
"""
Creates an enhanced violin plot for visualizing the feature value compared to the normal range.
Args:
feature (str): The name of the feature.
value (list): The values of the feature for the current segment.
output_dir (str): Directory where the plot will be saved.
Returns:
str: Path to the saved violin plot image.
"""
try:
# Fetch normal range
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
# Define the outlier threshold
outlier_min = normal_min - 1.5 * (normal_max - normal_min)
outlier_max = normal_max + 1.5 * (normal_max - normal_min)
# Cap the values to fit within the plot and mark outliers
capped_values = np.clip(values, outlier_min, outlier_max)
# Create data for the violin plot (normal distribution)
# data = np.random.normal(
# loc=(normal_max + normal_min) / 2,
# scale=(normal_max - normal_min) / 6,
# size=1000,
# )
plt.figure(figsize=(10, 7))
# Create the violin plot for the capped values with updated density_norm parameter
sns.violinplot(
data=capped_values,
color="lightgreen",
inner="quartile",
linewidth=2,
density_norm="area",
)
# Overlay shaded normal range
plt.fill_betweenx(
[normal_min, normal_max],
0.25,
0.75,
color="lightblue",
alpha=0.5,
label="Normal Range",
)
# Add horizontal lines for the normal range statistics
mean_value = (normal_max + normal_min) / 2
plt.axhline(
mean_value,
color="green",
linestyle="--",
label=f"Normal Range Mean: {mean_value:.2f}",
)
plt.axhline(
normal_min,
color="blue",
linestyle="--",
alpha=0.6,
label=f"Normal Min: {normal_min:.2f}",
)
plt.axhline(
normal_max,
color="red",
linestyle="--",
alpha=0.6,
label=f"Normal Max: {normal_max:.2f}",
)
# Plot each capped value with dynamic scatter sizes
value_counts = Counter(capped_values)
for value, count in value_counts.items():
color = (
"red" if value == outlier_min or value == outlier_max else "blue"
)
plt.scatter(
[0.5] * count,
[value] * count,
color=color,
zorder=5,
s=50 + count * 10,
edgecolors="black",
)
# Add statistical information
median_value = np.median(capped_values)
plt.scatter(
0.5,
median_value,
color="purple",
zorder=6,
s=150,
edgecolors="black",
label=f"Median: {median_value:.2f}",
)
# Titles and labels
plt.title(f"{feature} Enhanced Violin Plot", fontsize=16)
plt.ylabel(f"{feature} Values", fontsize=14)
plt.xticks([]) # Remove x-ticks as there's only one plot axis
# Add a legend
plt.legend(loc="upper right", fontsize=12)
# Save the plot
filepath = os.path.join(output_dir, f"{feature}_enhanced_violin_plot.png")
plt.savefig(filepath, bbox_inches="tight", dpi=300)
plt.close()
except Exception as e:
self.logger.error(f"Error creating enhanced violin plot for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_heatmap_plot(self, feature, values, output_dir):
"""
Creates a heatmap plot for visualizing the feature values with highlights for the normal range.
Args:
feature (str): The name of the feature.
values (list): The list of values for the current segment.
output_dir (str): Directory where the plot will be saved.
Returns:
str: Path to the saved heatmap plot image.
"""
try:
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
# Set outlier thresholds
outlier_min = normal_min - 1.5 * (normal_max - normal_min)
outlier_max = normal_max + 1.5 * (normal_max - normal_min)
# Create x-axis values for the heatmap
x = np.linspace(
normal_min - 1.5 * (normal_max - normal_min),
normal_max + 1.5 * (normal_max - normal_min),
100,
)
y = np.exp(
-((x - np.mean([normal_min, normal_max])) ** 2)
/ (2 * ((normal_max - normal_min) / 2) ** 2)
)
heatmap_data = np.outer(y, y)
plt.figure(figsize=(8, 4))
sns.heatmap(
heatmap_data,
cmap="coolwarm",
cbar=False,
xticklabels=False,
yticklabels=False,
)
# Track the number of occurrences for each value
value_counts = {}
for value in values:
capped_value = min(max(value, outlier_min), outlier_max)
if capped_value in value_counts:
value_counts[capped_value] += 1
else:
value_counts[capped_value] = 1
# Plot scatter points with size based on frequency of the value
for capped_value, count in value_counts.items():
value_interp = np.interp(capped_value, x, np.linspace(0, 100, len(x)))
scatter_size = (
100 + (count - 1) * 50
) # Increase size for repeated values
plt.scatter(
[value_interp],
[50],
color="#3498db",
s=scatter_size,
zorder=5,
edgecolor="white",
)
# Mark the normal range
plt.axvline(
np.interp(normal_min, x, np.linspace(0, 100, len(x))),
color="blue",
linestyle="-",
label="Normal Min",
)
plt.axvline(
np.interp(normal_max, x, np.linspace(0, 100, len(x))),
color="red",
linestyle="-",
label="Normal Max",
)
plt.title(f"Feature: {feature}\nNormal Range: [{normal_min}, {normal_max}]")
plt.legend()
# Save heatmap
filepath = os.path.join(output_dir, f"{feature}_heatmap.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(f"Error creating heatmap plot for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _create_radar_plot(self, feature, values, output_dir):
"""
Creates an enhanced radar plot that shows how the feature value compares to the normal range.
Each axis represents different statistics of the signal, and the plot shows how the current feature deviates
from its normal range.
Args:
feature (str): The name of the feature.
values (list): The list of values for the current segment.
output_dir (str): Directory where the plot will be saved.
Returns:
str: Path to the saved radar plot image.
"""
try:
# Fetch normal range
normal_min, normal_max = self._fetch_and_validate_normal_range(
feature, values[0]
)
feature_names = ["Min Range", "Max Range", "Mid Range", "Mean", "Std Dev"]
# Handle NaN and Inf cases
if np.isnan(normal_min) or np.isnan(normal_max) or np.any(np.isnan(values)):
raise ValueError(f"NaN value encountered in feature '{feature}'.")
if np.isinf(normal_min):
normal_min = -10 * abs(np.min(values))
if np.isinf(normal_max):
normal_max = 10 * abs(np.max(values))
# Outlier thresholds
outlier_min = normal_min - 1.5 * (normal_max - normal_min)
outlier_max = normal_max + 1.5 * (normal_max - normal_min)
# Cap the values to fit within the plot and mark outliers
capped_values = np.clip(values, outlier_min, outlier_max)
# Statistical metrics for radar plot
mean_value = np.mean(capped_values)
std_value = np.std(capped_values)
# Ensure that all statistical values are finite
if not np.isfinite(mean_value) or not np.isfinite(std_value):
mean_value = 0 # Fallback value for plotting purposes
std_value = 0
# Prepare the "normal range" triangle data (min, max, and mid-point of normal range)
normal_values = [
normal_min,
normal_max,
(normal_min + normal_max) / 2,
mean_value,
std_value,
]
normal_values += normal_values[:1] # Complete the loop
# Prepare the "current value" triangle data: min, max, median, mean, std of capped values
current_value_triangle = [
np.min(capped_values),
np.max(capped_values),
np.median(capped_values),
mean_value,
std_value,
]
current_value_triangle += current_value_triangle[:1] # Complete the loop
# Calculate angles for radar plot
angles = np.linspace(
0, 2 * np.pi, len(feature_names), endpoint=False
).tolist()
angles += angles[:1] # Complete the loop for radar chart
# Start plot
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
# Configure axes
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
plt.xticks(angles[:-1], feature_names, color="#34495E", size=10)
# Set radial limits based on min and max of normal and capped data
y_min = min(normal_min, np.min(capped_values))
y_max = max(normal_max, np.max(capped_values))
plt.ylim(y_min, y_max)
# Plot normal range triangle
ax.plot(
angles,
normal_values,
linewidth=2,
linestyle="solid",
color="#16A085",
label="Normal Range",
)
ax.fill(angles, normal_values, color="#1ABC9C", alpha=0.3)
# Plot actual values triangle (min, max, median, mean, std)
ax.plot(
angles,
current_value_triangle,
linewidth=2,
linestyle="solid",
color="#2980B9",
label="Actual Values",
)
ax.fill(angles, current_value_triangle, color="#3498DB", alpha=0.3)
# Add a scatter point for the median value on the third axis
median_value = np.median(capped_values)
ax.scatter(
angles[2],
median_value,
color="#E74C3C",
zorder=5,
s=150,
edgecolors="black",
)
# Mark only one outlier in the plot
outlier_detected = False
for value in values:
if value < outlier_min or value > outlier_max and np.isfinite(value):
if not outlier_detected: # Only plot the first outlier
ax.scatter(
angles[2],
median_value,
color="#E67E22",
zorder=5,
s=200,
edgecolors="black",
label="Outlier",
)
outlier_detected = True
# Move the legend outside the plot
plt.legend(
loc="upper right",
bbox_to_anchor=(
1.3,
1.1,
),
frameon=True,
framealpha=0.7,
shadow=True,
fontsize=10,
) # Adjust the position
# Modern gridlines and spines
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
ax.spines["polar"].set_visible(False)
# Save the plot
filepath = os.path.join(output_dir, f"{feature}_radar_plot.png")
plt.savefig(filepath, bbox_inches="tight")
plt.close()
except Exception as e:
self.logger.error(f"Error creating radar plot for {feature}: {e}")
return f"Error generating plot for {feature}"
return filepath
def _get_normal_range_for_feature(self, feature):
"""
Retrieves the normal range for a given feature from the loaded configuration.
Args:
feature (str): The name of the feature to get the normal range for.
Returns:
tuple: The (min, max) normal range values for the feature, or None if not found.
Example Usage:
>>> normal_range = visualization._get_normal_range_for_feature("RMSSD")
>>> print(normal_range) # Output: (20, 100)
"""
feature_info = self.config.get(feature, {})
normal_range = feature_info.get("normal_range", {}).get(
self.segment_duration, None
)
if normal_range is not None:
# Handle string '-inf' and 'inf' cases
normal_range = [self._parse_inf_values(val) for val in normal_range]
return normal_range
def _parse_inf_values(self, val):
"""
Parses 'inf' and '-inf' strings and converts them to numpy infinity values.
Args:
val (str or float): The value to parse.
Returns:
float: Parsed value where 'inf' or '-inf' are converted to np.inf or -np.inf respectively.
Example Usage:
>>> value = visualization._parse_inf_values("-inf")
>>> print(value) # Output: -inf
"""
if isinstance(val, str):
if val.lower() == "inf":
return np.inf
elif val.lower() == "-inf":
return -np.inf
return val
[docs]
def auto_detect_roi(self, Sxx, times, frequencies, threshold=0.4):
"""
Automatically detects the region of interest (ROI) in the spectrogram based on power.
Args:
Sxx (2D array): Spectrogram values.
times (1D array): Time values corresponding to the spectrogram.
frequencies (1D array): Frequency values corresponding to the spectrogram.
threshold (float): A power threshold (percentage of max) for detecting regions of interest.
Returns:
roi_time (tuple): Start and end time for the region of interest.
roi_freq (tuple): Start and end frequency for the region of interest.
"""
# Normalize the spectrogram power to the range [0, 1]
normalized_Sxx = Sxx / np.max(Sxx)
# Find areas where power exceeds the threshold
mask = normalized_Sxx > threshold
# Find time and frequency indices that exceed the threshold
freq_indices, time_indices = np.where(mask)
# If no ROI is detected, return the full time and frequency range
if len(time_indices) == 0 or len(freq_indices) == 0:
return (times[0], times[-1]), (frequencies[0], frequencies[-1])
# Find the min and max time and frequency indices that exceed the threshold
min_time_idx = np.min(time_indices)
max_time_idx = np.max(time_indices)
min_freq_idx = np.min(freq_indices)
max_freq_idx = np.max(freq_indices)
# Determine ROI time based on detected indices
roi_time_start = times[min_time_idx]
roi_time_end = times[max_time_idx]
# Determine ROI frequency based on detected indices
roi_freq_start = frequencies[min_freq_idx]
roi_freq_end = frequencies[max_freq_idx]
return (roi_time_start, roi_time_end), (roi_freq_start, roi_freq_end)