"""
Physiological feature extraction callbacks for vitalDSP webapp.
This module handles extraction of physiological features from signals.
"""
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from dash import Input, Output, State, callback_context, no_update, html, dcc
from dash.exceptions import PreventUpdate
import dash_bootstrap_components as dbc
from scipy import signal
import logging
# Helper function for formatting large numbers
logger = logging.getLogger(__name__)
[docs]
def create_hrv_poincare_plot(rr_intervals, hrv_metrics):
"""Create Poincaré plot for HRV analysis."""
logger.info(f"Creating HRV Poincaré plot with {len(rr_intervals)} RR intervals")
if len(rr_intervals) < 2:
logger.warning("Not enough RR intervals for Poincaré plot")
return go.Figure()
# Prepare data for Poincaré plot (RR_n vs RR_n+1)
# Filter out infinite values before creating Poincaré plot
finite_mask = np.isfinite(rr_intervals)
if np.sum(finite_mask) < 2:
return create_empty_figure()
finite_rr_intervals = rr_intervals[finite_mask]
rr_n = finite_rr_intervals[:-1]
rr_n1 = finite_rr_intervals[1:]
fig = go.Figure()
# Add scatter plot
fig.add_trace(
go.Scatter(
x=rr_n,
y=rr_n1,
mode="markers",
marker=dict(
size=6,
color="rgba(0, 123, 255, 0.6)",
line=dict(width=1, color="rgba(0, 123, 255, 0.8)"),
),
name="RR Intervals",
hovertemplate="<b>RR<sub>n</sub>:</b> %{x:.1f} ms<br><b>RR<sub>n+1</sub>:</b> %{y:.1f} ms<extra></extra>",
)
)
# Add diagonal line (identity line)
min_rr = min(min(rr_n), min(rr_n1))
max_rr = max(max(rr_n), max(rr_n1))
fig.add_trace(
go.Scatter(
x=[min_rr, max_rr],
y=[min_rr, max_rr],
mode="lines",
line=dict(color="red", dash="dash", width=2),
name="Identity Line",
showlegend=True,
)
)
# Add SD1 and SD2 ellipses if available
if "poincare_sd1" in hrv_metrics and "poincare_sd2" in hrv_metrics:
sd1 = hrv_metrics["poincare_sd1"]
sd2 = hrv_metrics["poincare_sd2"]
mean_rr = np.mean(rr_intervals) if len(rr_intervals) > 0 else 0
# Create ellipse for SD1 and SD2
theta = np.linspace(0, 2 * np.pi, 100)
x_ellipse = mean_rr + sd2 * np.cos(theta)
y_ellipse = mean_rr + sd1 * np.sin(theta)
fig.add_trace(
go.Scatter(
x=x_ellipse,
y=y_ellipse,
mode="lines",
line=dict(color="green", width=2),
name=f"SD1={sd1:.2f}, SD2={sd2:.2f}",
showlegend=True,
)
)
fig.update_layout(
title="HRV Poincaré Plot",
xaxis_title="RR<sub>n</sub> (ms)",
yaxis_title="RR<sub>n+1</sub> (ms)",
width=500,
height=400,
showlegend=True,
)
return fig
[docs]
def create_hrv_time_series_plot(time_axis, rr_intervals, hrv_metrics):
"""Create time series plot for HRV analysis."""
if len(rr_intervals) == 0:
return go.Figure()
fig = go.Figure()
# Plot RR intervals over time
fig.add_trace(
go.Scatter(
x=time_axis[: len(rr_intervals)],
y=rr_intervals,
mode="lines+markers",
name="RR Intervals",
line=dict(color="blue", width=2),
marker=dict(size=4, color="blue"),
hovertemplate="<b>Time:</b> %{x:.2f}s<br><b>RR Interval:</b> %{y:.1f} ms<extra></extra>",
)
)
# Add mean RR line
if "mean_rr" in hrv_metrics:
mean_rr = hrv_metrics["mean_rr"]
fig.add_hline(
y=mean_rr,
line_dash="dash",
line_color="red",
annotation_text=f"Mean RR: {mean_rr:.1f} ms",
)
fig.update_layout(
title="HRV Time Series",
xaxis_title="Time (s)",
yaxis_title="RR Interval (ms)",
width=600,
height=300,
showlegend=True,
)
return fig
[docs]
def create_morphology_analysis_plot(
time_axis, signal_data, peaks, peak_heights, sampling_freq
):
"""Create morphology analysis plot showing peaks and amplitude distribution."""
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Signal with Detected Peaks",
"Peak Height Distribution",
"Peak Interval Distribution",
"Amplitude Histogram",
),
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# Plot 1: Signal with peaks
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue", width=1),
),
row=1,
col=1,
)
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_axis[peaks],
y=signal_data[peaks],
mode="markers",
name="Peaks",
marker=dict(color="red", size=8, symbol="diamond"),
hovertemplate="<b>Peak:</b> %{y:.1f}<br><b>Time:</b> %{x:.2f}s<extra></extra>",
),
row=1,
col=1,
)
# Plot 2: Peak height distribution
if len(peak_heights) > 0:
fig.add_trace(
go.Histogram(
x=peak_heights,
nbinsx=20,
name="Peak Heights",
marker_color="lightblue",
opacity=0.7,
),
row=1,
col=2,
)
# Plot 3: Peak interval distribution
if len(peaks) > 1:
peak_intervals = np.diff(peaks) / sampling_freq * 1000 # Convert to ms
fig.add_trace(
go.Histogram(
x=peak_intervals,
nbinsx=20,
name="Peak Intervals",
marker_color="lightgreen",
opacity=0.7,
),
row=2,
col=1,
)
# Plot 4: Amplitude histogram
fig.add_trace(
go.Histogram(
x=signal_data,
nbinsx=50,
name="Amplitude Distribution",
marker_color="lightcoral",
opacity=0.7,
),
row=2,
col=2,
)
fig.update_layout(title="Morphology Analysis", height=600, showlegend=False)
fig.update_xaxes(title_text="Time (s)", row=1, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
fig.update_xaxes(title_text="Peak Height", row=1, col=2)
fig.update_yaxes(title_text="Count", row=1, col=2)
fig.update_xaxes(title_text="Peak Interval (ms)", row=2, col=1)
fig.update_yaxes(title_text="Count", row=2, col=1)
fig.update_xaxes(title_text="Amplitude", row=2, col=2)
fig.update_yaxes(title_text="Count", row=2, col=2)
return fig
[docs]
def create_energy_analysis_plot(frequencies, psd, energy_metrics):
"""Create energy analysis plot showing frequency bands and power distribution."""
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Power Spectral Density",
"Frequency Band Energy Distribution",
"Energy Time Evolution",
"Energy Statistics",
),
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# Plot 1: PSD
fig.add_trace(
go.Scatter(
x=frequencies,
y=psd,
mode="lines",
name="PSD",
line=dict(color="blue", width=2),
),
row=1,
col=1,
)
# Plot 2: Frequency band energy distribution
if all(
key in energy_metrics
for key in ["low_freq_energy", "mid_freq_energy", "high_freq_energy"]
):
bands = ["Low Freq", "Mid Freq", "High Freq"]
energies = [
energy_metrics["low_freq_energy"],
energy_metrics["mid_freq_energy"],
energy_metrics["high_freq_energy"],
]
fig.add_trace(
go.Bar(
x=bands,
y=energies,
name="Energy by Band",
marker_color=["lightblue", "lightgreen", "lightcoral"],
),
row=1,
col=2,
)
# Plot 3: Energy time evolution (simplified)
if "total_energy" in energy_metrics:
time_points = np.linspace(0, 1, 50)
energy_evolution = np.full_like(time_points, energy_metrics["total_energy"])
fig.add_trace(
go.Scatter(
x=time_points,
y=energy_evolution,
mode="lines",
name="Total Energy",
line=dict(color="red", width=2),
),
row=2,
col=1,
)
# Plot 4: Energy statistics
if "total_energy" in energy_metrics and "mean_energy" in energy_metrics:
stats = ["Total", "Mean", "Variance"]
values = [
energy_metrics["total_energy"],
energy_metrics["mean_energy"],
energy_metrics.get("energy_variance", 0),
]
fig.add_trace(
go.Bar(
x=stats, y=values, name="Energy Stats", marker_color="lightsteelblue"
),
row=2,
col=2,
)
fig.update_layout(title="Energy Analysis", height=600, showlegend=False)
fig.update_xaxes(title_text="Frequency (Hz)", row=1, col=1)
fig.update_yaxes(title_text="Power", row=1, col=1)
fig.update_xaxes(title_text="Frequency Band", row=1, col=2)
fig.update_yaxes(title_text="Energy", row=1, col=2)
fig.update_xaxes(title_text="Time", row=2, col=1)
fig.update_yaxes(title_text="Energy", row=2, col=1)
fig.update_xaxes(title_text="Statistic", row=2, col=2)
fig.update_yaxes(title_text="Value", row=2, col=2)
return fig
[docs]
def create_quality_assessment_plot(signal_data, quality_metrics, time_axis):
"""Create signal quality assessment plots."""
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Signal Quality Over Time",
"SNR Analysis",
"Artifact Detection",
"Quality Metrics Summary",
),
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# Plot 1: Signal with quality indicators
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue", width=1),
),
row=1,
col=1,
)
# Add quality zones if available
if "signal_quality_index" in quality_metrics:
quality_index = quality_metrics["signal_quality_index"]
if quality_index > 0.8:
quality_color = "green"
quality_text = "Excellent"
elif quality_index > 0.6:
quality_color = "orange"
quality_text = "Good"
else:
quality_color = "red"
quality_text = "Poor"
fig.add_annotation(
x=0.5,
y=0.9,
xref="paper",
yref="paper",
text=f"Quality: {quality_text} ({quality_index:.2f})",
showarrow=False,
font=dict(color=quality_color, size=14),
)
# Plot 2: SNR analysis
if "snr_db" in quality_metrics:
snr_value = quality_metrics["snr_db"]
fig.add_trace(
go.Bar(
x=["SNR"],
y=[snr_value],
name="SNR",
marker_color="lightblue",
text=[f"{snr_value:.1f} dB"],
textposition="auto",
),
row=1,
col=2,
)
# Plot 3: Artifact detection
if "detected_artifacts" in quality_metrics:
artifacts = quality_metrics["detected_artifacts"]
fig.add_trace(
go.Bar(
x=["Artifacts"],
y=[artifacts],
name="Artifacts",
marker_color="red",
text=[f"{artifacts}"],
textposition="auto",
),
row=2,
col=1,
)
# Plot 4: Quality metrics summary
quality_indicators = []
quality_values = []
if "signal_quality_index" in quality_metrics:
quality_indicators.append("Quality Index")
quality_values.append(quality_metrics["signal_quality_index"])
if "snr_db" in quality_metrics:
quality_indicators.append("SNR (dB)")
quality_values.append(quality_metrics["snr_db"])
if "detected_artifacts" in quality_metrics:
quality_indicators.append("Artifacts")
quality_values.append(quality_metrics["detected_artifacts"])
if quality_indicators:
fig.add_trace(
go.Bar(
x=quality_indicators,
y=quality_values,
name="Quality Metrics",
marker_color="lightsteelblue",
),
row=2,
col=2,
)
fig.update_layout(title="Signal Quality Assessment", height=600, showlegend=False)
fig.update_xaxes(title_text="Time (s)", row=1, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
fig.update_xaxes(title_text="Metric", row=1, col=2)
fig.update_yaxes(title_text="Value", row=1, col=2)
fig.update_xaxes(title_text="Metric", row=2, col=1)
fig.update_yaxes(title_text="Count", row=2, col=1)
fig.update_xaxes(title_text="Quality Metric", row=2, col=2)
fig.update_yaxes(title_text="Value", row=2, col=2)
return fig
[docs]
def create_comprehensive_analysis_plot(time_axis, signal_data, analysis_results):
"""Create a comprehensive analysis plot showing multiple aspects of the signal."""
logger.info(
f"Creating comprehensive analysis plot with signal length: {len(signal_data)}"
)
logger.info(f"Analysis results available: {analysis_results is not None}")
fig = make_subplots(
rows=3,
cols=2,
subplot_titles=(
"Signal Overview",
"Peak Detection",
"Frequency Analysis",
"Statistical Distribution",
"Quality Assessment",
"Feature Summary",
),
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# Plot 1: Signal overview
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue", width=1),
),
row=1,
col=1,
)
# Plot 2: Peak detection
if (
"morphology_metrics" in analysis_results
and "peaks" in analysis_results["morphology_metrics"]
):
peaks = analysis_results["morphology_metrics"]["peaks"]
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_axis[peaks],
y=signal_data[peaks],
mode="markers",
name="Peaks",
marker=dict(color="red", size=6, symbol="diamond"),
),
row=1,
col=2,
)
# Plot 3: Frequency analysis
if (
"frequency_metrics" in analysis_results
and "frequencies" in analysis_results["frequency_metrics"]
):
frequencies = analysis_results["frequency_metrics"]["frequencies"]
psd = analysis_results["frequency_metrics"]["psd"]
fig.add_trace(
go.Scatter(
x=frequencies,
y=psd,
mode="lines",
name="PSD",
line=dict(color="green", width=2),
),
row=2,
col=1,
)
# Plot 4: Statistical distribution
fig.add_trace(
go.Histogram(
x=signal_data,
nbinsx=50,
name="Distribution",
marker_color="lightblue",
opacity=0.7,
),
row=2,
col=2,
)
# Plot 5: Quality assessment
if "quality_metrics" in analysis_results:
quality_metrics = analysis_results["quality_metrics"]
quality_names = list(quality_metrics.keys())[:5] # Top 5 metrics
quality_values = [quality_metrics[name] for name in quality_names]
fig.add_trace(
go.Bar(
x=quality_names,
y=quality_values,
name="Quality",
marker_color="lightcoral",
),
row=3,
col=1,
)
# Plot 6: Feature summary
feature_categories = ["HRV", "Morphology", "Energy", "Quality"]
feature_counts = [
len(analysis_results.get("hrv_metrics", {})),
len(analysis_results.get("morphology_metrics", {})),
len(analysis_results.get("energy_metrics", {})),
len(analysis_results.get("quality_metrics", {})),
]
fig.add_trace(
go.Bar(
x=feature_categories,
y=feature_counts,
name="Features",
marker_color="lightsteelblue",
),
row=3,
col=2,
)
fig.update_layout(
title="Comprehensive Physiological Analysis", height=900, showlegend=False
)
return fig
[docs]
def normalize_signal_type(signal_type):
"""Normalize signal type to ensure compatibility with vitalDSP."""
if not signal_type:
return "PPG"
# Convert to uppercase and validate
signal_type_upper = signal_type.upper()
valid_types = ["ECG", "PPG", "EEG", "RESP"]
if signal_type_upper in valid_types:
return signal_type_upper
else:
logger.warning(
f"Invalid signal type '{signal_type}' detected. Defaulting to 'PPG'."
)
return "PPG"
[docs]
def update_time_slider_marks(data_store):
"""Update time slider marks based on available data."""
if not data_store or "time_data" not in data_store:
return {}
try:
time_data = data_store["time_data"]
if not time_data:
return {}
# Create marks at regular intervals
max_time = max(time_data)
step = max_time / 10
marks = {}
for i in range(0, 11):
time_val = i * step
marks[time_val] = f"{time_val:.1f}s"
return marks
except Exception as e:
logger.error(f"Error updating time slider marks: {e}")
return {}
[docs]
def register_physiological_callbacks(app):
"""Register all physiological analysis callbacks."""
logger.info("=== REGISTERING PHYSIOLOGICAL CALLBACKS ===")
# Import vitalDSP modules when callbacks are registered
_import_vitaldsp_modules()
# Register additional callbacks for enhanced features
register_additional_physiological_callbacks(app)
# Auto-select signal type based on uploaded data
@app.callback(
[Output("physio-signal-type", "value")],
[Input("url", "pathname")],
[State("physio-signal-type", "value")],
prevent_initial_call=True,
)
def auto_select_physio_signal_type(pathname, current_signal_type):
"""Auto-select signal type based on uploaded data."""
logger.info("=== AUTO-SELECT PHYSIO SIGNAL TYPE CALLBACK TRIGGERED ===")
logger.info(f"Pathname: {pathname}, Current selection: {current_signal_type}")
if pathname != "/physiological":
logger.info("Not on physiological page, preventing update")
raise PreventUpdate
# If user has already made a selection, preserve it
if current_signal_type is not None:
logger.info(
f"User has existing signal type selection: {current_signal_type}, preserving it"
)
raise PreventUpdate
try:
from vitalDSP_webapp.services.data.enhanced_data_service import (
get_enhanced_data_service,
)
data_service = get_enhanced_data_service()
if not data_service:
logger.warning("Data service not available")
return ["PPG"]
# Get the latest data
all_data = data_service.get_all_data()
if not all_data:
logger.info("No data available, using defaults")
return ["PPG"]
# Get the most recent data
latest_data_id = max(
all_data.keys(), key=lambda x: int(x.split("_")[1]) if "_" in x else 0
)
data_info = data_service.get_data_info(latest_data_id)
if not data_info:
logger.info("No data info available, using defaults")
return ["PPG"]
# Debug: Log the data_info to see what's stored
logger.info(
f"Physio data info keys: {list(data_info.keys()) if data_info else 'None'}"
)
logger.info(f"Physio full data info: {data_info}")
# First, check if signal type is stored in data info
stored_signal_type = data_info.get("signal_type", None)
logger.info(f"Physio stored signal type: {stored_signal_type}")
if stored_signal_type and stored_signal_type.lower() != "auto":
# Convert stored value to match physiological screen dropdown format (lowercase)
signal_type = stored_signal_type.lower()
logger.info(f"Physio using stored signal type: {signal_type}")
else:
# Auto-detect signal type based on data characteristics
signal_type = "PPG" # Default for physiological screen
logger.info(
"Physio auto-detecting signal type from data characteristics"
)
# Try to detect signal type from column names or data characteristics if not stored
if (
stored_signal_type
and stored_signal_type.lower() == "auto"
or not stored_signal_type
):
df = data_service.get_data(latest_data_id)
if df is not None and not df.empty:
column_mapping = data_service.get_column_mapping(latest_data_id)
signal_column = column_mapping.get("signal", "")
# Check column names for signal type hints
if any(
keyword in signal_column.lower()
for keyword in ["ecg", "electrocardio"]
):
signal_type = "ECG"
logger.info("Auto-detected ECG signal type from column name")
elif any(
keyword in signal_column.lower()
for keyword in ["ppg", "pleth", "photopleth"]
):
signal_type = "PPG"
logger.info("Auto-detected PPG signal type from column name")
else:
# Try to detect from data characteristics
try:
signal_data = (
df[signal_column].values
if signal_column
else df.iloc[:, 1].values
)
sampling_freq = data_info.get("sampling_freq", 1000)
# Simple heuristic: ECG typically has higher frequency content
from scipy import signal
f, psd = signal.welch(
signal_data,
fs=sampling_freq,
nperseg=min(1024, len(signal_data) // 4),
)
dominant_freq = f[np.argmax(psd)]
if (
dominant_freq > 1.0
): # Higher frequency content suggests ECG
signal_type = "ECG"
logger.info(
"Auto-detected ECG signal type from frequency analysis"
)
else:
signal_type = "PPG"
logger.info(
"Auto-detected PPG signal type from frequency analysis"
)
except Exception as e:
logger.warning(
f"Could not analyze signal characteristics: {e}"
)
signal_type = "PPG"
logger.info(f"Auto-selected physiological signal type: {signal_type}")
return [signal_type]
except Exception as e:
logger.error(f"Error in auto-selection: {e}")
return ["PPG"]
@app.callback(
[
Output("physio-main-signal-plot", "figure"),
Output("physio-analysis-results", "children"),
Output("physio-analysis-plots", "figure"),
Output("store-physio-data", "data"),
Output("store-physio-features", "data"),
],
[
Input("url", "pathname"),
Input("physio-btn-update-analysis", "n_clicks"),
Input("physio-btn-nudge-m10", "n_clicks"),
Input("physio-btn-nudge-m1", "n_clicks"),
Input("physio-btn-nudge-p1", "n_clicks"),
Input("physio-btn-nudge-p10", "n_clicks"),
],
[
State(
"physio-start-position-slider", "value"
), # NEW: start position instead of time-range-slider
State(
"physio-duration-select", "value"
), # NEW: duration instead of start-time/end-time
State("physio-signal-type", "value"),
State("physio-signal-source-select", "value"),
State("physio-analysis-categories", "value"),
State("physio-hrv-options", "value"),
State("physio-morphology-options", "value"),
State("physio-advanced-features", "value"),
State("physio-quality-options", "value"),
State("physio-transform-options", "value"),
State("physio-advanced-computation", "value"),
State("physio-feature-engineering", "value"),
State("physio-preprocessing", "value"),
],
)
def physiological_analysis_callback(
pathname,
n_clicks,
nudge_m10,
nudge_m1,
nudge_p1,
nudge_p10,
start_position, # NEW: start position instead of slider_value
duration, # NEW: duration instead of start_time/end_time
signal_type,
signal_source,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
):
"""Unified callback for physiological analysis - handles both page load and user interactions."""
ctx = callback_context
# Determine what triggered this callback
if not ctx.triggered:
logger.warning("No context triggered - raising PreventUpdate")
raise PreventUpdate
trigger_id = ctx.triggered[0]["prop_id"].split(".")[0]
logger.info("=== PHYSIOLOGICAL ANALYSIS CALLBACK TRIGGERED ===")
logger.info(f"Trigger ID: {trigger_id}")
logger.info(f"Pathname: {pathname}")
logger.info(f"Analysis categories: {analysis_categories}")
logger.info(f"HRV options: {hrv_options}")
logger.info(f"Morphology options: {morphology_options}")
logger.info(f"Advanced features: {advanced_features}")
logger.info(f"Quality options: {quality_options}")
logger.info(f"Transform options: {transform_options}")
logger.info(f"Advanced computation: {advanced_computation}")
logger.info(f"Feature engineering: {feature_engineering}")
logger.info(f"Preprocessing: {preprocessing}")
# Only run this when we're on the physiological page
if pathname != "/physiological":
logger.info("Not on physiological page, returning empty figures")
return (
create_empty_figure(),
"Navigate to Physiological Features page",
create_empty_figure(),
None,
None,
)
try:
# Get data from the data service
from vitalDSP_webapp.services.data.enhanced_data_service import (
get_enhanced_data_service,
)
data_service = get_enhanced_data_service()
# Get the most recent data
all_data = data_service.get_all_data()
if not all_data:
logger.warning("No data found in service")
return (
create_empty_figure(),
"No data available. Please upload and process data first.",
create_empty_figure(),
None,
None,
)
# Get the most recent data entry
latest_data_id = list(all_data.keys())[-1]
latest_data = all_data[latest_data_id]
logger.info(f"Found data: {latest_data_id}")
# Get column mapping
column_mapping = data_service.get_column_mapping(latest_data_id)
if not column_mapping:
logger.warning(
"Data has not been processed yet - no column mapping found"
)
return (
create_empty_figure(),
"Please process your data on the Upload page first (configure column mapping)",
create_empty_figure(),
None,
None,
)
logger.info(f"Column mapping found: {column_mapping}")
df = data_service.get_data(latest_data_id)
if df is None or df.empty:
logger.warning("Data frame is empty")
return (
create_empty_figure(),
"Data is empty or corrupted.",
create_empty_figure(),
None,
None,
)
# Get sampling frequency from the data info
sampling_freq = latest_data.get("info", {}).get("sampling_freq", 1000)
logger.info(f"Sampling frequency: {sampling_freq}")
# Set default values if not provided
start_position = start_position or 0
duration = duration or 60 # Default to 1 minute instead of 10 seconds
signal_type = signal_type or "auto"
analysis_categories = analysis_categories or [
"hrv",
"morphology",
"beat2beat",
"energy",
"envelope",
"segmentation",
"trend",
"waveform",
"statistical",
"frequency",
]
hrv_options = hrv_options or ["time_domain", "freq_domain", "nonlinear"]
morphology_options = morphology_options or ["peaks", "duration", "area"]
advanced_features = advanced_features or [
"cross_signal",
"ensemble",
"change_detection",
"power_analysis",
]
quality_options = quality_options or ["quality_index", "artifact_detection"]
transform_options = transform_options or ["wavelet", "fourier", "hilbert"]
advanced_computation = advanced_computation or [
"anomaly_detection",
"bayesian",
"kalman",
]
feature_engineering = feature_engineering or [
"ppg_light",
"ppg_autonomic",
"ecg_autonomic",
]
preprocessing = preprocessing or [
"noise_reduction",
"baseline_correction",
"filtering",
]
# Convert duration to numeric (Select returns string)
try:
duration = float(duration) if duration is not None else 60
except (ValueError, TypeError):
duration = 60 # Default to 1 minute if conversion fails
# Convert start_position to numeric
try:
start_position = (
float(start_position) if start_position is not None else 0
)
except (ValueError, TypeError):
start_position = 0
# Handle nudge buttons for start_position adjustments
if trigger_id in [
"physio-btn-nudge-m10",
"physio-btn-nudge-m1",
"physio-btn-nudge-p1",
"physio-btn-nudge-p10",
]:
# Adjust start position by percentage
if trigger_id == "physio-btn-nudge-m10":
start_position = max(0, start_position - 10)
elif trigger_id == "physio-btn-nudge-m1":
start_position = max(0, start_position - 5)
elif trigger_id == "physio-btn-nudge-p1":
start_position = min(100, start_position + 5)
elif trigger_id == "physio-btn-nudge-p10":
start_position = min(100, start_position + 10)
logger.info(f"Adjusted start position via nudge: {start_position}%")
# Handle start_position as percentage (0-100) - convert to time
# start_position is now a percentage of the data length
logger.info(f"Start position: {start_position}%, Duration: {duration}s")
logger.info(f"Signal type: {signal_type}")
logger.info(f"Analysis categories: {analysis_categories}")
logger.info(f"HRV options: {hrv_options}")
logger.info(f"Morphology options: {morphology_options}")
logger.info(f"Advanced features: {advanced_features}")
logger.info(f"Quality options: {quality_options}")
logger.info(f"Transform options: {transform_options}")
logger.info(f"Advanced computation: {advanced_computation}")
logger.info(f"Feature engineering: {feature_engineering}")
logger.info(f"Preprocessing: {preprocessing}")
# Extract time and signal columns
time_col = column_mapping.get("time", df.columns[0])
signal_col = column_mapping.get("signal", df.columns[1])
# Log column selection for debugging
logger.info(f"Available columns: {list(df.columns)}")
logger.info(f"Selected time column: {time_col}")
logger.info(f"Selected signal column: {signal_col}")
logger.info(f"Column mapping: {column_mapping}")
# Extract time and signal data
time_data = df[time_col].values
signal_data = df[signal_col].values
# Always keep the original signal_data for dynamic filtering
original_signal_data = signal_data.copy()
selected_signal = signal_data
signal_source_info = "Original Signal"
filter_info = None
# Signal source loading logic
logger.info("=== SIGNAL SOURCE LOADING ===")
logger.info(f"Signal source selection: {signal_source}")
if signal_source == "filtered":
# Try to load filtered data from filtering screen
filtered_data = data_service.get_filtered_data(latest_data_id)
filter_info = data_service.get_filter_info(latest_data_id)
if filtered_data is not None:
logger.info(
f"Found filtered data with shape: {filtered_data.shape}"
)
selected_signal = filtered_data
signal_source_info = "Filtered Signal"
else:
logger.info("No filtered data available, using original signal")
else:
logger.info("Using original signal as requested")
# Log data characteristics
logger.info(f"Time data range: {time_data[0]} to {time_data[-1]}")
logger.info(
f"Selected signal range: {np.min(selected_signal):.4f} to {np.max(selected_signal):.4f}"
)
logger.info(
f"Selected signal mean: {np.mean(selected_signal):.4f}, std: {np.std(selected_signal):.4f}"
)
logger.info(f"Total samples: {len(selected_signal)}")
logger.info(f"Signal source: {signal_source_info}")
# Suggest better column mappings if current ones seem wrong
if signal_col == "PLETH" and np.std(signal_data) < 0.01:
logger.warning(
f"PLETH column has very low variance ({np.std(signal_data):.6f}), might not be the best signal column"
)
# Look for columns with higher variance
for col in df.columns:
if col not in [time_col, signal_col] and df[col].dtype in [
"float64",
"float32",
"int64",
"int32",
]:
col_std = np.std(df[col].values)
if col_std > np.std(signal_data) * 10: # 10x higher variance
logger.info(
f"Consider using column '{col}' instead (std: {col_std:.6f})"
)
break
# Get suggestions for better signal columns
signal_suggestions = suggest_best_signal_column(df, time_col)
if signal_suggestions:
logger.info("Signal column suggestions (ranked by quality):")
for i, suggestion in enumerate(
signal_suggestions[:3]
): # Top 3 suggestions
logger.info(
f" {i+1}. {suggestion['column']} (score: {suggestion['score']}, variance: {suggestion['variance']:.6f})"
)
# If current signal column has very low score, suggest using the best one
current_score = next(
(
s["score"]
for s in signal_suggestions
if s["column"] == signal_col
),
0,
)
best_score = signal_suggestions[0]["score"]
if current_score < best_score - 2: # Significant difference
logger.warning(
f"Current signal column '{signal_col}' has low quality (score: {current_score})"
)
logger.warning(
f"Consider using '{signal_suggestions[0]['column']}' instead (score: {best_score})"
)
# Auto-switch to better column if current one is very poor
if (
current_score < 2 and best_score >= 4
): # Current is very poor, best is good
old_signal_col = signal_col
signal_col = signal_suggestions[0]["column"]
signal_data = df[signal_col].values
logger.info(
f"Auto-switched signal column from '{old_signal_col}' to '{signal_col}' for better analysis quality"
)
# Re-log signal characteristics with new column
logger.info(
f"New signal data range: {np.min(signal_data):.4f} to {np.max(signal_data):.4f}"
)
logger.info(
f"New signal data mean: {np.mean(signal_data):.4f}, std: {np.std(signal_data):.4f}"
)
# Determine time unit and adjust window if needed
time_range = time_data[-1] - time_data[0]
logger.info(f"Full time range: {time_range}")
# Handle datetime data conversion
if pd.api.types.is_datetime64_any_dtype(df[time_col]):
logger.info("Converting datetime time data to numeric seconds")
first_timestamp = df[time_col].iloc[0]
time_data_seconds = (
(df[time_col] - first_timestamp).dt.total_seconds().values
)
time_range_seconds = time_data_seconds[-1] - time_data_seconds[0]
logger.info(
f"Time data converted to seconds from first timestamp. Range: {time_range_seconds:.2f}s"
)
elif time_range > 1000: # Likely milliseconds
logger.info(
"Time data appears to be in milliseconds, converting to seconds"
)
time_data_seconds = time_data / 1000.0
time_range_seconds = time_range / 1000.0
else:
# Time data is already in seconds
time_data_seconds = time_data
time_range_seconds = time_range
logger.info("Time data is already in seconds")
# Convert start_position percentage to actual time
start_time_actual = (start_position / 100.0) * time_range_seconds
end_time_actual = start_time_actual + duration
# Ensure end_time doesn't exceed data range
if end_time_actual > time_range_seconds:
end_time_actual = time_range_seconds
start_time_actual = max(0, end_time_actual - duration)
logger.info(
f"Time window: {start_time_actual:.2f}s to {end_time_actual:.2f}s"
)
# Apply time window
start_idx = np.searchsorted(time_data_seconds, start_time_actual)
end_idx = np.searchsorted(time_data_seconds, end_time_actual)
# Ensure minimum signal length for analysis (at least 5 seconds worth of data)
min_samples = max(
5 * sampling_freq, 1000
) # At least 5 seconds or 1000 samples
if end_idx - start_idx < min_samples:
# Extend the window to get sufficient data
if end_idx + min_samples <= len(time_data_seconds):
end_idx = start_idx + min_samples
elif start_idx - min_samples >= 0:
start_idx = end_idx - min_samples
else:
# Use the full signal if we can't get enough data
start_idx = 0
end_idx = len(time_data_seconds)
logger.warning(
"Using full signal due to insufficient data in time window"
)
logger.info(
f"Using signal segment: {start_idx} to {end_idx} (length: {end_idx - start_idx})"
)
if end_idx > start_idx:
time_data = time_data_seconds[start_idx:end_idx]
signal_data = signal_data[start_idx:end_idx]
# Apply time window to both original and selected signals
original_signal_data = original_signal_data[start_idx:end_idx]
selected_signal = selected_signal[start_idx:end_idx]
# Check if we need to apply dynamic filtering for filtered signal
if signal_source == "filtered" and filter_info is not None:
# Check if the time range is within the original signal range
original_signal_length = len(
data_service.get_data(latest_data_id)[signal_col].values
)
expected_length = end_idx - start_idx
# If time range is outside original signal or we need dynamic filtering
if (
start_idx >= original_signal_length
or end_idx > original_signal_length
or expected_length != len(selected_signal)
):
logger.info("=== DYNAMIC FILTERING ===")
logger.info(f"Original signal length: {original_signal_length}")
logger.info(f"Time range: {start_idx} to {end_idx}")
logger.info(f"Expected window length: {expected_length}")
logger.info("Applying dynamic filtering to current window...")
# If time range is completely outside original signal, return error
if (
start_idx >= original_signal_length
or end_idx > original_signal_length
):
logger.warning(
f"Time range {start_idx}-{end_idx} is outside original signal range (0-{original_signal_length})"
)
return (
create_empty_figure(),
f"Time range is outside the available signal data. Please select a time range within 0 to {original_signal_length/sampling_freq:.1f} seconds.",
create_empty_figure(),
None,
None,
)
try:
# Import the same filtering function used in time domain
from vitalDSP_webapp.callbacks.analysis.signal_filtering_callbacks import (
apply_traditional_filter,
)
# Get the full original signal for the current time window
full_original_signal = data_service.get_data(
latest_data_id
)[signal_col].values
windowed_original_signal = full_original_signal[
start_idx:end_idx
]
# Get filter parameters
parameters = filter_info.get("parameters", {})
detrending_applied = filter_info.get(
"detrending_applied", False
)
# Apply detrending if it was applied in the original filtering
if detrending_applied:
from scipy import signal as scipy_signal
signal_data_detrended = scipy_signal.detrend(
windowed_original_signal
)
logger.info("Applied detrending to signal")
else:
signal_data_detrended = windowed_original_signal
# Apply the same filter type as used in filtering screen
filter_type = filter_info.get("filter_type", "traditional")
if filter_type == "traditional":
# Extract traditional filter parameters
filter_family = parameters.get(
"filter_family", "butter"
)
filter_response = parameters.get(
"filter_response", "bandpass"
)
low_freq = parameters.get("low_freq", 0.5)
high_freq = parameters.get("high_freq", 5)
filter_order = parameters.get("filter_order", 4)
# Apply traditional filter
selected_signal = apply_traditional_filter(
signal_data_detrended,
sampling_freq,
filter_family,
filter_response,
low_freq,
high_freq,
filter_order,
)
logger.info("Applied dynamic traditional filter")
else:
# For other filter types, use the original signal
selected_signal = signal_data_detrended
logger.info(
f"Using original signal for filter type: {filter_type}"
)
signal_source_info = "Filtered Signal (Dynamic)"
logger.info("Dynamic filtering completed successfully")
except Exception as e:
logger.error(f"Error in dynamic filtering: {e}")
logger.info("Falling back to original signal")
selected_signal = windowed_original_signal
signal_source_info = "Original Signal (Fallback)"
# Use selected_signal for analysis
signal_data = selected_signal
# Validate signal length for analysis
if len(signal_data) < min_samples:
logger.warning(
f"Signal segment too short ({len(signal_data)} samples), using fallback analysis"
)
# For very short signals, limit analysis to basic features
analysis_categories = ["morphology", "statistical", "frequency"]
hrv_options = []
advanced_features = []
quality_options = []
transform_options = []
advanced_computation = []
feature_engineering = []
preprocessing = []
# Auto-detect signal type if needed
if signal_type == "auto":
signal_type = detect_physiological_signal_type(
signal_data, sampling_freq
)
logger.info(f"Auto-detected signal type: {signal_type}")
# Create main signal plot
main_fig = create_physiological_signal_plot(
time_data, signal_data, signal_type, sampling_freq
)
# Perform analysis based on selected categories
raw_analysis_results = perform_physiological_analysis_enhanced(
time_data,
signal_data,
signal_type,
sampling_freq,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
)
# Create comprehensive results display
analysis_results = create_comprehensive_results_display(
raw_analysis_results, signal_type, sampling_freq
)
# Create analysis plots
analysis_fig = create_physiological_analysis_plots(
time_data_seconds,
signal_data,
signal_type,
sampling_freq,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
)
# Store data for other callbacks
physio_data = {
"time_data": time_data.tolist(),
"signal_data": signal_data.tolist(),
"signal_type": signal_type,
"sampling_freq": sampling_freq,
"analysis_categories": analysis_categories,
}
physio_features = {
"hrv_metrics": raw_analysis_results.get("hrv_metrics", {}),
"morphology_metrics": raw_analysis_results.get(
"morphology_metrics", {}
),
"beat2beat_metrics": raw_analysis_results.get("beat2beat_metrics", {}),
"energy_metrics": raw_analysis_results.get("energy_metrics", {}),
"envelope_metrics": raw_analysis_results.get("envelope_metrics", {}),
"segmentation_metrics": raw_analysis_results.get(
"segmentation_metrics", {}
),
"waveform_metrics": raw_analysis_results.get("waveform_metrics", {}),
"statistical_metrics": raw_analysis_results.get(
"statistical_metrics", {}
),
"frequency_metrics": raw_analysis_results.get("frequency_metrics", {}),
"advanced_features_metrics": raw_analysis_results.get(
"advanced_features_metrics", {}
),
"quality_metrics": raw_analysis_results.get("quality_metrics", {}),
"transform_metrics": raw_analysis_results.get("transform_metrics", {}),
"advanced_computation_metrics": raw_analysis_results.get(
"advanced_computation_metrics", {}
),
"feature_engineering_metrics": raw_analysis_results.get(
"feature_engineering_metrics", {}
),
"preprocessing_metrics": raw_analysis_results.get(
"preprocessing_metrics", {}
),
"trend_metrics": raw_analysis_results.get("trend_metrics", {}),
}
logger.info("Physiological analysis completed successfully")
return (
main_fig,
analysis_results,
analysis_fig,
physio_data,
physio_features,
)
except Exception as e:
logger.error(f"Error in physiological analysis: {e}")
import traceback
logger.error(traceback.format_exc())
error_fig = create_empty_figure()
error_fig.add_annotation(
text=f"Error: {str(e)}",
xref="paper",
yref="paper",
x=0.5,
y=0.5,
showarrow=False,
)
error_results = html.Div(
[
html.H5("Error in Physiological Analysis"),
html.P(f"Analysis failed: {str(e)}"),
html.P("Please check your data and parameters."),
]
)
return error_fig, error_results, error_fig, None, None
# Time input update callbacks
@app.callback(
[
Output("physio-start-position-slider", "value"),
Output("physio-duration-select", "value"),
],
[
Input("physio-start-position-slider", "value"),
Input("physio-duration-select", "value"),
Input("physio-btn-nudge-m10", "n_clicks"),
Input("physio-btn-nudge-m1", "n_clicks"),
Input("physio-btn-nudge-p1", "n_clicks"),
Input("physio-btn-nudge-p10", "n_clicks"),
],
[
State("physio-start-position-slider", "value"),
State("physio-duration-select", "value"),
],
)
def update_physio_time_inputs(
start_position,
duration,
nudge_m10,
nudge_m1,
nudge_p1,
nudge_p10,
current_start,
current_duration,
):
"""Update time inputs based on nudge buttons."""
ctx = callback_context
if not ctx.triggered:
raise PreventUpdate
trigger_id = ctx.triggered[0]["prop_id"].split(".")[0]
# Use current values if available, otherwise defaults
start_pos = current_start if current_start is not None else 0
dur = current_duration if current_duration is not None else 10
# Handle nudge buttons (percentage-based)
if trigger_id == "physio-btn-nudge-m10":
new_start = max(0, start_pos - 10) # Decrease by 10%
return new_start, dur
elif trigger_id == "physio-btn-nudge-m1":
new_start = max(0, start_pos - 5) # Decrease by 5%
return new_start, dur
elif trigger_id == "physio-btn-nudge-p1":
new_start = min(100, start_pos + 5) # Increase by 5%
return new_start, dur
elif trigger_id == "physio-btn-nudge-p10":
new_start = min(100, start_pos + 10) # Increase by 10%
return new_start, dur
# Return current values for other triggers
return start_pos, dur
# Time slider range update callback
@app.callback(
Output("physio-start-position-slider", "max"),
[Input("store-uploaded-data", "data")],
)
def update_physio_time_slider_range(data_store):
"""Update time slider range based on uploaded data."""
if not data_store:
return 100
try:
df = pd.DataFrame(data_store["data"])
if df.empty:
return 100
# Get time column (assume first column)
time_data = df.iloc[:, 0].values
max_time = np.max(time_data)
return max_time
except Exception as e:
logger.error(f"Error updating time slider range: {e}")
return 100
# Helper functions for physiological analysis
[docs]
def update_physio_time_slider_range(data_store):
"""Update time slider range based on uploaded data."""
if not data_store:
return 100
try:
df = pd.DataFrame(data_store["data"])
if df.empty:
return 100
# Get time column (assume first column)
time_data = df.iloc[:, 0].values
max_time = np.max(time_data)
return max_time
except Exception as e:
logger.error(f"Error updating time slider range: {e}")
return 100
[docs]
def detect_physiological_signal_type(signal_data, sampling_freq):
"""Auto-detect the type of physiological signal."""
try:
# Simple heuristics for signal type detection
# Calculate basic statistics
mean_val = np.mean(signal_data)
std_val = np.std(signal_data)
# Find peaks for frequency analysis using scipy
from scipy import signal as sp_signal
peaks, _ = sp_signal.find_peaks(
signal_data,
height=mean_val + std_val,
distance=int(sampling_freq * 0.3),
)
if len(peaks) > 1:
intervals = np.diff(peaks) / sampling_freq
if (
np.mean(intervals) < 0.8
): # Less than 0.8 seconds between peaks (more conservative)
return "ecg" # Likely ECG (faster heart rate)
else:
return "ppg" # Likely PPG (slower, more variable)
elif len(peaks) == 1:
# Single peak, default to PPG (more conservative)
return "ppg"
else:
# No peaks found, default to PPG (more conservative)
return "ppg"
except Exception as e:
logger.warning(f"Error in signal type detection: {e}")
return "ppg" # Default fallback
[docs]
def create_physiological_signal_plot(
time_data, signal_data, signal_type, sampling_freq
):
"""Create the main physiological signal plot with enhanced visualization."""
fig = go.Figure()
# Add main signal with enhanced styling
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name=f"{signal_type.upper()} Signal",
line=dict(color="#1f77b4", width=2, shape="linear"),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
)
)
# Enhanced peak detection with red arrows
if signal_type.lower() in ["ecg", "ppg"]:
try:
# Use more sophisticated peak detection
height_threshold = np.mean(signal_data) + 1.5 * np.std(signal_data)
distance_threshold = int(
sampling_freq * 0.3
) # Minimum 0.3 seconds between peaks
# Use vitalDSP for ECG/PPG peak detection
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
properties = {}
if len(peaks) > 0:
# Add peak markers with red arrows
fig.add_trace(
go.Scatter(
x=time_data[peaks],
y=signal_data[peaks],
mode="markers+text",
name="Detected Peaks",
text=[f"P{i+1}" for i in range(len(peaks))],
textposition="top center",
marker=dict(
color="red",
size=10,
symbol="arrow-up",
line=dict(color="darkred", width=2),
),
textfont=dict(color="red", size=10, family="Arial Black"),
)
)
# Add peak-to-peak intervals
if len(peaks) > 1:
intervals = np.diff(peaks) / sampling_freq
for i, (peak1, peak2, interval) in enumerate(
zip(peaks[:-1], peaks[1:], intervals)
):
# Add interval annotation
mid_x = (time_data[peak1] + time_data[peak2]) / 2
mid_y = (signal_data[peak1] + signal_data[peak2]) / 2
fig.add_annotation(
x=mid_x,
y=mid_y,
text=f"{interval:.2f}s",
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor="orange",
font=dict(size=10, color="orange"),
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="orange",
borderwidth=1,
)
# Add statistics annotation
mean_interval = np.mean(intervals) if len(peaks) > 1 else 0
heart_rate = 60 / mean_interval if mean_interval > 0 else 0
fig.add_annotation(
x=0.02,
y=0.98,
xref="paper",
yref="paper",
text=f"Peaks: {len(peaks)}<br>Mean Interval: {mean_interval:.2f}s<br>Heart Rate: {heart_rate:.1f} BPM",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=12, color="black"),
)
except Exception as e:
logger.warning(f"Error in peak detection: {e}")
# Add baseline reference
baseline = np.mean(signal_data)
fig.add_hline(
y=baseline,
line_dash="dash",
line_color="gray",
annotation_text="Baseline",
annotation_position="right",
)
# Add signal envelope
try:
# Calculate moving average for envelope
window_size = min(50, len(signal_data) // 10)
if window_size > 1:
envelope_upper = np.convolve(
signal_data, np.ones(window_size) / window_size, mode="same"
)
envelope_lower = np.convolve(
signal_data, -np.ones(window_size) / window_size, mode="same"
)
fig.add_trace(
go.Scatter(
x=time_data,
y=envelope_upper,
mode="lines",
name="Upper Envelope",
line=dict(color="rgba(255, 165, 0, 0.6)", width=1, dash="dot"),
showlegend=True,
)
)
fig.add_trace(
go.Scatter(
x=time_data,
y=envelope_lower,
mode="lines",
name="Lower Envelope",
line=dict(color="rgba(255, 165, 0, 0.6)", width=1, dash="dot"),
showlegend=True,
fill="tonexty",
fillcolor="rgba(255, 165, 0, 0.1)",
)
)
except Exception as e:
logger.debug(f"Could not create envelope: {e}")
# Enhanced layout
fig.update_layout(
title=dict(
text=f"{signal_type.upper()} Signal Analysis with Peak Detection",
x=0.5,
font=dict(size=20, color="#2c3e50"),
),
xaxis=dict(
title=dict(text="Time (seconds)", font=dict(size=14, color="#2c3e50")),
gridcolor="rgba(128, 128, 128, 0.2)",
zeroline=True,
zerolinecolor="rgba(128, 128, 128, 0.5)",
showgrid=True,
),
yaxis=dict(
title=dict(text="Amplitude", font=dict(size=14, color="#2c3e50")),
gridcolor="rgba(128, 128, 128, 0.2)",
zeroline=True,
zerolinecolor="rgba(128, 128, 128, 0.5)",
showgrid=True,
),
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
hovermode="x unified",
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
height=500,
)
# Add grid styling
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128, 128, 128, 0.2)")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="rgba(128, 128, 128, 0.2)")
return fig
[docs]
def create_physiological_analysis_plots(
time_data,
signal_data,
signal_type,
sampling_freq,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
):
"""Create comprehensive 2x2 grid layout with multiple analysis plots for better visualization."""
try:
# Find peaks for analysis - this is crucial for all plots
height_threshold = np.mean(signal_data) + 1.5 * np.std(signal_data)
distance_threshold = int(sampling_freq * 0.3)
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
else:
# Use scipy for other signal types
peaks, _ = signal.find_peaks(
signal_data,
height=height_threshold,
distance=distance_threshold,
prominence=np.std(signal_data) * 0.5,
)
# Create a 2x2 subplot layout for comprehensive analysis view
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Signal Overview",
"Beat-to-Beat Analysis",
"Frequency Analysis",
"Poincaré Plot",
),
vertical_spacing=0.12,
horizontal_spacing=0.1,
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# 1. Signal Overview (top-left) - Enhanced main signal with peaks
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="#1f77b4", width=2),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=1,
col=1,
)
# Add peaks with red arrows for better visibility
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_data[peaks],
y=signal_data[peaks],
mode="markers+text",
name="Peaks",
text=[
f"P{i+1}" for i in range(min(len(peaks), 10))
], # Limit text to first 10 peaks
textposition="top center",
marker=dict(
color="red",
size=8,
symbol="arrow-up",
line=dict(color="darkred", width=2),
),
textfont=dict(color="red", size=8, family="Arial Black"),
),
row=1,
col=1,
)
# Add ECG-specific trough detection (Q and S valleys) for ECG signals
if signal_type.lower() == "ecg":
try:
from vitalDSP.physiological_features.waveform import (
WaveformMorphology,
)
# Create waveform morphology object
wm = WaveformMorphology(
waveform=signal_data,
fs=sampling_freq,
signal_type="ECG",
simple_mode=True,
)
# Detect Q valleys
q_valleys = wm.detect_q_valley()
if q_valleys is not None and len(q_valleys) > 0:
fig.add_trace(
go.Scatter(
x=time_data[q_valleys],
y=signal_data[q_valleys],
mode="markers",
name="Q Valleys",
marker=dict(
color="orange", size=6, symbol="triangle-down"
),
hovertemplate="<b>Q Valley:</b> %{y}<extra></extra>",
),
row=1,
col=1,
)
# Detect S valleys
s_valleys = wm.detect_s_valley()
if s_valleys is not None and len(s_valleys) > 0:
fig.add_trace(
go.Scatter(
x=time_data[s_valleys],
y=signal_data[s_valleys],
mode="markers",
name="S Valleys",
marker=dict(
color="red", size=6, symbol="triangle-down"
),
hovertemplate="<b>S Valley:</b> %{y}<extra></extra>",
),
row=1,
col=1,
)
except Exception as e:
logger.warning(f"ECG trough detection failed: {e}")
# Add peak statistics
mean_interval = (
np.mean(np.diff(peaks) / sampling_freq) if len(peaks) > 1 else 0
)
heart_rate = 60 / mean_interval if mean_interval > 0 else 0
fig.add_annotation(
x=0.98,
y=0.98,
xref="x1",
yref="y1",
text=f"Peaks: {len(peaks)}<br>HR: {heart_rate:.1f} BPM<br>Mean Interval: {mean_interval:.2f}s",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=10, color="black"),
)
# 2. Beat-to-Beat Analysis (top-right) - Beat intervals and variability
if len(peaks) > 1:
# Calculate beat intervals
beat_intervals = np.diff(peaks) / sampling_freq
fig.add_trace(
go.Scatter(
x=np.arange(len(beat_intervals)),
y=beat_intervals,
mode="lines+markers",
name="Beat Intervals",
line=dict(color="#2ca02c", width=3),
marker=dict(color="#2ca02c", size=6, symbol="diamond"),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=1,
col=2,
)
# Add interval statistics
mean_interval = np.mean(beat_intervals) if len(beat_intervals) > 0 else 0
std_interval = np.std(beat_intervals) if len(beat_intervals) > 0 else 0
heart_rate = 60 / mean_interval if mean_interval > 0 else 0
fig.add_hline(
y=mean_interval,
line_dash="dash",
line_color="red",
annotation_text=f"Mean: {mean_interval:.3f}s",
annotation_position="right",
row=1,
col=2,
)
# Add beat-to-beat statistics
fig.add_annotation(
x=0.98,
y=0.98,
xref="x2",
yref="y2",
text=f"Heart Rate: {heart_rate:.1f} BPM<br>Mean Interval: {mean_interval:.3f}s<br>Variability: {std_interval:.3f}s",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="green",
borderwidth=2,
font=dict(size=10, color="black"),
)
else:
# If no peaks, show signal energy analysis instead
try:
# Calculate signal energy over time
window_size = min(50, len(signal_data) // 10)
if window_size > 1:
energy_values = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
energy = np.sum(window_data**2)
energy_values.append(energy)
time_windows.append(time_data[i + window_size // 2])
if len(energy_values) > 0:
fig.add_trace(
go.Scatter(
x=time_windows,
y=energy_values,
mode="lines+markers",
name="Signal Energy",
line=dict(color="#ff7f0e", width=2),
marker=dict(color="#ff7f0e", size=4, symbol="circle"),
),
row=1,
col=2,
)
# Add energy statistics
mean_energy = np.mean(energy_values)
total_energy = np.sum(signal_data**2)
fig.add_annotation(
x=0.98,
y=0.98,
xref="x2",
yref="y2",
text=f"Mean Energy: {mean_energy:.2e}<br>Total Energy: {total_energy:.2e}<br>Energy Variance: {np.var(energy_values):.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="orange",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception as e:
logger.debug(f"Could not create energy analysis: {e}")
# Final fallback: show signal amplitude distribution
hist, bin_edges = np.histogram(signal_data, bins=30, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Bar(
x=bin_centers,
y=hist,
name="Signal Distribution",
marker_color="rgba(255, 165, 0, 0.7)",
opacity=0.8,
),
row=1,
col=2,
)
# Add signal statistics
fig.add_annotation(
x=0.98,
y=0.98,
xref="x2",
yref="y2",
text=f"Mean: {np.mean(signal_data):.3f}<br>Std: {np.std(signal_data):.3f}<br>Range: {np.max(signal_data) - np.min(signal_data):.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="orange",
borderwidth=2,
font=dict(size=10, color="black"),
)
# 3. Frequency Analysis (bottom-left) - Power spectral density
try:
freqs, psd = signal.welch(
signal_data, fs=sampling_freq, nperseg=min(256, len(signal_data) // 2)
)
fig.add_trace(
go.Scatter(
x=freqs,
y=psd,
mode="lines",
name="Power Spectrum",
line=dict(color="#2ca02c", width=2),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=2,
col=1,
)
# Add frequency band annotations
low_freq_mask = freqs < 1.0
mid_freq_mask = (freqs >= 1.0) & (freqs < 10.0)
high_freq_mask = freqs >= 10.0
if np.any(low_freq_mask):
low_power = np.sum(psd[low_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"Low: {low_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="blue",
)
if np.any(mid_freq_mask):
mid_power = np.sum(psd[mid_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"Mid: {mid_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="green",
)
if np.any(high_freq_mask):
high_power = np.sum(psd[high_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"High: {high_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="red",
)
except Exception as e:
logger.debug(f"Could not create frequency analysis: {e}")
# 4. Poincaré Plot (bottom-right) - RR intervals scatter plot
if len(peaks) > 1:
try:
# Calculate RR intervals
rr_intervals = (
np.diff(peaks) / sampling_freq * 1000
) # Convert to milliseconds
# Create Poincaré plot (RR_n vs RR_{n+1})
# Filter out infinite values
finite_mask = np.isfinite(rr_intervals)
if np.sum(finite_mask) >= 2:
finite_rr_intervals = rr_intervals[finite_mask]
fig.add_trace(
go.Scatter(
x=finite_rr_intervals[:-1],
y=finite_rr_intervals[1:],
mode="markers",
name="Poincaré Plot",
marker=dict(
color="rgba(156, 39, 176, 0.7)",
size=8,
symbol="circle",
line=dict(color="rgba(156, 39, 176, 1)", width=1),
),
),
row=2,
col=2,
)
# Calculate Poincaré plot statistics
diff_rr = np.diff(rr_intervals)
sd1 = np.std(diff_rr) / np.sqrt(2)
sd2 = np.std(rr_intervals) * np.sqrt(2)
# Add SD1 and SD2 ellipses
finite_mask = np.isfinite(rr_intervals)
if np.sum(finite_mask) >= 2:
finite_rr_intervals = rr_intervals[finite_mask]
center_x, center_y = (
(
np.mean(finite_rr_intervals[:-1])
if len(finite_rr_intervals) > 1
else 0
),
(
np.mean(finite_rr_intervals[1:])
if len(finite_rr_intervals) > 1
else 0
),
)
# Create SD1 ellipse points
theta = np.linspace(0, 2 * np.pi, 100)
x_ellipse = center_x + sd1 * np.cos(theta)
y_ellipse = center_y + sd1 * np.sin(theta)
fig.add_trace(
go.Scatter(
x=x_ellipse,
y=y_ellipse,
mode="lines",
name="SD1 Ellipse",
line=dict(color="red", width=2, dash="dash"),
showlegend=False,
),
row=2,
col=2,
)
# Add Poincaré statistics annotation
fig.add_annotation(
x=0.98,
y=0.98,
xref="x4",
yref="y4",
text=f"SD1: {sd1:.1f} ms<br>SD2: {sd2:.1f} ms<br>RR Count: {len(rr_intervals)}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="purple",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception as e:
logger.debug(f"Could not create Poincaré plot: {e}")
# Fallback: show signal envelope instead
try:
analytic_signal = signal.hilbert(signal_data)
envelope = np.abs(analytic_signal)
fig.add_trace(
go.Scatter(
x=time_data,
y=envelope,
mode="lines",
name="Signal Envelope",
line=dict(color="#d62728", width=2),
),
row=2,
col=2,
)
fig.add_annotation(
x=0.98,
y=0.98,
xref="x4",
yref="y4",
text=f"Envelope Mean: {np.mean(envelope):.3f}<br>Envelope Std: {np.std(envelope):.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception:
pass
else:
# If no peaks, show signal quality metrics instead
try:
# Calculate quality metrics
baseline = np.mean(signal_data)
noise_level = np.std(signal_data)
snr = (
20 * np.log10(np.abs(baseline) / noise_level)
if noise_level > 0
else 0
)
# Create quality score over time
window_size = min(50, len(signal_data) // 10)
if window_size > 1:
quality_scores = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std > 0:
quality_score = np.abs(window_mean) / window_std
else:
quality_score = 0
quality_scores.append(quality_score)
time_windows.append(time_data[i + window_size // 2])
if len(quality_scores) > 0:
fig.add_trace(
go.Scatter(
x=time_windows,
y=quality_scores,
mode="lines+markers",
name="Quality Score",
line=dict(color="#d62728", width=2),
marker=dict(color="#d62728", size=4, symbol="circle"),
),
row=2,
col=2,
)
# Add quality statistics
mean_quality = np.mean(quality_scores)
fig.add_annotation(
x=0.98,
y=0.98,
xref="x4",
yref="y4",
text=f"Mean Quality: {mean_quality:.2f}<br>SNR: {snr:.1f} dB<br>Noise Level: {noise_level:.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception as e:
logger.debug(f"Could not create quality metrics: {e}")
# Enhanced layout with better spacing and sizing
fig.update_layout(
title=dict(
text=f"Comprehensive {signal_type.upper()} Analysis Dashboard",
x=0.5,
font=dict(size=18, color="#2c3e50"),
),
height=800, # Increased height for better plot visibility
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
)
# Update all subplot axes with better grid and styling
for i in range(1, 3):
for j in range(1, 3):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=j,
title_font=dict(size=12),
tickfont=dict(size=10),
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=j,
title_font=dict(size=12),
tickfont=dict(size=10),
)
# Add specific axis titles for better clarity
fig.update_xaxes(title_text="Time (s)", row=1, col=1)
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
fig.update_xaxes(title_text="Value", row=1, col=2)
fig.update_yaxes(title_text="Density", row=1, col=2)
fig.update_xaxes(title_text="Frequency (Hz)", row=2, col=1)
fig.update_yaxes(title_text="Power", row=2, col=1)
fig.update_xaxes(title_text="RR_n (ms)", row=2, col=2)
fig.update_yaxes(title_text="RR_{n+1} (ms)", row=2, col=2)
return fig
except Exception as e:
logger.error(f"Error creating comprehensive analysis plots: {e}")
return create_empty_figure()
[docs]
def analyze_hrv_fallback(signal_data, sampling_freq, hrv_options):
"""Fallback HRV analysis when vitalDSP is not available."""
try:
# Check if signal is long enough for HRV analysis
min_samples_for_hrv = 5 * sampling_freq # At least 5 seconds
if len(signal_data) < min_samples_for_hrv:
logger.warning(
f"Signal too short for HRV analysis ({len(signal_data)} samples, need at least {min_samples_for_hrv})"
)
return {
"error": f"Signal too short for HRV analysis. Need at least {min_samples_for_hrv} samples."
}
# Fallback to basic analysis if vitalDSP not available
logger.warning("vitalDSP HRV features not available, using basic analysis")
# Find peaks using scipy (simpler fallback)
peaks, _ = signal.find_peaks(
signal_data,
height=np.mean(signal_data) + np.std(signal_data),
distance=int(sampling_freq * 0.3),
)
if len(peaks) < 2:
return {"error": "Insufficient peaks for HRV analysis"}
# Calculate RR intervals
rr_intervals = np.diff(peaks) / sampling_freq * 1000 # Convert to milliseconds
hrv_metrics = {}
if "time_domain" in hrv_options:
hrv_metrics.update(
{
"mean_rr": (np.mean(rr_intervals) if len(rr_intervals) > 0 else 0),
"std_rr": np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
"rmssd": (
np.sqrt(np.mean(np.diff(rr_intervals) ** 2))
if len(rr_intervals) > 1
else 0
),
"nn50": (
np.sum(np.abs(np.diff(rr_intervals)) > 50)
if len(rr_intervals) > 1
else 0
),
"pnn50": (
np.sum(np.abs(np.diff(rr_intervals)) > 50)
/ len(rr_intervals)
* 100
if len(rr_intervals) > 1
else 0
),
}
)
return hrv_metrics
except Exception as e:
logger.error(f"Error in HRV fallback analysis: {e}")
return {"error": f"HRV fallback analysis failed: {str(e)}"}
[docs]
def analyze_hrv(signal_data, sampling_freq, hrv_options):
"""Analyze Heart Rate Variability using vitalDSP."""
try:
# Handle case where hrv_options might be None or empty
if not hrv_options:
hrv_options = ["time_domain", "freq_domain"]
# Use vitalDSP for HRV analysis
try:
from vitalDSP.physiological_features.hrv import HRVFeatures
# Create HRV features object
hrv_features = HRVFeatures(signal_data, sampling_freq)
# Compute all HRV features
hrv_result = hrv_features.compute_all_features()
# Map vitalDSP results to our format
mapped_results = {}
if "time_domain" in hrv_options:
time_domain_features = hrv_result.get("time_domain", {})
mapped_results.update(
{
"mean_rr": time_domain_features.get("mean_nn", 0),
"std_rr": time_domain_features.get("std_nn", 0),
"rmssd": time_domain_features.get("rmssd", 0),
"nn50": time_domain_features.get("nn50", 0),
"pnn50": time_domain_features.get("pnn50", 0),
}
)
if "frequency_domain" in hrv_options or "freq_domain" in hrv_options:
freq_domain_features = hrv_result.get("frequency_domain", {})
mapped_results.update(
{
"total_power": freq_domain_features.get("total_power", 0),
"vlf_power": freq_domain_features.get("vlf_power", 0),
"lf_power": freq_domain_features.get("lf_power", 0),
"hf_power": freq_domain_features.get("hf_power", 0),
"lf_hf_ratio": freq_domain_features.get("lf_hf_ratio", 0),
}
)
return mapped_results
except ImportError:
# Use fallback implementation
return analyze_hrv_fallback(signal_data, sampling_freq, hrv_options)
except Exception as e:
logger.error(f"Error in HRV analysis: {e}")
return {"error": f"HRV analysis failed: {str(e)}"}
[docs]
def analyze_morphology(
signal_data, sampling_freq, morphology_options, signal_type=None
):
"""Analyze signal morphology using vitalDSP."""
try:
morphology_metrics = {}
# Handle case where morphology_options might be None or empty
if not morphology_options:
morphology_options = ["peaks", "amplitude", "duration"]
# Use vitalDSP for morphology analysis
try:
from vitalDSP.physiological_features.morphology import MorphologyFeatures
# Create morphology features object
morph_features = MorphologyFeatures(signal_data, sampling_freq)
if "peak_detection" in morphology_options or "peaks" in morphology_options:
# Get peak detection features
peak_features = morph_features.get_peak_features()
morphology_metrics.update(
{
"num_peaks": peak_features.get("num_peaks", 0),
"peak_heights": peak_features.get("peak_heights", []),
"peak_positions": peak_features.get("peak_positions", []),
}
)
if "amplitude" in morphology_options:
# Get amplitude features
amplitude_features = morph_features.get_amplitude_features()
morphology_metrics.update(
{
"mean_amplitude": amplitude_features.get("mean_amplitude", 0),
"std_amplitude": amplitude_features.get("std_amplitude", 0),
"min_amplitude": amplitude_features.get("min_amplitude", 0),
"max_amplitude": amplitude_features.get("max_amplitude", 0),
"peak_to_peak": amplitude_features.get("peak_to_peak", 0),
}
)
if "duration" in morphology_options:
# Get duration features
duration_features = morph_features.get_duration_features()
morphology_metrics.update(duration_features)
except ImportError:
logger.warning("vitalDSP morphology features not available")
return {"error": "Morphology analysis not available"}
if "amplitude" in morphology_options:
morphology_metrics.update(
{
"mean_amplitude": np.mean(signal_data),
"std_amplitude": np.std(signal_data),
"min_amplitude": np.min(signal_data),
"max_amplitude": np.max(signal_data),
"peak_to_peak": np.max(signal_data) - np.min(signal_data),
}
)
if "duration" in morphology_options:
duration_stats = {
"signal_duration": len(signal_data) / sampling_freq,
"sampling_freq": sampling_freq,
"num_samples": len(signal_data),
}
morphology_metrics.update(duration_stats)
morphology_metrics["duration_stats"] = duration_stats
return morphology_metrics
except Exception as e:
logger.error(f"Error in morphology analysis: {e}")
return {"error": f"Morphology analysis failed: {str(e)}"}
[docs]
def analyze_signal_quality(signal_data, sampling_freq):
"""Analyze signal quality metrics."""
try:
# Calculate SNR estimate
signal_power = np.mean(signal_data**2)
noise_power = np.var(signal_data)
snr = (
10 * np.log10(signal_power / noise_power)
if noise_power > 0
else float("inf")
)
# Calculate other quality metrics
quality_metrics = {
"snr_db": snr,
"dynamic_range": np.max(signal_data) - np.min(signal_data),
"mean_value": np.mean(signal_data),
"std_value": np.std(signal_data),
"zero_crossings": np.sum(np.diff(np.sign(signal_data)) != 0),
}
return quality_metrics
except Exception as e:
logger.error(f"Error in quality analysis: {e}")
return {"error": f"Quality analysis failed: {str(e)}"}
[docs]
def analyze_trends(signal_data, sampling_freq):
"""Analyze signal trends."""
try:
# Check if we have enough data for trend analysis
if len(signal_data) < 2:
return {
"trend_slope": 0,
"trend_strength": 0,
"trend_direction": "insufficient_data",
}
# Simple trend analysis using linear regression
time_axis = np.arange(len(signal_data)) / sampling_freq
# Ensure signal_data is not all the same value (which would cause polyfit to fail)
if np.std(signal_data) == 0:
return {
"trend_slope": 0,
"trend_strength": 0,
"trend_direction": "no_variation",
}
coeffs = np.polyfit(time_axis, signal_data, 1)
trend_slope = coeffs[0]
# Calculate trend strength
trend_line = np.polyval(coeffs, time_axis)
signal_variance = np.sum((signal_data - np.mean(signal_data)) ** 2)
# Avoid division by zero
if signal_variance == 0:
trend_strength = 0
else:
trend_strength = 1 - (
np.sum((signal_data - trend_line) ** 2) / signal_variance
)
# Ensure trend_strength is within valid range
trend_strength = max(0, min(1, trend_strength))
trend_metrics = {
"trend_slope": float(trend_slope),
"trend_strength": float(trend_strength),
"trend_direction": (
"increasing"
if trend_slope > 0.15
else "decreasing" if trend_slope < -0.15 else "stable"
),
}
return trend_metrics
except Exception as e:
logger.error(f"Error in trend analysis: {e}")
return {"error": f"Trend analysis failed: {str(e)}"}
[docs]
def create_comprehensive_results_display(results, signal_type, sampling_freq):
"""Create modern, compact results display for all analysis types."""
try:
sections = []
# Create a modern metrics grid
def create_metric_card(title, icon, metrics_dict, color="primary"):
"""Create a modern metric card with compact layout."""
metric_items = []
for key, value in metrics_dict.items():
if isinstance(value, (int, float)) and value != 0:
# Format the value based on its type
if key.endswith("_db"):
formatted_value = f"{value:.1f} dB"
elif key.endswith("_ratio") or key.endswith("_strength"):
# Handle infinity and NaN values
if np.isinf(value):
formatted_value = "N/A (HF power = 0)"
elif np.isnan(value):
formatted_value = "N/A"
else:
formatted_value = f"{value:.3f}"
elif key.endswith("_freq") or key.endswith("_frequency"):
formatted_value = f"{value:.3f} Hz"
elif key == "pnn50" or key == "pnn_20" or key.startswith("pnn_"):
# Percentage values (pNN50, pNN20, etc.)
formatted_value = f"{value:.2f}%"
elif key.endswith("nu_power") or key.endswith("_normalized"):
# Normalized power values (LFnu, HFnu)
formatted_value = f"{value:.3f} (n.u.)"
elif key == "cvnn":
# Coefficient of variation (percentage)
formatted_value = f"{value:.2f}%"
elif (
key.endswith("_entropy")
or key == "sample_entropy"
or key == "approximate_entropy"
):
# Entropy values
if value is None or np.isnan(value):
formatted_value = "N/A"
else:
formatted_value = f"{value:.3f}"
elif (
key.startswith("dfa_")
or key.endswith("_alpha1")
or key.endswith("_alpha2")
):
# DFA alpha values (dimensionless)
formatted_value = f"{value:.3f}"
elif key == "fractal_dimension" or key == "lyapunov_exponent":
# Nonlinear features
formatted_value = f"{value:.3f}"
elif (
key.endswith("_ms")
or key == "mean_rr"
or key == "rmssd"
or key == "sdnn"
or key == "std_nn"
or key == "median_nn"
or key == "iqr_nn"
or key == "sdsd"
or key == "poincare_sd1"
or key == "poincare_sd2"
):
# Values are already in milliseconds - don't apply unit conversion
formatted_value = f"{format_large_number(value, unit='ms')} ms"
elif key == "mean_beat_interval" or key == "beat_variability":
# Convert from seconds to milliseconds
value_ms = value * 1000
formatted_value = (
f"{format_large_number(value_ms, unit='ms')} ms"
)
elif key == "beat_regularity":
# Regularity is dimensionless
formatted_value = f"{value:.3f}"
elif (
key.endswith("_peaks")
or key.endswith("_beats")
or key.endswith("_segments")
or key.endswith("_crossings")
or key.endswith("_anomalies")
or key == "zero_crossings"
or key == "envelope_peaks"
or key == "anomalies_detected"
or key == "nn50"
):
# Integer formatting for counts
formatted_value = format_large_number(value, as_integer=True)
elif (
key.endswith("_power")
or key.endswith("_energy")
or key == "total_energy"
or key == "mean_energy"
or key == "low_freq_energy"
or key == "high_freq_energy"
or key == "wavelet_energy"
or key == "total_power"
):
# Large number formatting for power/energy values
formatted_value = (
f"{format_large_number(value, use_scientific=True)} units²"
)
elif (
key.endswith("_height")
or key.endswith("_amplitude")
or key == "mean_peak_height"
or key == "peak_to_peak"
or key == "std_amplitude"
or key == "envelope_mean"
or key == "envelope_range"
or key == "rms"
or key == "mean"
or key == "median"
or key == "std"
or key == "iqr"
or key == "bayesian_prior_mean"
or key == "ecg_autonomic_response"
or key == "noise_level"
):
# Signal amplitude values
formatted_value = f"{format_large_number(value)} units"
elif (
key == "spectral_centroid"
or key == "fourier_peak"
or key == "hilbert_phase"
or key == "cross_signal_correlation"
or key == "signal_quality_index"
):
# Dimensionless values
formatted_value = format_large_number(value)
elif key == "mean_segment_length" or key == "signal_bandwidth":
# Length/bandwidth values
formatted_value = f"{format_large_number(value)} samples"
elif isinstance(value, float) and abs(value) < 0.01:
formatted_value = f"{value:.2e}"
else:
formatted_value = format_large_number(value)
# Create a compact metric item
metric_items.append(
html.Div(
[
html.Span(
f"{key.replace('_', ' ').title()}: ",
className="text-muted fw-bold",
),
html.Span(formatted_value, className="text-dark"),
],
className="d-flex justify-content-between align-items-center py-1 border-bottom border-light",
)
)
if not metric_items:
return None
return dbc.Card(
[
dbc.CardHeader(
[
html.Div(
[
html.Span(icon, className="me-2 fs-4"),
html.Span(
title, className="fs-6 fw-bold text-dark"
),
],
className="d-flex align-items-center",
)
],
className=f"bg-{color} bg-opacity-10 border-0 py-2",
),
dbc.CardBody(
[html.Div(metric_items, className="small")],
className="py-2 px-3",
),
],
className="h-100 border-0 shadow-sm",
)
# HRV Results - Time Domain
if "hrv_metrics" in results and "error" not in results["hrv_metrics"]:
hrv_time_metrics = {
"mean_rr": results["hrv_metrics"].get("mean_rr", 0),
"sdnn": results["hrv_metrics"].get("sdnn", 0),
"rmssd": results["hrv_metrics"].get("rmssd", 0),
"nn50": results["hrv_metrics"].get("nn50", 0),
"pnn50": results["hrv_metrics"].get("pnn50", 0),
"pnn_20": results["hrv_metrics"].get("pnn_20", 0),
}
hrv_time_card = create_metric_card(
"HRV - Time Domain", "💓", hrv_time_metrics, "danger"
)
if hrv_time_card:
sections.append(
html.Div(
hrv_time_card, className="col-lg-4 col-md-6 col-sm-12 mb-3"
)
)
# HRV Results - Frequency Domain
hrv_freq_metrics = {
"total_power": results["hrv_metrics"].get("total_power", 0),
"ulf_power": results["hrv_metrics"].get("ulf_power", 0),
"vlf_power": results["hrv_metrics"].get("vlf_power", 0),
"lf_power": results["hrv_metrics"].get("lf_power", 0),
"hf_power": results["hrv_metrics"].get("hf_power", 0),
"lf_hf_ratio": results["hrv_metrics"].get("lf_hf_ratio", 0),
"lfnu_power": results["hrv_metrics"].get("lfnu_power", 0),
"hfnu_power": results["hrv_metrics"].get("hfnu_power", 0),
}
hrv_freq_card = create_metric_card(
"HRV - Frequency Domain", "📡", hrv_freq_metrics, "warning"
)
if hrv_freq_card:
sections.append(
html.Div(
hrv_freq_card, className="col-lg-4 col-md-6 col-sm-12 mb-3"
)
)
# HRV Results - Nonlinear
hrv_nonlinear_metrics = {}
if results["hrv_metrics"].get("poincare_sd1") is not None:
hrv_nonlinear_metrics["poincare_sd1"] = results["hrv_metrics"].get(
"poincare_sd1", 0
)
if results["hrv_metrics"].get("poincare_sd2") is not None:
hrv_nonlinear_metrics["poincare_sd2"] = results["hrv_metrics"].get(
"poincare_sd2", 0
)
if results["hrv_metrics"].get("poincare_sd1_sd2_ratio") is not None:
hrv_nonlinear_metrics["poincare_sd1_sd2_ratio"] = results[
"hrv_metrics"
].get("poincare_sd1_sd2_ratio", 0)
if results["hrv_metrics"].get("dfa_alpha1") is not None:
hrv_nonlinear_metrics["dfa_alpha1"] = results["hrv_metrics"].get(
"dfa_alpha1", 0
)
if results["hrv_metrics"].get("sample_entropy") is not None:
hrv_nonlinear_metrics["sample_entropy"] = results["hrv_metrics"].get(
"sample_entropy", 0
)
if results["hrv_metrics"].get("approximate_entropy") is not None:
hrv_nonlinear_metrics["approximate_entropy"] = results[
"hrv_metrics"
].get("approximate_entropy", 0)
if hrv_nonlinear_metrics:
hrv_nonlinear_card = create_metric_card(
"HRV - Nonlinear", "🌀", hrv_nonlinear_metrics, "info"
)
if hrv_nonlinear_card:
sections.append(
html.Div(
hrv_nonlinear_card,
className="col-lg-4 col-md-6 col-sm-12 mb-3",
)
)
# Morphology Results
if (
"morphology_metrics" in results
and "error" not in results["morphology_metrics"]
):
morph_metrics = {
"num_peaks": results["morphology_metrics"].get("num_peaks", 0),
"mean_peak_height": np.mean(
results["morphology_metrics"].get("peak_heights", [0])
),
"peak_to_peak": results["morphology_metrics"].get("peak_to_peak", 0),
"std_amplitude": results["morphology_metrics"].get("std_amplitude", 0),
}
morph_card = create_metric_card(
"Morphology Analysis", "📊", morph_metrics, "info"
)
if morph_card:
sections.append(
html.Div(morph_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Beat-to-Beat Results
if (
"beat2beat_metrics" in results
and "error" not in results["beat2beat_metrics"]
):
b2b_metrics = {
"num_beats": results["beat2beat_metrics"].get("num_beats", 0),
"mean_beat_interval": results["beat2beat_metrics"].get(
"mean_beat_interval", 0
),
"beat_variability": results["beat2beat_metrics"].get(
"beat_variability", 0
),
"beat_regularity": results["beat2beat_metrics"].get(
"beat_regularity", 0
),
}
b2b_card = create_metric_card(
"Beat-to-Beat Analysis", "🫀", b2b_metrics, "success"
)
if b2b_card:
sections.append(
html.Div(b2b_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Energy Results
if "energy_metrics" in results and "error" not in results["energy_metrics"]:
energy_metrics = {
"total_energy": results["energy_metrics"].get("total_energy", 0),
"mean_energy": results["energy_metrics"].get("mean_energy", 0),
"low_freq_energy": results["energy_metrics"].get("low_freq_energy", 0),
"high_freq_energy": results["energy_metrics"].get(
"high_freq_energy", 0
),
}
energy_card = create_metric_card(
"Energy Analysis", "⚡", energy_metrics, "warning"
)
if energy_card:
sections.append(
html.Div(energy_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Envelope Results
if "envelope_metrics" in results and "error" not in results["envelope_metrics"]:
env_metrics = {
"envelope_mean": results["envelope_metrics"].get("envelope_mean", 0),
"envelope_range": results["envelope_metrics"].get("envelope_range", 0),
"envelope_peaks": results["envelope_metrics"].get("envelope_peaks", 0),
}
env_card = create_metric_card(
"Envelope Analysis", "📦", env_metrics, "secondary"
)
if env_card:
sections.append(
html.Div(env_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Segmentation Results
if (
"segmentation_metrics" in results
and "error" not in results["segmentation_metrics"]
):
seg_metrics = {
"num_segments": results["segmentation_metrics"].get("num_segments", 0),
"zero_crossings": results["segmentation_metrics"].get(
"zero_crossings", 0
),
"mean_segment_length": results["segmentation_metrics"].get(
"mean_segment_length", 0
),
}
seg_card = create_metric_card(
"Signal Segmentation", "✂️", seg_metrics, "dark"
)
if seg_card:
sections.append(
html.Div(seg_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Waveform Results
if "waveform_metrics" in results and "error" not in results["waveform_metrics"]:
wave_metrics = {
"rms": results["waveform_metrics"].get("rms", 0),
"skewness": results["waveform_metrics"].get("skewness", 0),
"kurtosis": results["waveform_metrics"].get("kurtosis", 0),
"peak_to_peak": results["waveform_metrics"].get("peak_to_peak", 0),
}
wave_card = create_metric_card(
"Waveform Analysis", "🌊", wave_metrics, "primary"
)
if wave_card:
sections.append(
html.Div(wave_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Statistical Results
if (
"statistical_metrics" in results
and "error" not in results["statistical_metrics"]
):
stat_metrics = {
"mean": results["statistical_metrics"].get("mean", 0),
"median": results["statistical_metrics"].get("median", 0),
"std": results["statistical_metrics"].get("std", 0),
"iqr": results["statistical_metrics"].get("iqr", 0),
}
stat_card = create_metric_card(
"Statistical Analysis", "📊", stat_metrics, "info"
)
if stat_card:
sections.append(
html.Div(stat_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Frequency Results
if (
"frequency_metrics" in results
and "error" not in results["frequency_metrics"]
):
freq_metrics = {
"dominant_frequency": results["frequency_metrics"].get(
"dominant_frequency", 0
),
"spectral_centroid": results["frequency_metrics"].get(
"spectral_centroid", 0
),
"lf_power": results["frequency_metrics"]
.get("power_bands", {})
.get("low", 0),
"hf_power": results["frequency_metrics"]
.get("power_bands", {})
.get("high", 0),
}
freq_card = create_metric_card(
"Frequency Analysis", "🔊", freq_metrics, "success"
)
if freq_card:
sections.append(
html.Div(freq_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Advanced Features Results
if (
"advanced_features_metrics" in results
and "error" not in results["advanced_features_metrics"]
):
adv_metrics = {
"cross_signal_correlation": results["advanced_features_metrics"].get(
"cross_signal_correlation", 0
),
"total_power": results["advanced_features_metrics"].get(
"total_power", 0
),
}
adv_card = create_metric_card(
"Advanced Features", "🚀", adv_metrics, "warning"
)
if adv_card:
sections.append(
html.Div(adv_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Quality Results
if "quality_metrics" in results and "error" not in results["quality_metrics"]:
quality_metrics = {
"signal_quality_index": results["quality_metrics"].get(
"signal_quality_index", 0
),
"snr_db": results["quality_metrics"].get("snr_db", 0),
"artifacts_detected": results["quality_metrics"].get(
"artifacts_detected", 0
),
"artifact_ratio": results["quality_metrics"].get("artifact_ratio", 0),
}
quality_card = create_metric_card(
"Signal Quality", "⚖️", quality_metrics, "danger"
)
if quality_card:
sections.append(
html.Div(quality_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Transform Results
if (
"transform_metrics" in results
and "error" not in results["transform_metrics"]
):
transform_metrics = {
"wavelet_energy": results["transform_metrics"].get("wavelet_energy", 0),
"fourier_peak": results["transform_metrics"].get("fourier_peak", 0),
"hilbert_phase": results["transform_metrics"].get("hilbert_phase", 0),
}
transform_card = create_metric_card(
"Signal Transforms", "🔄", transform_metrics, "primary"
)
if transform_card:
sections.append(
html.Div(
transform_card, className="col-lg-4 col-md-6 col-sm-12 mb-3"
)
)
# Advanced Computation Results
if (
"advanced_computation_metrics" in results
and "error" not in results["advanced_computation_metrics"]
):
comp_metrics = {
"anomalies_detected": results["advanced_computation_metrics"].get(
"anomalies_detected", 0
),
"bayesian_prior_mean": results["advanced_computation_metrics"].get(
"bayesian_prior_mean", 0
),
}
comp_card = create_metric_card(
"Advanced Computation", "🧠", comp_metrics, "dark"
)
if comp_card:
sections.append(
html.Div(comp_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Feature Engineering Results
if (
"feature_engineering_metrics" in results
and "error" not in results["feature_engineering_metrics"]
):
feat_metrics = {
"ppg_light_intensity": results["feature_engineering_metrics"].get(
"ppg_light_intensity", 0
),
"ppg_autonomic_response": results["feature_engineering_metrics"].get(
"ppg_autonomic_response", 0
),
"ecg_autonomic_response": results["feature_engineering_metrics"].get(
"ecg_autonomic_response", 0
),
}
feat_card = create_metric_card(
"Feature Engineering", "🔧", feat_metrics, "info"
)
if feat_card:
sections.append(
html.Div(feat_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Preprocessing Results
if (
"preprocessing_metrics" in results
and "error" not in results["preprocessing_metrics"]
):
prep_metrics = {
"noise_level": results["preprocessing_metrics"].get("noise_level", 0),
"signal_bandwidth": results["preprocessing_metrics"].get(
"signal_bandwidth", 0
),
}
prep_card = create_metric_card(
"Preprocessing Analysis", "🔧", prep_metrics, "secondary"
)
if prep_card:
sections.append(
html.Div(prep_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
# Trend Results
if "trend_metrics" in results and "error" not in results["trend_metrics"]:
trend_metrics = {
"trend_direction": results["trend_metrics"].get(
"trend_direction", "Unknown"
),
"trend_strength": results["trend_metrics"].get("trend_strength", 0),
}
trend_card = create_metric_card(
"Trend Analysis", "📈", trend_metrics, "success"
)
if trend_card:
sections.append(
html.Div(trend_card, className="col-lg-4 col-md-6 col-sm-12 mb-3")
)
if not sections:
return html.Div(
[
html.Div(
[
html.I(
className="fas fa-info-circle text-muted fs-1 d-block text-center mb-3"
),
html.H5(
"No Analysis Results",
className="text-center text-muted",
),
html.P(
"Please select analysis categories and run the analysis.",
className="text-center text-muted",
),
],
className="text-center py-5",
)
]
)
# Create a modern grid layout
return html.Div(
[
html.Div(
[
html.H4(
"📋 Analysis Results",
className="text-center mb-4 text-dark",
),
html.P(
f"Comprehensive physiological analysis for {signal_type.upper()} signal",
className="text-center text-muted mb-4",
),
],
className="mb-4",
),
html.Div(sections, className="row g-3"),
]
)
except Exception as e:
logger.error(f"Error creating results display: {e}")
return html.Div(
[
html.Div(
[
html.I(
className="fas fa-exclamation-triangle text-danger fs-1 d-block text-center mb-3"
),
html.H5("Error", className="text-center text-danger"),
html.P(
f"Failed to create results display: {str(e)}",
className="text-center text-muted",
),
],
className="text-center py-5",
)
]
)
[docs]
def create_hrv_plots(
time_data, signal_data, sampling_freq, hrv_options, signal_type=None
):
"""Create HRV-specific plots with enhanced visualization and analysis."""
try:
# Find peaks for RR intervals with enhanced detection
height_threshold = np.mean(signal_data) + 1.5 * np.std(signal_data)
distance_threshold = int(sampling_freq * 0.3)
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
properties = {}
else:
# Use scipy for other signal types
peaks, properties = signal.find_peaks(
signal_data,
height=height_threshold,
distance=distance_threshold,
prominence=np.std(signal_data) * 0.5,
)
if len(peaks) < 2:
return create_empty_figure()
# Calculate RR intervals
rr_intervals = np.diff(peaks) / sampling_freq * 1000 # Convert to milliseconds
# Determine number of subplots based on options
num_plots = 1
if "time_domain" in hrv_options:
num_plots += 1
if "freq_domain" in hrv_options:
num_plots += 1
if "nonlinear" in hrv_options:
num_plots += 1
fig = make_subplots(
rows=num_plots,
cols=1,
subplot_titles=(
["RR Intervals"]
+ (["Time Domain"] if "time_domain" in hrv_options else [])
+ (["Frequency Domain"] if "freq_domain" in hrv_options else [])
+ (["Nonlinear Analysis"] if "nonlinear" in hrv_options else [])
),
vertical_spacing=0.12,
specs=[[{"secondary_y": False}] for _ in range(num_plots)],
)
current_row = 1
# RR intervals over time with enhanced styling
fig.add_trace(
go.Scatter(
x=np.arange(len(rr_intervals)),
y=rr_intervals,
mode="lines+markers",
name="RR Intervals",
line=dict(color="#1f77b4", width=3),
marker=dict(color="#1f77b4", size=6, symbol="circle"),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=current_row,
col=1,
)
# Add RR interval statistics
mean_rr = np.mean(rr_intervals) if len(rr_intervals) > 0 else 0
std_rr = np.std(rr_intervals) if len(rr_intervals) > 0 else 0
fig.add_hline(
y=mean_rr,
line_dash="dash",
line_color="red",
annotation_text=f"Mean: {mean_rr:.1f} ms",
annotation_position="right",
row=current_row,
col=1,
)
# Add confidence bands
fig.add_hline(
y=mean_rr + std_rr,
line_dash="dot",
line_color="orange",
annotation_text=f"+1σ: {mean_rr + std_rr:.1f} ms",
annotation_position="right",
row=current_row,
col=1,
)
fig.add_hline(
y=mean_rr - std_rr,
line_dash="dot",
line_color="orange",
annotation_text=f"-1σ: {mean_rr - std_rr:.1f} ms",
annotation_position="right",
row=current_row,
col=1,
)
current_row += 1
# Time domain analysis with histogram
if "time_domain" in hrv_options:
# Create histogram of RR intervals
hist, bin_edges = np.histogram(
rr_intervals, bins=min(20, len(rr_intervals) // 2), density=True
)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Bar(
x=bin_centers,
y=hist,
name="RR Distribution",
marker_color="rgba(76, 175, 80, 0.7)",
opacity=0.8,
),
row=current_row,
col=1,
)
# Add normal distribution fit
try:
from scipy.stats import norm
x_norm = np.linspace(min(rr_intervals), max(rr_intervals), 100)
y_norm = norm.pdf(x_norm, mean_rr, std_rr)
fig.add_trace(
go.Scatter(
x=x_norm,
y=y_norm,
mode="lines",
name="Normal Fit",
line=dict(color="red", width=2, dash="dash"),
),
row=current_row,
col=1,
)
except Exception:
pass
current_row += 1
# Frequency domain analysis with enhanced PSD
if "freq_domain" in hrv_options:
try:
# Use Lomb-Scargle for unevenly spaced data
# freqs = np.linspace(0.003, 0.5, 1000) # 0.003 to 0.5 Hz (18 to 300 BPM)
# Calculate PSD using Welch method with better parameters
if len(rr_intervals) > 10:
freqs_welch, psd_welch = signal.welch(
rr_intervals,
fs=1000 / np.mean(rr_intervals),
nperseg=min(len(rr_intervals) // 2, 64),
)
fig.add_trace(
go.Scatter(
x=freqs_welch,
y=psd_welch,
mode="lines",
name="Power Spectrum",
line=dict(color="#d62728", width=3),
fill="tonexty",
fillcolor="rgba(214, 39, 40, 0.1)",
),
row=current_row,
col=1,
)
# Add frequency band annotations
vlf_mask = freqs_welch < 0.04
lf_mask = (freqs_welch >= 0.04) & (freqs_welch < 0.15)
hf_mask = (freqs_welch >= 0.15) & (freqs_welch < 0.4)
if np.any(vlf_mask):
vlf_power = np.trapz(psd_welch[vlf_mask], freqs_welch[vlf_mask])
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"VLF: {vlf_power:.1f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="blue",
)
if np.any(lf_mask):
lf_power = np.trapz(psd_welch[lf_mask], freqs_welch[lf_mask])
fig.add_annotation(
x=0.95,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"LF: {lf_power:.1f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="green",
)
if np.any(hf_mask):
hf_power = np.trapz(psd_welch[hf_mask], freqs_welch[hf_mask])
fig.add_annotation(
x=0.92,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"HF: {hf_power:.1f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="red",
)
except Exception as e:
logger.warning(f"Could not create frequency domain plot: {e}")
current_row += 1
# Nonlinear analysis with enhanced Poincaré plot
if "nonlinear" in hrv_options:
# Poincaré plot with enhanced styling
# Filter out infinite values
finite_mask = np.isfinite(rr_intervals)
if np.sum(finite_mask) >= 2:
finite_rr_intervals = rr_intervals[finite_mask]
fig.add_trace(
go.Scatter(
x=finite_rr_intervals[:-1],
y=finite_rr_intervals[1:],
mode="markers",
name="Poincaré Plot",
marker=dict(
color="rgba(156, 39, 176, 0.7)",
size=8,
symbol="circle",
line=dict(color="rgba(156, 39, 176, 1)", width=1),
),
),
row=current_row,
col=1,
)
# Add Poincaré plot statistics
try:
# Calculate SD1 and SD2
diff_rr = np.diff(rr_intervals)
sd1 = np.std(diff_rr) / np.sqrt(2)
sd2 = np.std(rr_intervals) * np.sqrt(2)
# Add ellipses
finite_mask = np.isfinite(rr_intervals)
if np.sum(finite_mask) >= 2:
finite_rr_intervals = rr_intervals[finite_mask]
center_x, center_y = (
(
np.mean(finite_rr_intervals[:-1])
if len(finite_rr_intervals) > 1
else 0
),
(
np.mean(finite_rr_intervals[1:])
if len(finite_rr_intervals) > 1
else 0
),
)
# Create ellipse points
theta = np.linspace(0, 2 * np.pi, 100)
x_ellipse = center_x + sd1 * np.cos(theta)
y_ellipse = center_y + sd1 * np.sin(theta)
fig.add_trace(
go.Scatter(
x=x_ellipse,
y=y_ellipse,
mode="lines",
name="SD1 Ellipse",
line=dict(color="red", width=2, dash="dash"),
showlegend=False,
),
row=current_row,
col=1,
)
# Add statistics annotation
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"SD1: {sd1:.1f} ms<br>SD2: {sd2:.1f} ms",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="purple",
borderwidth=2,
)
except Exception as e:
logger.warning(f"Could not add Poincaré statistics: {e}")
# Enhanced layout
fig.update_layout(
title=dict(
text="Enhanced Heart Rate Variability Analysis",
x=0.5,
font=dict(size=20, color="#2c3e50"),
),
height=250 * num_plots,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
)
# Update all subplot axes
for i in range(1, num_plots + 1):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
return fig
except Exception as e:
logger.error(f"Error creating enhanced HRV plots: {e}")
return create_empty_figure()
[docs]
def create_morphology_plots(
time_data, signal_data, sampling_freq, morphology_options, signal_type=None
):
"""Create enhanced morphology-specific plots with red arrows and comprehensive analysis."""
try:
# Enhanced peak detection
height_threshold = np.mean(signal_data) + 1.5 * np.std(signal_data)
distance_threshold = int(sampling_freq * 0.3)
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
properties = {}
else:
# Use scipy for other signal types
peaks, properties = signal.find_peaks(
signal_data,
height=height_threshold,
distance=distance_threshold,
prominence=np.std(signal_data) * 0.5,
)
# Determine number of subplots based on options
num_plots = 1 # Always show main signal
if "peaks" in morphology_options:
num_plots += 1
if "duration" in morphology_options:
num_plots += 1
if "area" in morphology_options:
num_plots += 1
fig = make_subplots(
rows=num_plots,
cols=1,
subplot_titles=(
["Signal Analysis with Peak Detection"]
+ (["Peak Analysis"] if "peaks" in morphology_options else [])
+ (["Duration Analysis"] if "duration" in morphology_options else [])
+ (["Area Analysis"] if "area" in morphology_options else [])
),
vertical_spacing=0.12,
specs=[[{"secondary_y": False}] for _ in range(num_plots)],
)
current_row = 1
# Main signal with enhanced styling and red arrows for peaks
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="#1f77b4", width=2),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=current_row,
col=1,
)
# Add red arrows for peaks
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_data[peaks],
y=signal_data[peaks],
mode="markers+text",
name="Detected Peaks",
text=[f"P{i+1}" for i in range(len(peaks))],
textposition="top center",
marker=dict(
color="red",
size=12,
symbol="arrow-up",
line=dict(color="darkred", width=2),
),
textfont=dict(color="red", size=10, family="Arial Black"),
),
row=current_row,
col=1,
)
# Add peak statistics annotation
peak_heights = signal_data[peaks]
mean_height = np.mean(peak_heights)
std_height = np.std(peak_heights)
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Peaks: {len(peaks)}<br>Mean Height: {mean_height:.3f}<br>Height Std: {std_height:.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=12, color="black"),
)
current_row += 1
# Enhanced peak analysis with histogram and statistics
if "peaks" in morphology_options and len(peaks) > 0:
# Create histogram of peak heights
hist, bin_edges = np.histogram(
signal_data[peaks], bins=min(15, len(peaks) // 2), density=True
)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Bar(
x=bin_centers,
y=hist,
name="Peak Heights Distribution",
marker_color="rgba(255, 165, 0, 0.7)",
opacity=0.8,
),
row=current_row,
col=1,
)
# Add normal distribution fit
try:
from scipy.stats import norm
x_norm = np.linspace(
min(signal_data[peaks]), max(signal_data[peaks]), 100
)
y_norm = norm.pdf(x_norm, mean_height, std_height)
fig.add_trace(
go.Scatter(
x=x_norm,
y=y_norm,
mode="lines",
name="Normal Fit",
line=dict(color="red", width=2, dash="dash"),
),
row=current_row,
col=1,
)
except Exception:
pass
# Add peak height statistics
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Min: {np.min(peak_heights):.3f}<br>Max: {np.max(peak_heights):.3f}<br>Range: {np.max(peak_heights) - np.min(peak_heights):.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="orange",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Enhanced duration analysis with beat intervals
if "duration" in morphology_options and len(peaks) > 1:
# Calculate beat intervals
beat_intervals = np.diff(peaks) / sampling_freq
fig.add_trace(
go.Scatter(
x=np.arange(len(beat_intervals)),
y=beat_intervals,
mode="lines+markers",
name="Beat Intervals",
line=dict(color="#2ca02c", width=3),
marker=dict(color="#2ca02c", size=6, symbol="diamond"),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=current_row,
col=1,
)
# Add interval statistics
mean_interval = np.mean(beat_intervals) if len(beat_intervals) > 0 else 0
std_interval = np.std(beat_intervals) if len(beat_intervals) > 0 else 0
fig.add_hline(
y=mean_interval,
line_dash="dash",
line_color="red",
annotation_text=f"Mean: {mean_interval:.3f}s",
annotation_position="right",
row=current_row,
col=1,
)
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Mean Interval: {mean_interval:.3f}s<br>Std: {std_interval:.3f}s<br>Heart Rate: {60/mean_interval:.1f} BPM",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="green",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Enhanced area analysis with multiple metrics
if "area" in morphology_options:
# Calculate multiple area metrics
cumulative_area = np.cumsum(np.abs(signal_data))
signal_envelope = np.abs(signal_data)
# Add cumulative area
fig.add_trace(
go.Scatter(
x=time_data,
y=cumulative_area,
mode="lines",
name="Cumulative Area",
line=dict(color="#9467bd", width=3),
fill="tonexty",
fillcolor="rgba(148, 103, 189, 0.1)",
),
row=current_row,
col=1,
)
# Add signal envelope
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_envelope,
mode="lines",
name="Signal Envelope",
line=dict(color="#e377c2", width=2, dash="dot"),
opacity=0.7,
),
row=current_row,
col=1,
)
# Add area statistics
total_area = np.sum(np.abs(signal_data))
mean_area = np.mean(np.abs(signal_data))
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Total Area: {total_area:.3f}<br>Mean Area: {mean_area:.3f}<br>Signal Length: {len(signal_data)} samples",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="purple",
borderwidth=2,
font=dict(size=11, color="black"),
)
# Enhanced layout
fig.update_layout(
title=dict(
text="Enhanced Morphology Analysis with Peak Detection",
x=0.5,
font=dict(size=20, color="#2c3e50"),
),
height=250 * num_plots,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
)
# Update all subplot axes
for i in range(1, num_plots + 1):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
return fig
except Exception as e:
logger.error(f"Error creating enhanced morphology plots: {e}")
return create_empty_figure()
# New Analysis Functions for Comprehensive Physiological Features
[docs]
def analyze_beat_to_beat(signal_data, sampling_freq, signal_type=None):
"""Analyze beat-to-beat characteristics."""
try:
# Find peaks for beat analysis
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
else:
# Use scipy for other signal types
peaks, _ = signal.find_peaks(
signal_data,
height=np.mean(signal_data) + np.std(signal_data),
distance=int(sampling_freq * 0.3),
)
if len(peaks) < 2:
return {"error": "Insufficient beats for analysis"}
# Calculate beat intervals
beat_intervals = np.diff(peaks) / sampling_freq
# Calculate beat-to-beat variability
beat_variability = np.std(beat_intervals)
beat_regularity = 1.0 / (1.0 + beat_variability)
return {
"num_beats": len(peaks),
"mean_beat_interval": (
np.mean(beat_intervals) if len(beat_intervals) > 0 else 0
),
"beat_variability": beat_variability,
"beat_regularity": beat_regularity,
"beat_intervals": beat_intervals.tolist(),
}
except Exception as e:
logger.error(f"Error in beat-to-beat analysis: {e}")
return {"error": f"Beat-to-beat analysis failed: {str(e)}"}
[docs]
def analyze_energy(signal_data, sampling_freq, signal_type=None):
"""Analyze signal energy characteristics."""
try:
# Calculate energy metrics
total_energy = np.sum(signal_data**2)
mean_energy = np.mean(signal_data**2)
energy_variance = np.var(signal_data**2)
# Calculate energy in different frequency bands
freqs, psd = signal.welch(signal_data, fs=sampling_freq)
low_freq_energy = np.sum(psd[freqs < 0.1])
mid_freq_energy = np.sum(psd[(freqs >= 0.1) & (freqs < 1.0)])
high_freq_energy = np.sum(psd[freqs >= 1.0])
return {
"total_energy": total_energy,
"mean_energy": mean_energy,
"energy_variance": energy_variance,
"low_freq_energy": low_freq_energy,
"mid_freq_energy": mid_freq_energy,
"high_freq_energy": high_freq_energy,
"energy_distribution": [low_freq_energy, mid_freq_energy, high_freq_energy],
}
except Exception as e:
logger.error(f"Error in energy analysis: {e}")
return {"error": f"Energy analysis failed: {str(e)}"}
[docs]
def analyze_envelope(signal_data, sampling_freq, signal_type=None):
"""Analyze signal envelope characteristics."""
try:
# Calculate analytic signal using Hilbert transform
analytic_signal = signal.hilbert(signal_data)
envelope = np.abs(analytic_signal)
# Calculate envelope metrics
envelope_mean = np.mean(envelope)
envelope_std = np.std(envelope)
envelope_range = np.max(envelope) - np.min(envelope)
# Find envelope peaks
env_peaks, _ = signal.find_peaks(
envelope, height=np.mean(envelope) + np.std(envelope)
)
return {
"upper_envelope": envelope.tolist(),
"lower_envelope": (-envelope).tolist(),
"envelope_mean": envelope_mean,
"envelope_std": envelope_std,
"envelope_range": envelope_range,
"envelope_peaks": len(env_peaks),
"envelope_data": envelope.tolist(),
}
except Exception as e:
logger.error(f"Error in envelope analysis: {e}")
return {"error": f"Envelope analysis failed: {str(e)}"}
[docs]
def analyze_segmentation(signal_data, sampling_freq):
"""Analyze signal segmentation."""
try:
# Simple segmentation based on zero crossings
zero_crossings = np.where(np.diff(np.sign(signal_data)))[0]
# Calculate segment lengths
segment_lengths = np.diff(zero_crossings)
if len(segment_lengths) > 0:
mean_segment_length = np.mean(segment_lengths)
segment_variability = np.std(segment_lengths)
else:
mean_segment_length = 0
segment_variability = 0
return {
"segments": segment_lengths.tolist(),
"num_segments": len(segment_lengths),
"mean_segment_length": mean_segment_length,
"segment_variability": segment_variability,
"zero_crossings": len(zero_crossings),
}
except Exception as e:
logger.error(f"Error in segmentation analysis: {e}")
return {"error": f"Segmentation analysis failed: {str(e)}"}
[docs]
def analyze_statistical(signal_data, sampling_freq):
"""Analyze statistical characteristics."""
try:
# Calculate comprehensive statistical metrics
percentiles = np.percentile(signal_data, [5, 25, 50, 75, 95])
stats = {
"mean": np.mean(signal_data),
"median": np.median(signal_data),
"std": np.std(signal_data),
"variance": np.var(signal_data),
"min": np.min(signal_data),
"max": np.max(signal_data),
"range": np.max(signal_data) - np.min(signal_data),
"percentile_5": percentiles[0],
"percentile_25": percentiles[1],
"percentile_75": percentiles[3],
"percentile_95": percentiles[4],
"iqr": percentiles[3] - percentiles[1],
}
return stats
except Exception as e:
logger.error(f"Error in statistical analysis: {e}")
return {"error": f"Statistical analysis failed: {str(e)}"}
[docs]
def analyze_frequency(signal_data, sampling_freq):
"""Analyze frequency domain characteristics."""
try:
# Calculate power spectral density
freqs, psd = signal.welch(signal_data, fs=sampling_freq)
# Calculate frequency domain metrics with safety checks
total_power = np.sum(psd)
dominant_freq = freqs[np.argmax(psd)]
# Avoid divide by zero warning
if total_power > 0:
spectral_centroid = np.sum(freqs * psd) / total_power
spectral_bandwidth = np.sqrt(
np.sum(((freqs - spectral_centroid) ** 2) * psd) / total_power
)
else:
spectral_centroid = 0
spectral_bandwidth = 0
# Power in different frequency bands
power_bands = {
"very_low": np.sum(psd[freqs < 0.04]),
"low": np.sum(psd[(freqs >= 0.04) & (freqs < 0.15)]),
"high": np.sum(psd[(freqs >= 0.15) & (freqs < 0.4)]),
"very_high": np.sum(psd[freqs >= 0.4]),
}
return {
"total_power": total_power,
"dominant_frequency": dominant_freq,
"spectral_centroid": spectral_centroid,
"spectral_bandwidth": spectral_bandwidth,
"power_bands": power_bands,
"frequencies": freqs.tolist(),
"psd": psd.tolist(),
}
except Exception as e:
logger.error(f"Error in frequency analysis: {e}")
return {"error": f"Frequency analysis failed: {str(e)}"}
[docs]
def analyze_signal_quality_advanced(signal_data, sampling_freq, quality_options=None):
"""Advanced signal quality analysis."""
try:
quality_metrics = {}
# Handle case where quality_options might be None or empty
if not quality_options:
quality_options = [
"quality_index",
"artifact_detection",
"entropy",
"complexity",
]
if "quality_index" in quality_options:
# Calculate signal quality index
signal_power = np.mean(signal_data**2)
noise_power = np.var(signal_data)
snr = (
10 * np.log10(signal_power / noise_power)
if noise_power > 0
else float("inf")
)
quality_index = min(
100, max(0, 100 - abs(snr - 20))
) # Simple quality index
quality_metrics.update(
{"signal_quality_index": quality_index, "snr_db": snr}
)
if "artifact_detection" in quality_options:
# Simple artifact detection
threshold = np.mean(signal_data) + 3 * np.std(signal_data)
artifacts = np.sum(np.abs(signal_data) > threshold)
artifact_ratio = artifacts / len(signal_data)
quality_metrics.update(
{"artifacts_detected": artifacts, "artifact_ratio": artifact_ratio}
)
if "snr_estimation" in quality_options:
# Adaptive SNR estimation
signal_power = np.mean(signal_data**2)
noise_power = np.var(signal_data)
snr = (
10 * np.log10(signal_power / noise_power)
if noise_power > 0
else float("inf")
)
quality_metrics.update({"adaptive_snr": snr})
if "entropy" in quality_options:
# Calculate signal entropy
try:
# Discretize signal for entropy calculation
hist, _ = np.histogram(
signal_data, bins=min(50, len(signal_data) // 10)
)
hist = hist[hist > 0] # Remove zero bins
if len(hist) > 0:
prob = hist / np.sum(hist)
entropy = -np.sum(prob * np.log2(prob + 1e-10))
quality_metrics["entropy"] = float(entropy)
else:
quality_metrics["entropy"] = 0.0
except Exception:
quality_metrics["entropy"] = 0.0
if "complexity" in quality_options:
# Calculate signal complexity (approximate)
try:
# Use zero crossings as a simple complexity measure
zero_crossings = np.sum(np.diff(np.sign(signal_data)) != 0)
complexity = zero_crossings / len(signal_data)
quality_metrics["complexity"] = float(complexity)
except Exception:
quality_metrics["complexity"] = 0.0
if "blind_source" in quality_options:
# Blind Source Separation using vitalDSP
try:
from vitalDSP.signal_quality_assessment.blind_source_separation import (
center_signal,
whiten_signal,
fast_ica,
)
# For single-channel signal, create multi-channel by time-delayed versions
# This simulates multi-channel data for BSS demonstration
n_channels = 3
delay_samples = max(1, len(signal_data) // 100)
# Create delayed versions
multi_channel = np.zeros((n_channels, len(signal_data)))
multi_channel[0, :] = signal_data
multi_channel[1, delay_samples:] = signal_data[:-delay_samples]
if len(signal_data) > delay_samples * 2:
multi_channel[2, delay_samples * 2 :] = signal_data[
: -delay_samples * 2
]
# Center the signals
centered_signal, mean_signal = center_signal(multi_channel)
# Whiten the signals
whitened_signal, whitening_matrix = whiten_signal(centered_signal.T)
# Apply FastICA
try:
sources, mixing_matrix, unmixing_matrix = fast_ica(
whitened_signal.T, n_components=n_channels, max_iter=200
)
quality_metrics["blind_source_separation"] = {
"n_components": n_channels,
"sources_shape": sources.shape,
"mixing_matrix_shape": mixing_matrix.shape,
"separation_quality": float(
np.mean(np.abs(np.corrcoef(sources)))
),
"status": "success",
}
except Exception as ica_error:
logger.warning(f"FastICA failed: {ica_error}")
quality_metrics["blind_source_separation"] = {
"status": "failed",
"error": str(ica_error),
}
except ImportError:
logger.warning(
"vitalDSP blind_source_separation not available, skipping BSS"
)
quality_metrics["blind_source_separation"] = {
"status": "module_not_available"
}
except Exception as e:
logger.warning(f"Blind source separation failed: {e}")
quality_metrics["blind_source_separation"] = {
"status": "error",
"error": str(e),
}
if "multimodal_artifacts" in quality_options:
# Multi-modal Artifact Detection using vitalDSP
try:
from vitalDSP.filtering.artifact_removal import ArtifactRemoval
from vitalDSP.signal_quality_assessment.artifact_detection import (
z_score_artifact_detection,
adaptive_threshold_artifact_detection,
moving_average_artifact_detection,
)
# Method 1: Z-score based detection
artifacts_zscore = z_score_artifact_detection(
signal_data, z_threshold=3.0
)
zscore_ratio = np.sum(artifacts_zscore) / len(artifacts_zscore)
# Method 2: Adaptive threshold detection
artifacts_adaptive = adaptive_threshold_artifact_detection(
signal_data, window_size=max(50, len(signal_data) // 20)
)
adaptive_ratio = np.sum(artifacts_adaptive) / len(artifacts_adaptive)
# Method 3: Moving average detection
artifacts_ma = moving_average_artifact_detection(
signal_data, window_size=max(50, len(signal_data) // 20)
)
ma_ratio = np.sum(artifacts_ma) / len(artifacts_ma)
# Method 4: Baseline wander detection
ar = ArtifactRemoval(signal_data)
baseline_removed = ar.baseline_correction(cutoff=0.5, fs=sampling_freq)
baseline_artifacts = np.abs(signal_data - baseline_removed) > (
3 * np.std(signal_data)
)
baseline_ratio = np.sum(baseline_artifacts) / len(baseline_artifacts)
# Consensus detection (artifact if detected by at least 2 methods)
consensus_artifacts = (
artifacts_zscore.astype(int)
+ artifacts_adaptive.astype(int)
+ artifacts_ma.astype(int)
+ baseline_artifacts.astype(int)
) >= 2
consensus_ratio = np.sum(consensus_artifacts) / len(consensus_artifacts)
quality_metrics["multimodal_artifacts"] = {
"zscore_artifact_ratio": float(zscore_ratio),
"adaptive_artifact_ratio": float(adaptive_ratio),
"moving_average_artifact_ratio": float(ma_ratio),
"baseline_artifact_ratio": float(baseline_ratio),
"consensus_artifact_ratio": float(consensus_ratio),
"total_artifacts_detected": int(np.sum(consensus_artifacts)),
"clean_signal_percentage": float((1 - consensus_ratio) * 100),
"methods_used": [
"z_score",
"adaptive_threshold",
"moving_average",
"baseline_wander",
],
"status": "success",
}
except ImportError:
logger.warning(
"vitalDSP artifact detection modules not available, using basic detection"
)
# Fallback to simple threshold-based detection
threshold = np.mean(signal_data) + 3 * np.std(signal_data)
artifacts = np.abs(signal_data) > threshold
artifact_ratio = np.sum(artifacts) / len(artifacts)
quality_metrics["multimodal_artifacts"] = {
"artifact_ratio": float(artifact_ratio),
"total_artifacts_detected": int(np.sum(artifacts)),
"status": "fallback_simple",
}
except Exception as e:
logger.warning(f"Multimodal artifact detection failed: {e}")
quality_metrics["multimodal_artifacts"] = {
"status": "error",
"error": str(e),
}
return quality_metrics
except Exception as e:
logger.error(f"Error in advanced quality analysis: {e}")
return {"error": f"Advanced quality analysis failed: {str(e)}"}
[docs]
def analyze_advanced_computation(signal_data, sampling_freq, advanced_computation):
"""Analyze advanced computation features."""
try:
comp_metrics = {}
# Handle case where advanced_computation might be None or empty
if not advanced_computation:
advanced_computation = ["anomaly_detection", "bayesian", "kalman"]
if "anomaly_detection" in advanced_computation:
# Simple anomaly detection
threshold = np.mean(signal_data) + 2 * np.std(signal_data)
anomalies = np.sum(np.abs(signal_data) > threshold)
comp_metrics["anomalies_detected"] = anomalies
if "bayesian" in advanced_computation:
# Simple Bayesian-like analysis
prior_mean = np.mean(signal_data)
prior_std = np.std(signal_data)
comp_metrics["bayesian_prior_mean"] = prior_mean
comp_metrics["bayesian_prior_std"] = prior_std
if "kalman" in advanced_computation:
# Simple Kalman-like analysis
comp_metrics["kalman_estimate"] = np.mean(signal_data)
return comp_metrics
except Exception as e:
logger.error(f"Error in advanced computation analysis: {e}")
return {"error": f"Advanced computation analysis failed: {str(e)}"}
[docs]
def analyze_feature_engineering(
signal_data, sampling_freq, feature_engineering, signal_type
):
"""Analyze feature engineering options."""
try:
feature_metrics = {}
# Handle case where feature_engineering might be None or empty
if not feature_engineering:
feature_engineering = ["ppg_light", "ppg_autonomic", "ecg_autonomic"]
if "ppg_light" in feature_engineering and signal_type == "ppg":
# PPG light features
feature_metrics["ppg_light_intensity"] = np.mean(signal_data)
feature_metrics["ppg_light_variability"] = np.std(signal_data)
if "ppg_autonomic" in feature_engineering and signal_type.lower() == "ppg":
# PPG autonomic features
feature_metrics["ppg_autonomic_response"] = np.std(signal_data)
if "ecg_autonomic" in feature_engineering and signal_type.lower() == "ecg":
# ECG autonomic features
feature_metrics["ecg_autonomic_response"] = np.std(signal_data)
return feature_metrics
except Exception as e:
logger.error(f"Error in feature engineering analysis: {e}")
return {"error": f"Feature engineering analysis failed: {str(e)}"}
[docs]
def analyze_advanced_features(signal_data, sampling_freq, advanced_features):
"""Analyze advanced features."""
try:
advanced_metrics = {}
# Handle case where advanced_features might be None or empty
if not advanced_features:
advanced_features = [
"cross_signal",
"ensemble",
"change_detection",
"power_analysis",
]
if "cross_signal" in advanced_features:
# Cross-signal analysis
try:
if len(signal_data) > 1:
# Ensure we have valid data for correlation
signal_prev = signal_data[:-1]
signal_next = signal_data[1:]
# Check for NaN or infinite values
if (
np.any(np.isnan(signal_prev))
or np.any(np.isnan(signal_next))
or np.any(np.isinf(signal_prev))
or np.any(np.isinf(signal_next))
):
advanced_metrics["cross_signal_correlation"] = 0.0
else:
corr_matrix = np.corrcoef(signal_prev, signal_next)
if corr_matrix.shape == (2, 2) and not np.isnan(
corr_matrix[0, 1]
):
advanced_metrics["cross_signal_correlation"] = float(
corr_matrix[0, 1]
)
else:
advanced_metrics["cross_signal_correlation"] = 0.0
else:
advanced_metrics["cross_signal_correlation"] = 0.0
except Exception as e:
logger.warning(f"Cross-signal correlation failed: {e}")
advanced_metrics["cross_signal_correlation"] = 0.0
if "ensemble" in advanced_features:
# Ensemble analysis
advanced_metrics["ensemble_mean"] = np.mean(signal_data)
advanced_metrics["ensemble_std"] = np.std(signal_data)
if "change_detection" in advanced_features:
# Change detection
diff_signal = np.diff(signal_data)
advanced_metrics["change_points"] = np.sum(
np.abs(diff_signal) > np.std(diff_signal)
)
if "power_analysis" in advanced_features:
# Power analysis
freqs, psd = signal.welch(signal_data, fs=sampling_freq)
advanced_metrics["total_power"] = np.sum(psd)
advanced_metrics["peak_power"] = np.max(psd)
return advanced_metrics
except Exception as e:
logger.error(f"Error in advanced features analysis: {e}")
return {"error": f"Advanced features analysis failed: {str(e)}"}
[docs]
def analyze_preprocessing(signal_data, sampling_freq, preprocessing):
"""Analyze preprocessing options using vitalDSP."""
try:
preprocess_metrics = {}
# Handle case where preprocessing might be None or empty
if not preprocessing:
preprocessing = ["noise_reduction", "baseline_correction", "filtering"]
if "noise_reduction" in preprocessing:
# Enhanced noise reduction using vitalDSP
try:
from vitalDSP.preprocess.noise_reduction import (
wavelet_denoising,
savgol_denoising,
median_denoising,
)
# Calculate original noise level
noise_level = float(np.std(signal_data))
preprocess_metrics["noise_level"] = noise_level
preprocess_metrics["original_noise_level"] = noise_level
# Apply wavelet denoising
denoised_wavelet = wavelet_denoising(
signal_data, wavelet_name="db4", level=3
)
preprocess_metrics["wavelet_denoised_noise_level"] = float(
np.std(signal_data - denoised_wavelet)
)
preprocess_metrics["wavelet_snr_improvement"] = float(
20
* np.log10(
np.std(signal_data)
/ (np.std(signal_data - denoised_wavelet) + 1e-10)
)
)
# Apply Savitzky-Golay denoising
window_length = min(51, len(signal_data) // 10)
if window_length % 2 == 0:
window_length += 1
denoised_savgol = savgol_denoising(
signal_data, window_length=window_length, polyorder=3
)
preprocess_metrics["savgol_denoised_noise_level"] = float(
np.std(signal_data - denoised_savgol)
)
# Apply median denoising
kernel_size = min(5, len(signal_data) // 100)
if kernel_size < 3:
kernel_size = 3
denoised_median = median_denoising(signal_data, kernel_size=kernel_size)
preprocess_metrics["median_denoised_noise_level"] = float(
np.std(signal_data - denoised_median)
)
preprocess_metrics["denoising_methods_available"] = [
"wavelet",
"savgol",
"median",
]
except ImportError:
logger.warning(
"vitalDSP noise_reduction not available, using basic analysis"
)
preprocess_metrics["noise_level"] = float(np.std(signal_data))
except Exception as e:
logger.warning(f"Noise reduction analysis failed: {e}")
preprocess_metrics["noise_level"] = float(np.std(signal_data))
if "baseline_correction" in preprocessing:
# Enhanced baseline correction using vitalDSP
try:
from vitalDSP.filtering.artifact_removal import ArtifactRemoval
baseline = np.mean(signal_data)
preprocess_metrics["baseline_offset"] = float(baseline)
# Apply baseline correction
ar = ArtifactRemoval(signal_data)
corrected_signal = ar.baseline_correction(cutoff=0.5, fs=sampling_freq)
preprocess_metrics["baseline_corrected_offset"] = float(
np.mean(corrected_signal)
)
preprocess_metrics["baseline_correction_applied"] = True
except ImportError:
logger.warning(
"vitalDSP artifact_removal not available, using basic analysis"
)
baseline = np.mean(signal_data)
preprocess_metrics["baseline_offset"] = float(baseline)
except Exception as e:
logger.warning(f"Baseline correction analysis failed: {e}")
baseline = np.mean(signal_data)
preprocess_metrics["baseline_offset"] = float(baseline)
if "filtering" in preprocessing:
# Enhanced filtering analysis
preprocess_metrics["signal_bandwidth"] = float(sampling_freq / 2)
preprocess_metrics["nyquist_frequency"] = float(sampling_freq / 2)
return preprocess_metrics
except Exception as e:
logger.error(f"Error in preprocessing analysis: {e}")
return {"error": f"Preprocessing analysis failed: {str(e)}"}
# Plot Creation Functions for New Analysis Types
[docs]
def create_beat_to_beat_plots(time_data, signal_data, sampling_freq, signal_type=None):
"""Create beat-to-beat analysis plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Beat Detection", "Beat Interval Analysis"),
vertical_spacing=0.1,
)
# Main signal with beats
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Add beat markers
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
else:
# Use scipy for other signal types
peaks, _ = signal.find_peaks(
signal_data,
height=np.mean(signal_data) + np.std(signal_data),
distance=int(sampling_freq * 0.3),
)
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_data[peaks],
y=signal_data[peaks],
mode="markers",
name="Beats",
marker=dict(color="red", size=8),
),
row=1,
col=1,
)
# Beat intervals
if len(peaks) > 1:
intervals = np.diff(peaks) / sampling_freq
fig.add_trace(
go.Scatter(
x=np.arange(len(intervals)),
y=intervals,
mode="lines+markers",
name="Beat Intervals",
line=dict(color="green"),
),
row=2,
col=1,
)
fig.update_layout(
title="Beat-to-Beat Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating beat-to-beat plots: {e}")
return create_empty_figure()
[docs]
def create_energy_plots(time_data, signal_data, sampling_freq):
"""Create energy analysis plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal Energy", "Power Spectral Density"),
vertical_spacing=0.1,
)
# Energy over time
energy = signal_data**2
fig.add_trace(
go.Scatter(
x=time_data,
y=energy,
mode="lines",
name="Energy",
line=dict(color="purple"),
),
row=1,
col=1,
)
# Power spectral density
freqs, psd = signal.welch(signal_data, fs=sampling_freq)
fig.add_trace(
go.Scatter(
x=freqs, y=psd, mode="lines", name="PSD", line=dict(color="orange")
),
row=2,
col=1,
)
fig.update_layout(
title="Energy Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating energy plots: {e}")
return create_empty_figure()
[docs]
def create_envelope_plots(time_data, signal_data, sampling_freq):
"""Create envelope analysis plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal with Envelope", "Envelope Analysis"),
vertical_spacing=0.1,
)
# Main signal
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Envelope
analytic_signal = signal.hilbert(signal_data)
envelope = np.abs(analytic_signal)
fig.add_trace(
go.Scatter(
x=time_data,
y=envelope,
mode="lines",
name="Envelope",
line=dict(color="red"),
),
row=1,
col=1,
)
# Envelope histogram
fig.add_trace(
go.Histogram(
x=envelope,
nbinsx=20,
name="Envelope Distribution",
marker_color="green",
opacity=0.7,
),
row=2,
col=1,
)
fig.update_layout(
title="Envelope Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating envelope plots: {e}")
return create_empty_figure()
[docs]
def create_segmentation_plots(time_data, signal_data, sampling_freq):
"""Create segmentation analysis plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal Segmentation", "Segment Analysis"),
vertical_spacing=0.1,
)
# Main signal
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Zero crossings
zero_crossings = np.where(np.diff(np.sign(signal_data)))[0]
if len(zero_crossings) > 0:
fig.add_trace(
go.Scatter(
x=time_data[zero_crossings],
y=signal_data[zero_crossings],
mode="markers",
name="Zero Crossings",
marker=dict(color="red", size=6),
),
row=1,
col=1,
)
# Segment lengths histogram
if len(zero_crossings) > 1:
segment_lengths = np.diff(zero_crossings)
fig.add_trace(
go.Histogram(
x=segment_lengths,
nbinsx=15,
name="Segment Lengths",
marker_color="orange",
opacity=0.7,
),
row=2,
col=1,
)
fig.update_layout(
title="Segmentation Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating segmentation plots: {e}")
return create_empty_figure()
[docs]
def create_wavelet_plots(time_data, signal_data, sampling_freq):
"""Create wavelet transform plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal", "Wavelet-like Analysis"),
vertical_spacing=0.1,
)
# Main signal
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Simplified wavelet-like analysis
energy = signal_data**2
fig.add_trace(
go.Scatter(
x=time_data,
y=energy,
mode="lines",
name="Energy",
line=dict(color="red"),
),
row=2,
col=1,
)
fig.update_layout(
title="Wavelet Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating wavelet plots: {e}")
return create_empty_figure()
[docs]
def create_fourier_plots(time_data, signal_data, sampling_freq):
"""Create Fourier transform plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal", "Fourier Transform"),
vertical_spacing=0.1,
)
# Main signal
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# FFT
fft_result = np.fft.fft(signal_data)
fft_magnitude = np.abs(fft_result)
freqs = np.fft.fftfreq(len(signal_data), 1 / sampling_freq)
# Only show positive frequencies
pos_mask = freqs > 0
fig.add_trace(
go.Scatter(
x=freqs[pos_mask],
y=fft_magnitude[pos_mask],
mode="lines",
name="FFT Magnitude",
line=dict(color="green"),
),
row=2,
col=1,
)
fig.update_layout(
title="Fourier Transform Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating Fourier plots: {e}")
return create_empty_figure()
[docs]
def create_hilbert_plots(time_data, signal_data, sampling_freq):
"""Create Hilbert transform plots."""
try:
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Signal", "Hilbert Transform"),
vertical_spacing=0.1,
)
# Main signal
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Hilbert transform
analytic_signal = signal.hilbert(signal_data)
phase = np.angle(analytic_signal)
fig.add_trace(
go.Scatter(
x=time_data, y=phase, mode="lines", name="Phase", line=dict(color="red")
),
row=2,
col=1,
)
fig.update_layout(
title="Hilbert Transform Analysis",
height=600,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating Hilbert plots: {e}")
return create_empty_figure()
# New Analysis Functions for Comprehensive Physiological Features
[docs]
def create_frequency_plots(time_data, signal_data, sampling_freq):
"""Create a 2-panel frequency analysis figure:
(1) FFT magnitude spectrum, (2) Welch PSD with HRV bands.
"""
try:
# --- FFT (new plot) ---
sig = np.asarray(signal_data, dtype=float)
fft_result = np.fft.fft(sig)
freqs_fft = np.fft.fftfreq(sig.size, d=1.0 / sampling_freq)
positive_mask = freqs_fft > 0
freqs_fft = freqs_fft[positive_mask]
magnitude = np.abs(fft_result[positive_mask])
# --- Welch PSD (old plot content) ---
nperseg = min(1024, len(sig)) if len(sig) > 0 else 256
freqs_psd, psd = signal.welch(sig, fs=sampling_freq, nperseg=nperseg)
# HRV frequency bands
vlf_mask = freqs_psd < 0.04
lf_mask = (freqs_psd >= 0.04) & (freqs_psd < 0.15)
hf_mask = (freqs_psd >= 0.15) & (freqs_psd < 0.40)
vhf_mask = freqs_psd >= 0.40
# --- Build figure with 2 rows ---
fig = make_subplots(
rows=2,
cols=1,
shared_x=False,
vertical_spacing=0.12,
subplot_titles=(
"Frequency Spectrum (FFT Magnitude)",
"Welch PSD with HRV Bands",
),
)
# Row 1: FFT magnitude spectrum (new)
fig.add_trace(
go.Scatter(
x=freqs_fft,
y=magnitude,
mode="lines",
name="FFT Magnitude",
line=dict(color="blue"),
),
row=1,
col=1,
)
# Row 2: PSD baseline + band segments (old content)
# baseline PSD for context
fig.add_trace(
go.Scatter(
x=freqs_psd,
y=psd,
mode="lines",
name="PSD (Welch)",
line=dict(color="rgba(100,100,100,0.6)"),
),
row=2,
col=1,
)
# colored band segments
fig.add_trace(
go.Scatter(
x=freqs_psd[vlf_mask],
y=psd[vlf_mask],
mode="lines",
name="VLF (<0.04 Hz)",
line=dict(color="red"),
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=freqs_psd[lf_mask],
y=psd[lf_mask],
mode="lines",
name="LF (0.04–0.15 Hz)",
line=dict(color="green"),
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=freqs_psd[hf_mask],
y=psd[hf_mask],
mode="lines",
name="HF (0.15–0.40 Hz)",
line=dict(color="blue"),
),
row=2,
col=1,
)
fig.add_trace(
go.Scatter(
x=freqs_psd[vhf_mask],
y=psd[vhf_mask],
mode="lines",
name="VHF (≥0.40 Hz)",
line=dict(color="purple"),
),
row=2,
col=1,
)
# Axes & layout
fig.update_xaxes(title_text="Frequency (Hz)", row=1, col=1)
fig.update_yaxes(title_text="Magnitude", row=1, col=1)
fig.update_xaxes(title_text="Frequency (Hz)", row=2, col=1)
fig.update_yaxes(
title_text="Power / Frequency", row=2, col=1
) # keep linear to match your old plot
fig.update_layout(
title="Frequency Analysis",
height=700,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255,255,255,0.9)",
bordercolor="rgba(0,0,0,0.2)",
borderwidth=1,
),
margin=dict(l=60, r=200, t=80, b=60),
)
return fig
except Exception as e:
logger.error(f"Error creating frequency plots: {e}")
return create_empty_figure()
# Additional Callback Functions for Enhanced Features
[docs]
def register_additional_physiological_callbacks(app):
"""Register additional physiological analysis callbacks for enhanced features."""
@app.callback(
Output("physio-hrv-options-container", "style"),
[Input("physio-analysis-categories", "value")],
)
def toggle_hrv_options_visibility(analysis_categories):
"""Show/hide HRV options based on analysis category selection."""
if analysis_categories and "hrv" in analysis_categories:
return {"display": "block"}
return {"display": "none"}
@app.callback(
Output("physio-morphology-options-container", "style"),
[Input("physio-analysis-categories", "value")],
)
def toggle_morphology_options_visibility(analysis_categories):
"""Show/hide morphology options based on analysis category selection."""
if analysis_categories and "morphology" in analysis_categories:
return {"display": "block"}
return {"display": "none"}
@app.callback(
Output("physio-additional-analysis-section", "children"),
[Input("store-physio-features", "data")],
)
def update_additional_analysis_section(features_data):
"""Update the additional analysis section with modern, compact feature information."""
if not features_data:
return html.Div(
[
html.Div(
[
html.I(
className="fas fa-chart-line text-muted fs-1 d-block text-center mb-2"
),
html.H5(
"Additional Analysis",
className="text-center text-muted mb-2",
),
html.P(
"Run the main analysis to see additional detailed information here.",
className="text-center text-muted small",
),
],
className="text-center py-4",
)
]
)
# Create modern metric cards with compact layout
def create_compact_metric_card(title, icon, metrics_dict, color="primary"):
"""Create a modern, compact metric card with minimal white space."""
metric_items = []
for key, value in metrics_dict.items():
if isinstance(value, (int, float)) and value != 0:
# Format the value based on its type
if key.endswith("_db"):
formatted_value = f"{value:.1f} dB"
elif key.endswith("_ratio") or key.endswith("_strength"):
# Handle infinity and NaN values
if np.isinf(value):
formatted_value = "N/A (HF power = 0)"
elif np.isnan(value):
formatted_value = "N/A"
else:
formatted_value = f"{value:.3f}"
elif key.endswith("_freq") or key.endswith("_frequency"):
formatted_value = f"{value:.3f} Hz"
elif key == "pnn50" or key == "pnn_20" or key.startswith("pnn_"):
# Percentage values (pNN50, pNN20, etc.)
formatted_value = f"{value:.2f}%"
elif key.endswith("nu_power") or key.endswith("_normalized"):
# Normalized power values (LFnu, HFnu)
formatted_value = f"{value:.3f} (n.u.)"
elif key == "cvnn":
# Coefficient of variation (percentage)
formatted_value = f"{value:.2f}%"
elif (
key.endswith("_entropy")
or key == "sample_entropy"
or key == "approximate_entropy"
):
# Entropy values
if value is None or np.isnan(value):
formatted_value = "N/A"
else:
formatted_value = f"{value:.3f}"
elif (
key.startswith("dfa_")
or key.endswith("_alpha1")
or key.endswith("_alpha2")
):
# DFA alpha values (dimensionless)
formatted_value = f"{value:.3f}"
elif key == "fractal_dimension" or key == "lyapunov_exponent":
# Nonlinear features
formatted_value = f"{value:.3f}"
elif (
key.endswith("_ms")
or key == "mean_rr"
or key == "rmssd"
or key == "sdnn"
or key == "std_nn"
or key == "median_nn"
or key == "iqr_nn"
or key == "sdsd"
or key == "poincare_sd1"
or key == "poincare_sd2"
):
# Values are already in milliseconds - don't apply unit conversion
formatted_value = f"{format_large_number(value, unit='ms')} ms"
elif key == "mean_beat_interval" or key == "beat_variability":
# Convert from seconds to milliseconds
value_ms = value * 1000
formatted_value = (
f"{format_large_number(value_ms, unit='ms')} ms"
)
elif key == "beat_regularity":
# Regularity is dimensionless
formatted_value = f"{value:.3f}"
elif (
key.endswith("_peaks")
or key.endswith("_beats")
or key.endswith("_segments")
or key.endswith("_crossings")
or key.endswith("_anomalies")
or key == "zero_crossings"
or key == "envelope_peaks"
or key == "anomalies_detected"
or key == "nn50"
):
# Integer formatting for counts
formatted_value = format_large_number(value, as_integer=True)
elif (
key.endswith("_power")
or key.endswith("_energy")
or key == "total_energy"
or key == "mean_energy"
or key == "low_freq_energy"
or key == "high_freq_energy"
or key == "wavelet_energy"
or key == "total_power"
):
# Large number formatting for power/energy values
formatted_value = (
f"{format_large_number(value, use_scientific=True)} units²"
)
elif (
key.endswith("_height")
or key.endswith("_amplitude")
or key == "mean_peak_height"
or key == "peak_to_peak"
or key == "std_amplitude"
or key == "envelope_mean"
or key == "envelope_range"
or key == "rms"
or key == "mean"
or key == "median"
or key == "std"
or key == "iqr"
or key == "bayesian_prior_mean"
or key == "ecg_autonomic_response"
or key == "noise_level"
):
# Signal amplitude values
formatted_value = f"{format_large_number(value)} units"
elif (
key == "spectral_centroid"
or key == "fourier_peak"
or key == "hilbert_phase"
or key == "cross_signal_correlation"
or key == "signal_quality_index"
):
# Dimensionless values
formatted_value = format_large_number(value)
elif key == "mean_segment_length" or key == "signal_bandwidth":
# Length/bandwidth values
formatted_value = f"{format_large_number(value)} samples"
elif isinstance(value, float) and abs(value) < 0.01:
formatted_value = f"{value:.2e}"
else:
formatted_value = format_large_number(value)
# Create compact metric item with minimal spacing
metric_items.append(
html.Div(
[
html.Span(
f"{key.replace('_', ' ').title()}: ",
className="text-muted fw-bold small",
),
html.Span(
formatted_value, className="text-dark fw-semibold"
),
],
className="d-flex justify-content-between align-items-center py-1 px-2 border-bottom border-light-subtle",
)
)
elif isinstance(value, str) and value:
metric_items.append(
html.Div(
[
html.Span(
f"{key.replace('_', ' ').title()}: ",
className="text-muted fw-bold small",
),
html.Span(value, className="text-dark fw-semibold"),
],
className="d-flex justify-content-between align-items-center py-1 px-2 border-bottom border-light-subtle",
)
)
elif isinstance(value, list) and len(value) > 0:
if isinstance(value[0], (int, float)):
metric_items.append(
html.Div(
[
html.Span(
f"{key.replace('_', ' ').title()}: ",
className="text-muted fw-bold small",
),
html.Span(
f"{len(value)} values",
className="text-dark fw-semibold",
),
],
className="d-flex justify-content-between align-items-center py-1 px-2 border-bottom border-light-subtle",
)
)
else:
metric_items.append(
html.Div(
[
html.Span(
f"{key.replace('_', ' ').title()}: ",
className="text-muted fw-bold small",
),
html.Span(
f"{len(value)} items",
className="text-dark fw-semibold",
),
],
className="d-flex justify-content-between align-items-center py-1 px-2 border-bottom border-light-subtle",
)
)
if not metric_items:
return None
return dbc.Card(
[
dbc.CardHeader(
[
html.Div(
[
html.Span(icon, className="me-2 fs-5"),
html.Span(
title, className="fs-6 fw-bold text-dark"
),
],
className="d-flex align-items-center",
)
],
className=f"bg-{color} bg-opacity-10 border-0 py-2 px-3",
),
dbc.CardBody(
[html.Div(metric_items, className="small")],
className="py-2 px-3",
),
],
className="h-100 border-0 shadow-sm",
)
# Process features and create compact cards
feature_cards = []
for feature_type, metrics in features_data.items():
if metrics and "error" not in metrics:
# Map feature types to appropriate icons and colors
icon_map = {
"hrv_metrics": ("💓", "danger"),
"morphology_metrics": ("📊", "info"),
"beat2beat_metrics": ("🫀", "success"),
"energy_metrics": ("⚡", "warning"),
"envelope_metrics": ("📦", "secondary"),
"segmentation_metrics": ("✂️", "dark"),
"waveform_metrics": ("🌊", "primary"),
"statistical_metrics": ("📈", "info"),
"frequency_metrics": ("🔊", "success"),
"advanced_features_metrics": ("🚀", "warning"),
"quality_metrics": ("⚖️", "danger"),
"transform_metrics": ("🔄", "primary"),
"advanced_computation_metrics": ("🧠", "dark"),
"feature_engineering_metrics": ("🔧", "info"),
"preprocessing_metrics": ("🔧", "secondary"),
"trend_metrics": ("📈", "success"),
}
icon, color = icon_map.get(feature_type, ("📊", "primary"))
title = feature_type.replace("_", " ").title()
# Create compact metric card
compact_card = create_compact_metric_card(title, icon, metrics, color)
if compact_card:
feature_cards.append(
html.Div(
compact_card, className="col-lg-4 col-md-6 col-sm-12 mb-2"
)
)
if not feature_cards:
return html.Div(
[
html.Div(
[
html.I(
className="fas fa-info-circle text-muted fs-1 d-block text-center mb-2"
),
html.H5(
"No Additional Features",
className="text-center text-muted mb-2",
),
html.P(
"No additional features to display.",
className="text-center text-muted small",
),
],
className="text-center py-4",
)
]
)
# Return modern, compact layout
return html.Div(
[
html.Div(
[
html.H4(
"📋 Detailed Feature Analysis",
className="text-center mb-3 text-dark",
),
html.P(
"Comprehensive physiological feature extraction results",
className="text-center text-muted mb-3 small",
),
],
className="mb-3",
),
html.Div(feature_cards, className="row g-2"), # Reduced gap with g-2
]
)
@app.callback(
Output("physio-btn-export-results", "disabled"),
[Input("store-physio-features", "data")],
)
def toggle_export_button(features_data):
"""Enable/disable export button based on available data."""
return not bool(features_data)
# Enhanced vitalDSP Integration Functions
def _import_vitaldsp_modules():
"""Import vitalDSP modules when needed."""
try:
# Import core vitalDSP modules
import sys
import os
# Add the vitalDSP source path to sys.path
vitaldsp_path = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "vitalDSP"
)
if os.path.exists(vitaldsp_path):
sys.path.insert(0, vitaldsp_path)
logger.info("Added vitalDSP path to sys.path")
# Try to import key modules
try:
from vitalDSP.physiological_features import (
hrv_analysis,
time_domain,
frequency_domain,
nonlinear,
)
logger.info("Successfully imported vitalDSP physiological features modules")
except ImportError as e:
logger.warning(f"Could not import vitalDSP physiological features: {e}")
try:
from vitalDSP.feature_engineering import (
ppg_light_features,
ppg_autonomic_features,
ecg_autonomic_features,
)
logger.info("Successfully imported vitalDSP feature engineering modules")
except ImportError as e:
logger.warning(f"Could not import vitalDSP feature engineering: {e}")
try:
from vitalDSP.signal_quality_assessment import (
signal_quality_index,
artifact_detection_removal,
)
logger.info("Successfully imported vitalDSP signal quality modules")
except ImportError as e:
logger.warning(f"Could not import vitalDSP signal quality modules: {e}")
try:
from vitalDSP.transforms import (
wavelet_transform,
fourier_transform,
hilbert_transform,
)
logger.info("Successfully imported vitalDSP transforms modules")
except ImportError as e:
logger.warning(f"Could not import vitalDSP transforms modules: {e}")
try:
from vitalDSP.advanced_computation import (
anomaly_detection,
bayesian_analysis,
kalman_filter,
)
logger.info("Successfully imported vitalDSP advanced computation modules")
except ImportError as e:
logger.warning(
f"Could not import vitalDSP advanced computation modules: {e}"
)
except Exception as e:
logger.warning(f"Error importing vitalDSP modules: {e}")
logger.info("Continuing with scipy/numpy fallback implementations")
[docs]
def get_vitaldsp_hrv_analysis(
signal_data, sampling_freq, hrv_options, signal_type="PPG"
):
"""Get HRV analysis using vitalDSP modules if available."""
try:
# Check if signal is long enough for HRV analysis
min_samples_for_hrv = 5 * sampling_freq # At least 5 seconds
if len(signal_data) < min_samples_for_hrv:
logger.warning(
f"Signal too short for HRV analysis ({len(signal_data)} samples, need at least {min_samples_for_hrv})"
)
return {
"error": f"Signal too short for HRV analysis. Need at least {min_samples_for_hrv} samples."
}
# Normalize signal type for vitalDSP compatibility
signal_type_upper = normalize_signal_type(signal_type)
try:
from vitalDSP.physiological_features.hrv_analysis import HRVFeatures
# Use vitalDSP HRV analysis with class-based approach
# Note: vitalDSP HRVFeatures may not accept signal_type parameter
try:
hrv_features = HRVFeatures(
signals=signal_data, fs=sampling_freq, signal_type=signal_type_upper
)
except TypeError:
# Fallback if signal_type parameter is not supported
hrv_features = HRVFeatures(signals=signal_data, fs=sampling_freq)
except ImportError:
logger.info(
"vitalDSP HRV module not available, using fallback implementation"
)
return analyze_hrv_fallback(signal_data, sampling_freq, hrv_options)
hrv_result = hrv_features.compute_all_features()
# Map vitalDSP results to our format
mapped_results = {}
if "time_domain" in hrv_options:
# IMPORTANT: vitalDSP returns NN intervals in SECONDS, but HRV metrics should be in MILLISECONDS
# We need to multiply time-domain metrics by 1000 (except counts and percentages)
mapped_results.update(
{
"mean_rr": hrv_result.get("mean_nn", 0)
* 1000, # Convert seconds to milliseconds
"sdnn": hrv_result.get("sdnn", 0)
* 1000, # Convert seconds to milliseconds
"std_nn": hrv_result.get("std_nn", 0)
* 1000, # Convert seconds to milliseconds
"rmssd": hrv_result.get("rmssd", 0)
* 1000, # Convert seconds to milliseconds
"nn50": hrv_result.get("nn50", 0), # Count - no conversion needed
"pnn50": hrv_result.get(
"pnn50", 0
), # Percentage - no conversion needed
"median_nn": hrv_result.get("median_nn", 0)
* 1000, # Convert seconds to milliseconds
"iqr_nn": hrv_result.get("iqr_nn", 0)
* 1000, # Convert seconds to milliseconds
"cvnn": hrv_result.get(
"cvnn", 0
), # Coefficient of variation - dimensionless, no conversion
"sdsd": hrv_result.get("sdsd", 0)
* 1000, # Convert seconds to milliseconds
"pnn_20": hrv_result.get(
"pnn_20", 0
), # Percentage - no conversion needed
}
)
if "freq_domain" in hrv_options:
# Check signal duration for frequency domain analysis
signal_duration = len(signal_data) / sampling_freq # Duration in seconds
min_duration_for_freq = (
120 # 2 minutes recommended for frequency domain HRV
)
if signal_duration < min_duration_for_freq:
logger.warning(
f"Signal duration ({signal_duration:.1f}s) is shorter than recommended for frequency domain HRV "
f"(minimum {min_duration_for_freq}s). Results may be unreliable."
)
# Add warning to results
mapped_results["freq_domain_warning"] = (
f"Signal too short ({signal_duration:.1f}s) for reliable frequency domain HRV. "
f"Recommend ≥{min_duration_for_freq}s (2-5 minutes)."
)
# Get power values
total_power = hrv_result.get("total_power", 0)
ulf_power = hrv_result.get("ulf_power", 0)
vlf_power = hrv_result.get("vlf_power", 0)
lf_power = hrv_result.get("lf_power", 0)
hf_power = hrv_result.get("hf_power", 0)
# Calculate normalized powers (LFnu and HFnu)
lf_hf_sum = lf_power + hf_power
lfnu = (lf_power / lf_hf_sum) if lf_hf_sum > 0 else 0
hfnu = (hf_power / lf_hf_sum) if lf_hf_sum > 0 else 0
mapped_results.update(
{
"total_power": total_power,
"ulf_power": ulf_power,
"vlf_power": vlf_power,
"lf_power": lf_power,
"hf_power": hf_power,
"lf_hf_ratio": hrv_result.get("lf_hf_ratio", 0),
"lfnu_power": hrv_result.get(
"lfnu_power", lfnu
), # Use vitalDSP value or calculated
"hfnu_power": hrv_result.get(
"hfnu_power", hfnu
), # Use vitalDSP value or calculated
}
)
if "nonlinear" in hrv_options:
# Poincaré SD1/SD2 are also in seconds, need to convert to milliseconds
poincare_sd1 = (
hrv_result.get("poincare_sd1", 0) * 1000
) # Convert to milliseconds
poincare_sd2 = (
hrv_result.get("poincare_sd2", 0) * 1000
) # Convert to milliseconds
mapped_results.update(
{
"poincare_sd1": poincare_sd1,
"poincare_sd2": poincare_sd2,
"poincare_sd1_sd2_ratio": (
poincare_sd1 / poincare_sd2 if poincare_sd2 != 0 else 0
),
"dfa_alpha1": hrv_result.get(
"dfa", 0
), # Dimensionless - no conversion
"dfa_alpha2": hrv_result.get(
"dfa_alpha2", 0
), # Dimensionless - no conversion
"sample_entropy": hrv_result.get(
"sample_entropy", None
), # Dimensionless - no conversion
"approximate_entropy": hrv_result.get(
"approximate_entropy", None
), # Dimensionless - no conversion
"fractal_dimension": hrv_result.get(
"fractal_dimension", 0
), # Dimensionless - no conversion
"lyapunov_exponent": hrv_result.get(
"lyapunov_exponent", 0
), # Dimensionless - no conversion
}
)
return mapped_results
except ImportError:
logger.info("vitalDSP HRV module not available, using fallback implementation")
return analyze_hrv(signal_data, sampling_freq, hrv_options)
except Exception as e:
logger.error(f"Error in vitalDSP HRV analysis: {e}")
return analyze_hrv(signal_data, sampling_freq, hrv_options)
[docs]
def get_vitaldsp_morphology_analysis(
signal_data, sampling_freq, morphology_options, signal_type="PPG"
):
"""Get morphology analysis using vitalDSP modules if available."""
try:
# Check if signal is long enough for morphology analysis
min_samples_for_morphology = 2 * sampling_freq # At least 2 seconds
if len(signal_data) < min_samples_for_morphology:
logger.warning(
f"Signal too short for morphology analysis ({len(signal_data)} samples, need at least {min_samples_for_morphology})"
)
return {
"error": f"Signal too short for morphology analysis. Need at least {min_samples_for_morphology} samples."
}
# Normalize signal type for vitalDSP compatibility
signal_type_upper = normalize_signal_type(signal_type)
from vitalDSP.feature_engineering.morphology_features import (
PhysiologicalFeatureExtractor,
)
# Use vitalDSP morphology analysis with class-based approach
extractor = PhysiologicalFeatureExtractor(signal_data, fs=sampling_freq)
# Create a basic peak config
peak_config = {"window_size": 5, "slope_unit": "radians"}
# Extract features with proper signal type
morph_result = extractor.extract_features(
signal_type=signal_type_upper, peak_config=peak_config
)
# Map vitalDSP results to our format
mapped_results = {}
if "peaks" in morphology_options:
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type_upper and signal_type_upper.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type_upper
)
if signal_type_upper.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type_upper.lower() == "ppg":
peaks = wm.systolic_peaks
else:
# Use scipy for other signal types
from scipy import signal
peaks, _ = signal.find_peaks(
signal_data,
height=np.mean(signal_data) + np.std(signal_data),
distance=int(sampling_freq * 0.3),
)
mapped_results.update(
{
"num_peaks": len(peaks),
"peak_heights": (
signal_data[peaks].tolist() if len(peaks) > 0 else []
),
"peak_positions": peaks.tolist() if len(peaks) > 0 else [],
}
)
if "amplitude" in morphology_options:
mapped_results.update(
{
"mean_amplitude": np.mean(signal_data),
"std_amplitude": np.std(signal_data),
"min_amplitude": np.min(signal_data),
"max_amplitude": np.max(signal_data),
"peak_to_peak": np.max(signal_data) - np.min(signal_data),
}
)
if "duration" in morphology_options:
mapped_results.update(
{
"signal_duration": len(signal_data) / sampling_freq,
"sampling_freq": sampling_freq,
"num_samples": len(signal_data),
"systolic_duration": morph_result.get("systolic_duration", 0),
"diastolic_duration": morph_result.get("diastolic_duration", 0),
}
)
if "area" in morphology_options:
mapped_results.update(
{
"total_area": np.sum(np.abs(signal_data)),
"area_under_curve": np.sum(signal_data),
"systolic_area": morph_result.get("systolic_area", 0),
"diastolic_area": morph_result.get("diastolic_area", 0),
}
)
return mapped_results
except ImportError:
logger.info(
"vitalDSP morphology module not available, using fallback implementation"
)
return analyze_morphology(signal_data, sampling_freq, morphology_options)
except Exception as e:
logger.error(f"Error in vitalDSP morphology analysis: {e}")
return analyze_morphology(signal_data, sampling_freq, morphology_options)
[docs]
def get_vitaldsp_signal_quality(
signal_data, sampling_freq, quality_options, signal_type="PPG"
):
"""Get signal quality analysis using vitalDSP modules if available."""
try:
# Check if signal is long enough for quality analysis
min_samples_for_quality = 3 * sampling_freq # At least 3 seconds
if len(signal_data) < min_samples_for_quality:
logger.warning(
f"Signal too short for quality analysis ({len(signal_data)} samples, need at least {min_samples_for_quality})"
)
return {
"error": f"Signal too short for quality analysis. Need at least {min_samples_for_quality} samples."
}
# Normalize signal type for vitalDSP compatibility
signal_type_upper = normalize_signal_type(signal_type)
from vitalDSP.signal_quality_assessment.signal_quality_index import (
SignalQualityIndex,
)
# Use vitalDSP signal quality analysis with class-based approach
sqi = SignalQualityIndex(signal_data)
# Map vitalDSP results to our format
mapped_results = {}
if "quality_index" in quality_options:
# Use a simple approach - compute basic SQI metrics
try:
# Try to compute amplitude variability SQI
window_size = min(100, len(signal_data) // 4)
step_size = max(1, window_size // 2)
# Ensure window size is not larger than signal length
if window_size >= len(signal_data):
window_size = len(signal_data) // 2
step_size = max(1, window_size // 2)
# Use 'minmax' scaling instead of 'zscore' to get values in [0, 1] range
sqi_values, _, _ = sqi.amplitude_variability_sqi(
window_size, step_size, scale="minmax"
)
# Convert to 0-1 range if not already
if len(sqi_values) > 0:
sqi_min = np.min(sqi_values)
sqi_max = np.max(sqi_values)
if sqi_max > sqi_min:
# Normalize to 0-1 range
sqi_normalized = (sqi_values - sqi_min) / (sqi_max - sqi_min)
overall_score = float(np.mean(sqi_normalized))
else:
# All values are the same - use median approach
overall_score = 0.5 # Neutral quality
else:
overall_score = 0
mapped_results.update(
{
"signal_quality_index": overall_score,
"overall_score": overall_score,
}
)
except Exception as e:
logger.warning(f"Could not compute amplitude variability SQI: {e}")
mapped_results.update({"signal_quality_index": 0, "overall_score": 0})
if "artifact_detection" in quality_options:
# Use basic artifact detection
try:
# Simple artifact detection based on amplitude outliers
threshold = np.mean(signal_data) + 3 * np.std(signal_data)
artifacts = np.sum(np.abs(signal_data) > threshold)
artifact_ratio = (
artifacts / len(signal_data) if len(signal_data) > 0 else 0
)
mapped_results.update(
{"artifacts_detected": artifacts, "artifact_ratio": artifact_ratio}
)
except Exception as e:
logger.warning(f"Could not compute artifact detection: {e}")
mapped_results.update({"artifacts_detected": 0, "artifact_ratio": 0})
if "snr_estimation" in quality_options:
# Use improved SNR estimation for filtered signals
try:
# For filtered signals, estimate noise from high-frequency residuals
# Detrend signal to remove baseline
from scipy import signal as scipy_signal
detrended = scipy_signal.detrend(signal_data)
# Use smoothed signal as "clean" signal estimate
# Apply Savitzky-Golay filter for smoothing
window_length = min(51, len(signal_data) // 2 * 2 + 1) # Must be odd
if window_length >= 5:
from scipy.signal import savgol_filter
smoothed = savgol_filter(detrended, window_length, polyorder=3)
# Noise is the residual
noise = detrended - smoothed
signal_power = np.var(smoothed)
noise_power = np.var(noise)
else:
# Fallback for very short signals
signal_power = np.var(detrended)
noise_power = signal_power * 0.1 # Assume 10% noise
snr_db = (
10 * np.log10(signal_power / noise_power)
if noise_power > 0
else 40.0
)
# Clip to reasonable range (0-60 dB)
snr_db = float(np.clip(snr_db, 0, 60))
mapped_results.update({"snr_db": snr_db, "adaptive_snr": snr_db})
except Exception as e:
logger.warning(f"Could not compute SNR estimation: {e}")
mapped_results.update({"snr_db": 0, "adaptive_snr": 0})
return mapped_results
except ImportError:
logger.info(
"vitalDSP signal quality module not available, using fallback implementation"
)
return analyze_signal_quality_advanced(
signal_data, sampling_freq, quality_options
)
except Exception as e:
logger.error(f"Error in vitalDSP signal quality analysis: {e}")
return analyze_signal_quality_advanced(
signal_data, sampling_freq, quality_options
)
[docs]
def get_vitaldsp_advanced_computation(
signal_data, sampling_freq, advanced_computation, signal_type="PPG"
):
"""Get advanced computation features using vitalDSP modules if available."""
try:
# Normalize signal type for vitalDSP compatibility
signal_type_upper = normalize_signal_type(signal_type)
mapped_results = {}
if "anomaly_detection" in advanced_computation:
try:
from vitalDSP.advanced_computation.anomaly_detection import (
AnomalyDetection,
)
# Use vitalDSP anomaly detection with class-based approach
anomaly_detector = AnomalyDetection(signal_data)
anomalies = anomaly_detector.detect_anomalies(
method="z_score", threshold=2.0
)
mapped_results["anomalies_detected"] = len(anomalies)
except ImportError:
threshold = np.mean(signal_data) + 2 * np.std(signal_data)
mapped_results["anomalies_detected"] = np.sum(
np.abs(signal_data) > threshold
)
except Exception as e:
logger.warning(f"vitalDSP anomaly detection failed: {e}")
threshold = np.mean(signal_data) + 2 * np.std(signal_data)
mapped_results["anomalies_detected"] = np.sum(
np.abs(signal_data) > threshold
)
if "bayesian" in advanced_computation:
try:
from vitalDSP.advanced_computation.bayesian_analysis import (
GaussianProcess,
)
# Use vitalDSP bayesian analysis with class-based approach
# Use noise=1e-5 for better numerical stability (not 1e-10)
gp = GaussianProcess(length_scale=1.0, noise=1e-5)
# Create simple training data for demonstration
X_train = np.array([[0.1], [0.4], [0.7]])
y_train = np.array(
[
np.mean(signal_data[: len(signal_data) // 3]),
np.mean(
signal_data[
len(signal_data) // 3 : 2 * len(signal_data) // 3
]
),
np.mean(signal_data[2 * len(signal_data) // 3 :]),
]
)
gp.update(X_train, y_train)
# Predict at a new point
X_new = np.array([[0.5]])
mean, variance = gp.predict(X_new)
mapped_results["bayesian_prior_mean"] = float(mean[0])
mapped_results["bayesian_prior_std"] = float(np.sqrt(variance[0]))
except ImportError:
logger.info(
"vitalDSP bayesian_analysis not available, using basic statistics"
)
mapped_results["bayesian_prior_mean"] = np.mean(signal_data)
mapped_results["bayesian_prior_std"] = np.std(signal_data)
except Exception as e:
logger.warning(
f"vitalDSP bayesian analysis failed: {e}, using fallback"
)
mapped_results["bayesian_prior_mean"] = np.mean(signal_data)
mapped_results["bayesian_prior_std"] = np.std(signal_data)
if "kalman" in advanced_computation:
try:
from vitalDSP.advanced_computation.kalman_filter import KalmanFilter
# Use vitalDSP kalman filter with class-based approach
initial_state = np.array([np.mean(signal_data)])
initial_covariance = np.array([[np.var(signal_data)]])
process_covariance = np.array([[1e-5]])
measurement_covariance = np.array([[np.var(signal_data)]])
kalman = KalmanFilter(
initial_state,
initial_covariance,
process_covariance,
measurement_covariance,
)
# Simple filtering setup
measurement_matrix = np.array([[1]])
transition_matrix = np.array([[1]])
# Apply filter to a subset of the signal for efficiency
subset_size = min(100, len(signal_data))
subset = signal_data[:subset_size]
filtered_signal = kalman.filter(
subset, measurement_matrix, transition_matrix
)
mapped_results["kalman_estimate"] = float(np.mean(filtered_signal))
except ImportError:
mapped_results["kalman_estimate"] = np.mean(signal_data)
except Exception as e:
logger.warning(f"vitalDSP kalman filter failed: {e}")
mapped_results["kalman_estimate"] = np.mean(signal_data)
return mapped_results
except Exception as e:
logger.error(f"Error in vitalDSP advanced computation analysis: {e}")
return analyze_advanced_computation(
signal_data, sampling_freq, advanced_computation
)
[docs]
def get_vitaldsp_feature_engineering(
signal_data, sampling_freq, feature_engineering, signal_type
):
"""Get feature engineering features using vitalDSP modules if available."""
try:
# Ensure signal_type is properly capitalized for vitalDSP
signal_type_upper = signal_type.upper() if signal_type else "PPG"
mapped_results = {}
if "ppg_light" in feature_engineering and signal_type.lower() == "ppg":
try:
from vitalDSP.feature_engineering.ppg_light_features import (
PPGLightFeatureExtractor,
)
# Use vitalDSP PPG light features with class-based approach
# For now, we'll use the same signal for both IR and red (this is a limitation)
light_extractor = PPGLightFeatureExtractor(
signal_data, signal_data, sampling_freq
)
# Calculate perfusion index (this doesn't require red signal)
try:
pi_values, _ = light_extractor.calculate_perfusion_index()
mapped_results["ppg_light_intensity"] = (
float(np.mean(pi_values)) if len(pi_values) > 0 else 0
)
except Exception:
mapped_results["ppg_light_intensity"] = np.mean(signal_data)
mapped_results["ppg_light_variability"] = np.std(signal_data)
except ImportError:
mapped_results["ppg_light_intensity"] = np.mean(signal_data)
mapped_results["ppg_light_variability"] = np.std(signal_data)
except Exception as e:
logger.warning(f"vitalDSP PPG light features failed: {e}")
mapped_results["ppg_light_intensity"] = np.mean(signal_data)
mapped_results["ppg_light_variability"] = np.std(signal_data)
if "ppg_autonomic" in feature_engineering and signal_type.lower() == "ppg":
try:
from vitalDSP.feature_engineering.ppg_autonomic_features import (
PPGAutonomicFeatures,
)
# Use vitalDSP PPG autonomic features with class-based approach
autonomic_extractor = PPGAutonomicFeatures(signal_data, sampling_freq)
# Try to extract autonomic features
try:
# This is a placeholder - the actual method name might be different
if hasattr(autonomic_extractor, "extract_autonomic_features"):
autonomic_result = (
autonomic_extractor.extract_autonomic_features()
)
mapped_results["ppg_autonomic_response"] = (
autonomic_result.get("autonomic_response", 0)
if isinstance(autonomic_result, dict)
else 0
)
else:
mapped_results["ppg_autonomic_response"] = np.std(signal_data)
except Exception:
mapped_results["ppg_autonomic_response"] = np.std(signal_data)
except ImportError:
mapped_results["ppg_autonomic_response"] = np.std(signal_data)
except Exception as e:
logger.warning(f"vitalDSP PPG autonomic features failed: {e}")
mapped_results["ppg_autonomic_response"] = np.std(signal_data)
if "ecg_autonomic" in feature_engineering and signal_type.lower() == "ecg":
try:
from vitalDSP.feature_engineering.ecg_autonomic_features import (
ECGAutonomicFeatureExtractor,
)
# Use vitalDSP ECG autonomic features with class-based approach
ecg_extractor = ECGAutonomicFeatureExtractor(signal_data, sampling_freq)
# Try to extract autonomic features
try:
# This is a placeholder - the actual method name might be different
if hasattr(ecg_extractor, "extract_autonomic_features"):
ecg_result = ecg_extractor.extract_autonomic_features()
mapped_results["ecg_autonomic_response"] = (
ecg_result.get("autonomic_response", 0)
if isinstance(ecg_result, dict)
else 0
)
else:
mapped_results["ecg_autonomic_response"] = np.std(signal_data)
except Exception:
mapped_results["ecg_autonomic_response"] = np.std(signal_data)
except ImportError:
mapped_results["ecg_autonomic_response"] = np.std(signal_data)
except Exception as e:
logger.warning(f"vitalDSP ECG autonomic features failed: {e}")
mapped_results["ecg_autonomic_response"] = np.std(signal_data)
return mapped_results
except Exception as e:
logger.error(f"Error in vitalDSP feature engineering analysis: {e}")
return analyze_feature_engineering(
signal_data, sampling_freq, feature_engineering, signal_type
)
# Update the main analysis function to use vitalDSP when available
[docs]
def suggest_best_signal_column(df, time_col):
"""Suggest the best signal column for physiological analysis."""
signal_candidates = []
for col in df.columns:
if col != time_col and df[col].dtype in [
"float64",
"float32",
"int64",
"int32",
]:
try:
col_data = df[col].values
# Skip columns with all NaN or constant values
if np.any(np.isnan(col_data)) or np.std(col_data) == 0:
continue
# Calculate signal quality metrics
variance = np.var(col_data)
range_val = np.max(col_data) - np.min(col_data)
mean_val = np.mean(col_data)
# Score based on signal characteristics
score = 0
# Higher variance is generally better for physiological signals
if variance > 0.001:
score += 1
if variance > 0.01:
score += 1
if variance > 0.1:
score += 1
# Reasonable range (not too small, not too large)
if 0.01 < range_val < 1000:
score += 1
# Mean should not be extreme
if -100 < mean_val < 100:
score += 1
# Check for physiological signal patterns (basic heuristics)
if len(col_data) > 100:
# Look for some variation (not completely flat)
diff_signal = np.diff(col_data)
if np.std(diff_signal) > 0.001:
score += 1
signal_candidates.append(
{
"column": col,
"score": score,
"variance": variance,
"range": range_val,
"mean": mean_val,
}
)
except Exception as e:
logger.debug(f"Could not analyze column {col}: {e}")
continue
# Sort by score (highest first)
signal_candidates.sort(key=lambda x: x["score"], reverse=True)
return signal_candidates
[docs]
def create_signal_quality_plots(time_data, signal_data, sampling_freq, quality_options):
"""Create comprehensive signal quality analysis plots."""
try:
# Determine number of subplots
num_plots = 1 # Always show main signal
if "quality_index" in quality_options:
num_plots += 1
if "artifact_detection" in quality_options:
num_plots += 1
fig = make_subplots(
rows=num_plots,
cols=1,
subplot_titles=(
["Signal Quality Overview"]
+ (["Quality Metrics"] if "quality_index" in quality_options else [])
+ (
["Artifact Detection"]
if "artifact_detection" in quality_options
else []
)
),
vertical_spacing=0.12,
specs=[[{"secondary_y": False}] for _ in range(num_plots)],
)
current_row = 1
# Main signal with quality indicators
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Original Signal",
line=dict(color="#1f77b4", width=2),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=current_row,
col=1,
)
# Add baseline and noise bands
baseline = np.mean(signal_data)
noise_level = np.std(signal_data)
fig.add_hline(
y=baseline,
line_dash="dash",
line_color="gray",
annotation_text="Baseline",
annotation_position="right",
row=current_row,
col=1,
)
# Add noise bands
fig.add_hline(
y=baseline + noise_level,
line_dash="dot",
line_color="orange",
annotation_text="+1σ Noise",
annotation_position="right",
row=current_row,
col=1,
)
fig.add_hline(
y=baseline - noise_level,
line_dash="dot",
line_color="orange",
annotation_text="-1σ Noise",
annotation_position="right",
row=current_row,
col=1,
)
# Add quality statistics
snr = 20 * np.log10(np.abs(baseline) / noise_level) if noise_level > 0 else 0
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Baseline: {baseline:.3f}<br>Noise Level: {noise_level:.3f}<br>SNR: {snr:.1f} dB",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="blue",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Quality metrics analysis
if "quality_index" in quality_options:
# Calculate moving quality metrics
window_size = min(100, len(signal_data) // 10)
if window_size > 1:
quality_scores = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
window_mean = np.mean(window_data)
window_std = np.std(window_data)
# Simple quality score based on signal-to-noise ratio
if window_std > 0:
quality_score = np.abs(window_mean) / window_std
else:
quality_score = 0
quality_scores.append(quality_score)
time_windows.append(time_data[i + window_size // 2])
if len(quality_scores) > 0:
fig.add_trace(
go.Scatter(
x=time_windows,
y=quality_scores,
mode="lines+markers",
name="Quality Score",
line=dict(color="#2ca02c", width=3),
marker=dict(color="#2ca02c", size=6, symbol="circle"),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=current_row,
col=1,
)
# Add quality threshold
mean_quality = np.mean(quality_scores)
fig.add_hline(
y=mean_quality,
line_dash="dash",
line_color="red",
annotation_text=f"Mean Quality: {mean_quality:.2f}",
annotation_position="right",
row=current_row,
col=1,
)
# Add quality statistics
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Mean Quality: {mean_quality:.2f}<br>Min: {np.min(quality_scores):.2f}<br>Max: {np.max(quality_scores):.2f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="green",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Artifact detection analysis
if "artifact_detection" in quality_options:
# Detect artifacts using multiple methods
artifacts = []
artifact_scores = []
# Method 1: Amplitude outliers
threshold = baseline + 3 * noise_level
amplitude_artifacts = np.abs(signal_data) > threshold
# Method 2: Sudden changes
diff_signal = np.diff(signal_data)
change_threshold = 3 * np.std(diff_signal)
change_artifacts = np.abs(diff_signal) > change_threshold
# Combine artifact detection
for i in range(len(signal_data)):
artifact_score = 0
if amplitude_artifacts[i]:
artifact_score += 1
if i > 0 and change_artifacts[i - 1]:
artifact_score += 1
artifacts.append(artifact_score > 0)
artifact_scores.append(artifact_score)
# Plot artifact scores
fig.add_trace(
go.Scatter(
x=time_data,
y=artifact_scores,
mode="lines",
name="Artifact Score",
line=dict(color="#d62728", width=2),
fill="tonexty",
fillcolor="rgba(214, 39, 40, 0.1)",
),
row=current_row,
col=1,
)
# Highlight artifact regions
artifact_regions = np.where(artifacts)[0]
if len(artifact_regions) > 0:
fig.add_trace(
go.Scatter(
x=time_data[artifact_regions],
y=signal_data[artifact_regions],
mode="markers",
name="Detected Artifacts",
marker=dict(
color="red",
size=8,
symbol="x",
line=dict(color="darkred", width=2),
),
),
row=1,
col=1, # Add to main signal plot
)
# Add artifact statistics
total_artifacts = np.sum(artifacts)
artifact_ratio = (
total_artifacts / len(signal_data) if len(signal_data) > 0 else 0
)
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Total Artifacts: {total_artifacts}<br>Artifact Ratio: {artifact_ratio:.3f}<br>Clean Signal: {100*(1-artifact_ratio):.1f}%",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=11, color="black"),
)
# Enhanced layout
fig.update_layout(
title=dict(
text="Comprehensive Signal Quality Analysis",
x=0.5,
font=dict(size=20, color="#2c3e50"),
),
height=250 * num_plots,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
)
# Update all subplot axes
for i in range(1, num_plots + 1):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
return fig
except Exception as e:
logger.error(f"Error creating signal quality plots: {e}")
return create_empty_figure()
[docs]
def create_advanced_features_plots(
time_data, signal_data, sampling_freq, advanced_features
):
"""Create advanced features analysis plots with cross-signal and ensemble analysis."""
try:
# Determine number of subplots
num_plots = 1 # Always show main signal
if "cross_signal" in advanced_features:
num_plots += 1
if "ensemble" in advanced_features:
num_plots += 1
if "change_detection" in advanced_features:
num_plots += 1
if "power_analysis" in advanced_features:
num_plots += 1
fig = make_subplots(
rows=num_plots,
cols=1,
subplot_titles=(
["Advanced Features Overview"]
+ (
["Cross-Signal Analysis"]
if "cross_signal" in advanced_features
else []
)
+ (["Ensemble Analysis"] if "ensemble" in advanced_features else [])
+ (
["Change Detection"]
if "change_detection" in advanced_features
else []
)
+ (["Power Analysis"] if "power_analysis" in advanced_features else [])
),
vertical_spacing=0.12,
specs=[[{"secondary_y": False}] for _ in range(num_plots)],
)
current_row = 1
# Main signal with enhanced styling
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Original Signal",
line=dict(color="#1f77b4", width=2),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=current_row,
col=1,
)
current_row += 1
# Cross-signal analysis
if "cross_signal" in advanced_features:
# Calculate cross-correlation with shifted versions
max_lag = min(100, len(signal_data) // 4)
lags = np.arange(-max_lag, max_lag + 1)
correlations = []
for lag in lags:
if lag < 0:
corr = np.corrcoef(signal_data[:lag], signal_data[-lag:])[0, 1]
elif lag > 0:
corr = np.corrcoef(signal_data[lag:], signal_data[:-lag])[0, 1]
else:
corr = 1.0
if np.isnan(corr):
corr = 0
correlations.append(corr)
fig.add_trace(
go.Scatter(
x=lags / sampling_freq, # Convert to seconds
y=correlations,
mode="lines",
name="Cross-Correlation",
line=dict(color="#ff7f0e", width=3),
fill="tonexty",
fillcolor="rgba(255, 127, 14, 0.1)",
),
row=current_row,
col=1,
)
# Add correlation statistics
max_corr = np.max(correlations)
max_corr_lag = lags[np.argmax(correlations)] / sampling_freq
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Max Correlation: {max_corr:.3f}<br>At Lag: {max_corr_lag:.3f}s<br>Periodicity: {abs(max_corr_lag):.3f}s",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="orange",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Ensemble analysis
if "ensemble" in advanced_features:
# Create ensemble statistics using sliding windows
window_size = min(50, len(signal_data) // 10)
if window_size > 1:
ensemble_means = []
ensemble_stds = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
ensemble_means.append(np.mean(window_data))
ensemble_stds.append(np.std(window_data))
time_windows.append(time_data[i + window_size // 2])
if len(ensemble_means) > 0:
# Plot ensemble mean
fig.add_trace(
go.Scatter(
x=time_windows,
y=ensemble_means,
mode="lines+markers",
name="Ensemble Mean",
line=dict(color="#2ca02c", width=3),
marker=dict(color="#2ca02c", size=6, symbol="circle"),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=current_row,
col=1,
)
# Plot ensemble standard deviation
fig.add_trace(
go.Scatter(
x=time_windows,
y=ensemble_stds,
mode="lines+markers",
name="Ensemble Std",
line=dict(color="#d62728", width=2, dash="dot"),
marker=dict(color="#d62728", size=4, symbol="diamond"),
yaxis="y2",
),
row=current_row,
col=1,
)
# Add secondary y-axis for std
fig.update_layout(
yaxis2=dict(
title="Standard Deviation",
overlaying="y",
side="right",
showgrid=False,
)
)
# Add ensemble statistics
mean_ensemble = np.mean(ensemble_means)
std_ensemble = np.mean(ensemble_stds)
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Mean Ensemble: {mean_ensemble:.3f}<br>Avg Std: {std_ensemble:.3f}<br>Stability: {1/std_ensemble:.1f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="green",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Change detection analysis
if "change_detection" in advanced_features:
# Detect changes using multiple methods
# Method 1: First derivative
diff_signal = np.diff(signal_data)
change_threshold = 2 * np.std(diff_signal)
change_points = np.abs(diff_signal) > change_threshold
# Method 2: Moving variance
window_size = min(30, len(signal_data) // 10)
if window_size > 1:
moving_variance = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
moving_variance.append(np.var(window_data))
time_windows.append(time_data[i + window_size // 2])
if len(moving_variance) > 0:
fig.add_trace(
go.Scatter(
x=time_windows,
y=moving_variance,
mode="lines",
name="Moving Variance",
line=dict(color="#9467bd", width=3),
fill="tonexty",
fillcolor="rgba(148, 103, 189, 0.1)",
),
row=current_row,
col=1,
)
# Highlight change regions
change_regions = np.where(change_points)[0]
if len(change_regions) > 0:
fig.add_trace(
go.Scatter(
x=time_data[change_regions],
y=signal_data[change_regions],
mode="markers",
name="Change Points",
marker=dict(
color="red",
size=8,
symbol="diamond",
line=dict(color="darkred", width=2),
),
),
row=1,
col=1, # Add to main signal plot
)
# Add change detection statistics
total_changes = np.sum(change_points)
change_rate = (
total_changes / len(diff_signal) if len(diff_signal) > 0 else 0
)
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Change Points: {total_changes}<br>Change Rate: {change_rate:.3f}<br>Stability: {1-change_rate:.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="purple",
borderwidth=2,
font=dict(size=11, color="black"),
)
current_row += 1
# Power analysis
if "power_analysis" in advanced_features:
# Calculate power spectrum
freqs, psd = signal.welch(
signal_data, fs=sampling_freq, nperseg=min(256, len(signal_data) // 2)
)
fig.add_trace(
go.Scatter(
x=freqs,
y=psd,
mode="lines",
name="Power Spectrum",
line=dict(color="#e377c2", width=3),
fill="tonexty",
fillcolor="rgba(227, 119, 194, 0.1)",
),
row=current_row,
col=1,
)
# Add power statistics
total_power = np.sum(psd)
peak_power = np.max(psd)
dominant_freq = freqs[np.argmax(psd)]
# Calculate power in different frequency bands
low_freq_mask = freqs < 1.0 # Below 1 Hz
mid_freq_mask = (freqs >= 1.0) & (freqs < 10.0) # 1-10 Hz
high_freq_mask = freqs >= 10.0 # Above 10 Hz
low_power = np.sum(psd[low_freq_mask]) if np.any(low_freq_mask) else 0
mid_power = np.sum(psd[mid_freq_mask]) if np.any(mid_freq_mask) else 0
high_power = np.sum(psd[high_freq_mask]) if np.any(high_freq_mask) else 0
fig.add_annotation(
x=0.98,
y=0.98,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Total Power: {total_power:.2e}<br>Peak Power: {peak_power:.2e}<br>Dominant Freq: {dominant_freq:.1f} Hz",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="pink",
borderwidth=2,
font=dict(size=11, color="black"),
)
# Add frequency band annotations
if low_power > 0:
fig.add_annotation(
x=0.95,
y=0.9,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Low: {low_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="blue",
)
if mid_power > 0:
fig.add_annotation(
x=0.95,
y=0.9,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"Mid: {mid_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="green",
)
if high_power > 0:
fig.add_annotation(
x=0.95,
y=0.9,
xref=f"x{current_row}",
yref=f"y{current_row}",
text=f"High: {high_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="red",
)
# Enhanced layout
fig.update_layout(
title=dict(
text="Advanced Features Analysis",
x=0.5,
font=dict(size=20, color="#2c3e50"),
),
height=250 * num_plots,
showlegend=True,
legend=dict(
x=0.02,
y=0.98,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
)
# Update all subplot axes
for i in range(1, num_plots + 1):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=1,
)
return fig
except Exception as e:
logger.error(f"Error creating advanced features plots: {e}")
return create_empty_figure()
[docs]
def create_comprehensive_dashboard(
time_data, signal_data, signal_type, sampling_freq, analysis_categories
):
"""Create a comprehensive dashboard with multiple analysis views."""
try:
# Create a 2x2 subplot layout for comprehensive view
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Signal Overview",
"Peak Analysis",
"Frequency Analysis",
"Quality Metrics",
),
vertical_spacing=0.1,
horizontal_spacing=0.1,
specs=[
[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}],
],
)
# 1. Signal Overview (top-left)
fig.add_trace(
go.Scatter(
x=time_data,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="#1f77b4", width=2),
fill="tonexty",
fillcolor="rgba(31, 119, 180, 0.1)",
),
row=1,
col=1,
)
# Add peaks with red arrows
try:
height_threshold = np.mean(signal_data) + 1.5 * np.std(signal_data)
distance_threshold = int(sampling_freq * 0.3)
# Use vitalDSP for ECG/PPG peak detection, scipy for others
if signal_type and signal_type.lower() in ["ecg", "ppg"]:
from vitalDSP.physiological_features.waveform import WaveformMorphology
wm = WaveformMorphology(
signal_data, fs=sampling_freq, signal_type=signal_type.upper()
)
if signal_type.lower() == "ecg":
peaks = wm.r_peaks
elif signal_type.lower() == "ppg":
peaks = wm.systolic_peaks
else:
# Use scipy for other signal types
peaks, _ = signal.find_peaks(
signal_data,
height=height_threshold,
distance=distance_threshold,
prominence=np.std(signal_data) * 0.5,
)
if len(peaks) > 0:
fig.add_trace(
go.Scatter(
x=time_data[peaks],
y=signal_data[peaks],
mode="markers+text",
name="Peaks",
text=[f"P{i+1}" for i in range(len(peaks))],
textposition="top center",
marker=dict(
color="red",
size=10,
symbol="arrow-up",
line=dict(color="darkred", width=2),
),
textfont=dict(color="red", size=8, family="Arial Black"),
),
row=1,
col=1,
)
# Add peak statistics
mean_interval = (
np.mean(np.diff(peaks) / sampling_freq) if len(peaks) > 1 else 0
)
heart_rate = 60 / mean_interval if mean_interval > 0 else 0
fig.add_annotation(
x=0.98,
y=0.98,
xref="x1",
yref="y1",
text=f"Peaks: {len(peaks)}<br>HR: {heart_rate:.1f} BPM<br>Mean Interval: {mean_interval:.2f}s",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception as e:
logger.debug(f"Could not add peaks: {e}")
# 2. Peak Analysis (top-right)
if len(peaks) > 0:
# Histogram of peak heights
hist, bin_edges = np.histogram(
signal_data[peaks], bins=min(15, len(peaks) // 2), density=True
)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Bar(
x=bin_centers,
y=hist,
name="Peak Distribution",
marker_color="rgba(255, 165, 0, 0.7)",
opacity=0.8,
),
row=1,
col=2,
)
# Add peak statistics
peak_heights = signal_data[peaks]
fig.add_annotation(
x=0.98,
y=0.98,
xref="x2",
yref="y2",
text=f"Mean Height: {np.mean(peak_heights):.3f}<br>Height Std: {np.std(peak_heights):.3f}<br>Range: {np.max(peak_heights) - np.min(peak_heights):.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="orange",
borderwidth=2,
font=dict(size=10, color="black"),
)
# 3. Frequency Analysis (bottom-left)
try:
freqs, psd = signal.welch(
signal_data, fs=sampling_freq, nperseg=min(256, len(signal_data) // 2)
)
fig.add_trace(
go.Scatter(
x=freqs,
y=psd,
mode="lines",
name="Power Spectrum",
line=dict(color="#2ca02c", width=2),
fill="tonexty",
fillcolor="rgba(44, 160, 44, 0.1)",
),
row=2,
col=1,
)
# Add frequency band annotations
low_freq_mask = freqs < 1.0
mid_freq_mask = (freqs >= 1.0) & (freqs < 10.0)
high_freq_mask = freqs >= 10.0
if np.any(low_freq_mask):
low_power = np.sum(psd[low_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"Low: {low_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="blue",
)
if np.any(mid_freq_mask):
mid_power = np.sum(psd[mid_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"Mid: {mid_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="green",
)
if np.any(high_freq_mask):
high_power = np.sum(psd[high_freq_mask])
fig.add_annotation(
x=0.95,
y=0.9,
xref="x3",
yref="y3",
text=f"High: {high_power:.2e}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="red",
)
except Exception as e:
logger.debug(f"Could not create frequency analysis: {e}")
# 4. Quality Metrics (bottom-right)
try:
# Calculate quality metrics
baseline = np.mean(signal_data)
noise_level = np.std(signal_data)
snr = (
20 * np.log10(np.abs(baseline) / noise_level) if noise_level > 0 else 0
)
# Create quality score over time
window_size = min(50, len(signal_data) // 10)
if window_size > 1:
quality_scores = []
time_windows = []
for i in range(0, len(signal_data) - window_size, window_size // 2):
window_data = signal_data[i : i + window_size]
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std > 0:
quality_score = np.abs(window_mean) / window_std
else:
quality_score = 0
quality_scores.append(quality_score)
time_windows.append(time_data[i + window_size // 2])
if len(quality_scores) > 0:
fig.add_trace(
go.Scatter(
x=time_windows,
y=quality_scores,
mode="lines+markers",
name="Quality Score",
line=dict(color="#d62728", width=2),
marker=dict(color="#d62728", size=4, symbol="circle"),
),
row=2,
col=2,
)
# Add quality statistics
mean_quality = np.mean(quality_scores)
fig.add_annotation(
x=0.98,
y=0.98,
xref="x4",
yref="y4",
text=f"Mean Quality: {mean_quality:.2f}<br>SNR: {snr:.1f} dB<br>Noise Level: {noise_level:.3f}",
showarrow=False,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="red",
borderwidth=2,
font=dict(size=10, color="black"),
)
except Exception as e:
logger.debug(f"Could not create quality metrics: {e}")
# Enhanced layout
fig.update_layout(
title=dict(
text=f"Comprehensive {signal_type.upper()} Analysis Dashboard",
x=0.5,
font=dict(size=18, color="#2c3e50"),
),
height=800,
showlegend=True,
legend=dict(
x=1.02,
y=1.0,
bgcolor="rgba(255, 255, 255, 0.9)",
bordercolor="rgba(0, 0, 0, 0.2)",
borderwidth=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(l=60, r=200, t=80, b=60),
)
# Update all subplot axes
for i in range(1, 3):
for j in range(1, 3):
fig.update_xaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=j,
)
fig.update_yaxes(
showgrid=True,
gridwidth=1,
gridcolor="rgba(128, 128, 128, 0.2)",
row=i,
col=j,
)
return fig
except Exception as e:
logger.error(f"Error creating comprehensive dashboard: {e}")
return create_empty_figure()
# # Comprehensive Dashboard Callback
# @app.callback(
# Output("store-physio-analysis", "data"),
# [Input("btn-comprehensive-dashboard", "n_clicks")],
# [
# State("store-uploaded-data", "data"),
# State("signal-type-select", "value"),
# State("start-time", "value"),
# State("end-time", "value"),
# State("analysis-categories", "value"),
# State("hrv-options", "value"),
# State("morphology-options", "value"),
# State("advanced-features", "value"),
# State("quality-options", "value"),
# State("transform-options", "value"),
# State("advanced-computation", "value"),
# State("feature-engineering", "value"),
# State("preprocessing", "value"),
# ],
# )
# def show_comprehensive_dashboard(
# n_clicks,
# data_store,
# signal_type,
# start_time,
# end_time,
# analysis_categories,
# hrv_options,
# morphology_options,
# advanced_features,
# quality_options,
# transform_options,
# advanced_computation,
# feature_engineering,
# preprocessing,
# ):
# """Show comprehensive dashboard view."""
# if not n_clicks or not data_store:
# return no_update
# try:
# # Get data
# df = pd.DataFrame(data_store["data"])
# if df.empty:
# return no_update
# # Get column mapping
# column_mapping = data_store.get("column_mapping", {})
# sampling_freq = data_store.get("sampling_freq", 100)
# # Extract time and signal columns
# time_col = column_mapping.get("time", df.columns[0])
# signal_col = column_mapping.get("signal", df.columns[1])
# # Extract time and signal data
# time_data = df[time_col].values
# signal_data = df[signal_col].values
# # Apply time window
# start_idx = np.searchsorted(time_data, start_time or 0)
# end_idx = np.searchsorted(time_data, end_time or 10)
# if end_idx > start_idx:
# time_data = time_data[start_idx:end_idx]
# signal_data = signal_data[start_idx:end_idx]
# # Auto-detect signal type if needed
# if signal_type == "auto":
# signal_type = detect_physiological_signal_type(
# signal_data, sampling_freq
# )
# # Create comprehensive dashboard
# dashboard_fig = create_comprehensive_dashboard(
# time_data,
# signal_data,
# signal_type,
# sampling_freq,
# analysis_categories or [],
# )
# # Store dashboard data
# dashboard_data = {
# "type": "comprehensive_dashboard",
# "figure": dashboard_fig,
# "signal_type": signal_type,
# "sampling_freq": sampling_freq,
# "time_data": time_data.tolist(),
# "signal_data": signal_data.tolist(),
# }
# return dashboard_data
# except Exception as e:
# logger.error(f"Error creating comprehensive dashboard: {e}")
# return no_update
# Standalone version for testing
[docs]
def physiological_analysis_callback(
pathname,
n_clicks,
slider_value,
nudge_m10,
nudge_m1,
nudge_p1,
nudge_p10,
start_time,
end_time,
signal_type,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
):
"""Standalone version of physiological analysis callback for testing."""
from dash.exceptions import PreventUpdate
# Determine what triggered this callback
if not callback_context.triggered:
raise PreventUpdate
trigger_id = callback_context.triggered[0]["prop_id"].split(".")[0]
# Only run this when we're on the physiological page
if pathname != "/physiological":
return (
create_empty_figure(),
"Navigate to Physiological Features page",
create_empty_figure(),
None,
None,
)
try:
# Get data from the data service
from vitalDSP_webapp.services.data.enhanced_data_service import (
get_enhanced_data_service,
)
data_service = get_enhanced_data_service()
# Get the most recent data
all_data = data_service.get_all_data()
if not all_data:
return (
create_empty_figure(),
"No data available. Please upload and process data first.",
create_empty_figure(),
None,
None,
)
# Get the most recent data entry
latest_data_id = list(all_data.keys())[-1]
latest_data = all_data[latest_data_id]
# Get column mapping
column_mapping = data_service.get_column_mapping(latest_data_id)
if not column_mapping:
return (
create_empty_figure(),
"Please process your data on the Upload page first (configure column mapping)",
create_empty_figure(),
None,
None,
)
df = data_service.get_data(latest_data_id)
if df is None or df.empty:
return (
create_empty_figure(),
"Data is empty or corrupted.",
create_empty_figure(),
None,
None,
)
# Get sampling frequency from the data info
sampling_freq = latest_data.get("info", {}).get("sampling_freq", 1000)
# Handle time window adjustments for nudge buttons
# Note: Using start_time/end_time from function parameters
# Calculate duration from start_time and end_time if available
if trigger_id in [
"physio-btn-nudge-m10",
"physio-btn-nudge-m1",
"physio-btn-nudge-p1",
"physio-btn-nudge-p10",
]:
# Use start_time as start_position (they're the same conceptually)
start_position = start_time if start_time is not None else 0
duration = (
(end_time - start_time)
if (end_time is not None and start_time is not None)
else 10
)
if trigger_id == "physio-btn-nudge-m10":
start_position = max(0, start_position - 10)
elif trigger_id == "physio-btn-nudge-m1":
start_position = max(0, start_position - 1)
elif trigger_id == "physio-btn-nudge-p1":
start_position = start_position + 1
elif trigger_id == "physio-btn-nudge-p10":
start_position = start_position + 10
else:
# For non-nudge triggers, initialize from parameters
start_position = start_time if start_time is not None else 0
duration = (
(end_time - start_time)
if (end_time is not None and start_time is not None)
else 10
)
# Set default values if not provided
start_position = start_position or 0
duration = duration or 10
signal_type = signal_type or "auto"
analysis_categories = analysis_categories or ["hrv", "morphology"]
hrv_options = hrv_options or ["time_domain"]
morphology_options = morphology_options or ["peaks"]
advanced_features = advanced_features or []
quality_options = quality_options or []
transform_options = transform_options or []
advanced_computation = advanced_computation or []
feature_engineering = feature_engineering or []
preprocessing = preprocessing or []
# Extract signal data for analysis
signal_column = column_mapping.get("signal", "waveform")
if signal_column not in df.columns:
return (
create_empty_figure(),
f"Signal column '{signal_column}' not found in data",
create_empty_figure(),
None,
None,
)
# Apply time windowing
start_sample = int(start_time * sampling_freq)
end_sample = int(end_time * sampling_freq)
start_sample = max(0, start_sample)
end_sample = min(len(df), end_sample)
windowed_df = df.iloc[start_sample:end_sample]
signal_data = windowed_df[signal_column].values
time_axis = np.arange(len(signal_data)) / sampling_freq
# Perform comprehensive physiological analysis
analysis_results = perform_physiological_analysis_enhanced(
time_axis,
signal_data,
signal_type,
sampling_freq,
analysis_categories,
hrv_options,
morphology_options,
advanced_features,
quality_options,
transform_options,
advanced_computation,
feature_engineering,
preprocessing,
)
# Debug: Log analysis results structure
logger.info(
f"Analysis results keys: {list(analysis_results.keys()) if analysis_results else 'None'}"
)
if analysis_results:
for key, value in analysis_results.items():
if isinstance(value, dict):
logger.info(f" {key}: {list(value.keys())}")
else:
logger.info(f" {key}: {type(value)}")
# Create comprehensive visualization plots
logger.info("Creating main plot...")
# Create a simple test plot first to ensure plots are working
main_plot = go.Figure()
main_plot.add_trace(
go.Scatter(
x=time_axis,
y=signal_data,
mode="lines",
name="Signal",
line=dict(color="blue", width=1),
)
)
main_plot.update_layout(
title="Physiological Signal Analysis",
xaxis_title="Time (s)",
yaxis_title="Amplitude",
height=400,
)
logger.info(f"Main plot created with {len(main_plot.data)} traces")
# Also create the comprehensive plot
comprehensive_plot = create_comprehensive_analysis_plot(
time_axis, signal_data, analysis_results
)
logger.info(
f"Comprehensive plot created with {len(comprehensive_plot.data)} traces"
)
# Create specialized plots based on analysis categories
analysis_plots = []
logger.info(f"Creating analysis plots for categories: {analysis_categories}")
if "hrv" in analysis_categories and "hrv_metrics" in analysis_results:
hrv_metrics = analysis_results["hrv_metrics"]
if "rr_intervals" in hrv_metrics:
# Create HRV Poincaré plot
poincare_plot = create_hrv_poincare_plot(
hrv_metrics["rr_intervals"], hrv_metrics
)
analysis_plots.append(poincare_plot)
# Create HRV time series plot
hrv_ts_plot = create_hrv_time_series_plot(
time_axis, hrv_metrics["rr_intervals"], hrv_metrics
)
analysis_plots.append(hrv_ts_plot)
if (
"morphology" in analysis_categories
and "morphology_metrics" in analysis_results
):
morph_metrics = analysis_results["morphology_metrics"]
if "peaks" in morph_metrics and "peak_heights" in morph_metrics:
# Create morphology analysis plot
morph_plot = create_morphology_analysis_plot(
time_axis,
signal_data,
morph_metrics["peaks"],
morph_metrics["peak_heights"],
sampling_freq,
)
analysis_plots.append(morph_plot)
if "energy" in analysis_categories and "energy_metrics" in analysis_results:
energy_metrics = analysis_results["energy_metrics"]
if "frequencies" in energy_metrics and "psd" in energy_metrics:
# Create energy analysis plot
energy_plot = create_energy_analysis_plot(
energy_metrics["frequencies"], energy_metrics["psd"], energy_metrics
)
analysis_plots.append(energy_plot)
if "quality" in analysis_categories and "quality_metrics" in analysis_results:
# Create quality assessment plot
quality_plot = create_quality_assessment_plot(
signal_data, analysis_results["quality_metrics"], time_axis
)
analysis_plots.append(quality_plot)
# Create a simple test plot for analysis plots if none were created
if len(analysis_plots) == 0:
logger.info("No analysis plots created, creating test plot")
analysis_plots_fig = go.Figure()
analysis_plots_fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_data,
mode="lines",
name="Test Signal",
line=dict(color="green", width=2),
)
)
analysis_plots_fig.update_layout(
title="Analysis Plots - Test",
xaxis_title="Time (s)",
yaxis_title="Amplitude",
height=400,
)
# Combine all analysis plots into a single figure if multiple plots exist
elif len(analysis_plots) > 1:
from plotly.subplots import make_subplots
combined_fig = make_subplots(
rows=len(analysis_plots),
cols=1,
subplot_titles=[
f"Analysis Plot {i+1}" for i in range(len(analysis_plots))
],
vertical_spacing=0.1,
)
for i, plot in enumerate(analysis_plots):
for trace in plot.data:
combined_fig.add_trace(trace, row=i + 1, col=1)
combined_fig.update_layout(height=300 * len(analysis_plots))
analysis_plots_fig = combined_fig
else:
analysis_plots_fig = analysis_plots[0]
# Generate results text
results_text = f"Comprehensive analysis completed for {signal_type} signal from {start_position}s to {start_position + duration}s"
if analysis_results:
feature_count = sum(
len(metrics)
for metrics in analysis_results.values()
if isinstance(metrics, dict)
)
results_text += f" - Extracted {feature_count} physiological features across {len(analysis_categories)} categories"
return main_plot, results_text, analysis_plots_fig, analysis_results, None
except Exception as e:
logger.error(f"Error in physiological analysis callback: {e}")
return (
create_empty_figure(),
f"Error: {str(e)}",
create_empty_figure(),
None,
None,
)