Source code for vitalDSP.ml_models.pretrained_models

"""
Pre-trained Models for Physiological Signal Analysis

This module provides a repository of pre-trained models for common physiological
signal analysis tasks, enabling transfer learning and quick deployment.

Features:
- Automatic model download and caching
- Pre-trained models for ECG, PPG, EEG
- Model versioning and metadata
- Easy fine-tuning interface

Available Models:
1. ECG Classification (MIT-BIH, PTB-XL)
2. PPG Quality Assessment
3. EEG Sleep Stage Classification
4. Arrhythmia Detection
5. Heart Rate Estimation

Author: vitalDSP
License: MIT
"""

"""
Machine Learning Models Module for Physiological Signal Processing

This module provides comprehensive capabilities for physiological
signal processing including ECG, PPG, EEG, and other vital signs.

Author: vitalDSP Team
Date: 2025-01-27
Version: 1.0.0

Key Features:
- Object-oriented design with comprehensive classes
- Multiple processing methods and functions
- NumPy integration for numerical computations
- Deep learning framework integration
- Signal validation and error handling

Examples:
---------
Basic usage:
    >>> import numpy as np
    >>> from vitalDSP.ml_models.pretrained_models import PretrainedModels
    >>> signal = np.random.randn(1000)
    >>> pm = PretrainedModel(model_type='ecg_classifier')
    >>> predictions = pm.predict(signal)
    >>> print(f'Processing result: {result}')
"""


import numpy as np
from typing import Optional, Union, Tuple, List, Dict, Any, Callable
from pathlib import Path
import json
import hashlib
import warnings
from urllib.request import urlretrieve
from urllib.error import URLError, HTTPError
import os

try:
    import tensorflow as tf
    from tensorflow import keras

    TENSORFLOW_AVAILABLE = True
except ImportError:
    TENSORFLOW_AVAILABLE = False

try:
    import torch
    import torch.nn as nn

    PYTORCH_AVAILABLE = True
except ImportError:
    PYTORCH_AVAILABLE = False


# Model registry
MODEL_REGISTRY = {
    "ecg_classifier_mitbih": {
        "description": "ECG arrhythmia classifier trained on MIT-BIH dataset",
        "task": "classification",
        "signal_type": "ecg",
        "input_shape": (187, 1),
        "n_classes": 5,
        "classes": ["Normal", "Supraventricular", "Ventricular", "Fusion", "Unknown"],
        "architecture": "cnn1d",
        "backend": "tensorflow",
        "url": None,  # Placeholder - would be actual URL in production
        "size_mb": 2.5,
        "accuracy": 0.98,
        "version": "1.0.0",
        "reference": "MIT-BIH Arrhythmia Database",
    },
    "ecg_classifier_ptbxl": {
        "description": "Multi-label ECG classifier trained on PTB-XL dataset",
        "task": "multi_label_classification",
        "signal_type": "ecg",
        "input_shape": (1000, 12),  # 12-lead ECG
        "n_classes": 71,
        "architecture": "resnet1d",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 15.0,
        "f1_score": 0.92,
        "version": "1.0.0",
        "reference": "PTB-XL Database",
    },
    "ppg_quality_assessment": {
        "description": "PPG signal quality classifier",
        "task": "classification",
        "signal_type": "ppg",
        "input_shape": (250, 1),
        "n_classes": 3,
        "classes": ["Excellent", "Acceptable", "Unacceptable"],
        "architecture": "cnn1d",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 1.8,
        "accuracy": 0.94,
        "version": "1.0.0",
    },
    "eeg_sleep_stage": {
        "description": "EEG sleep stage classifier",
        "task": "classification",
        "signal_type": "eeg",
        "input_shape": (3000, 1),  # 30 seconds at 100 Hz
        "n_classes": 5,
        "classes": ["Wake", "N1", "N2", "N3", "REM"],
        "architecture": "lstm",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 5.2,
        "accuracy": 0.87,
        "version": "1.0.0",
        "reference": "Sleep-EDF Database",
    },
    "heart_rate_estimator": {
        "description": "End-to-end heart rate estimator from PPG/ECG",
        "task": "regression",
        "signal_type": "ppg",
        "input_shape": (500, 1),
        "architecture": "cnn_lstm",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 3.1,
        "mae": 2.3,  # MAE in BPM
        "version": "1.0.0",
    },
    "ecg_autoencoder": {
        "description": "Autoencoder for ECG anomaly detection and denoising",
        "task": "autoencoder",
        "signal_type": "ecg",
        "input_shape": (187, 1),
        "latent_dim": 32,
        "architecture": "convolutional_autoencoder",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 1.5,
        "reconstruction_loss": 0.002,
        "version": "1.0.0",
    },
    "multimodal_transformer": {
        "description": "Transformer for multi-lead ECG analysis",
        "task": "classification",
        "signal_type": "ecg",
        "input_shape": (1000, 12),
        "n_classes": 6,
        "classes": ["NORM", "MI", "STTC", "CD", "HYP", "Other"],
        "architecture": "transformer",
        "backend": "tensorflow",
        "url": None,
        "size_mb": 25.0,
        "accuracy": 0.91,
        "version": "1.0.0",
    },
}


[docs] class PretrainedModel: """ Wrapper for pre-trained physiological signal analysis models. Provides a unified interface for loading, using, and fine-tuning pre-trained models regardless of architecture or backend. Examples -------- >>> from vitalDSP.ml_models.pretrained_models import PretrainedModel >>> import numpy as np >>> >>> # Load pre-trained ECG classifier >>> model = PretrainedModel.from_registry('ecg_classifier_mitbih') >>> print(model.info()) >>> >>> # Make predictions >>> ecg_signals = np.random.randn(10, 187, 1) >>> predictions = model.predict(ecg_signals) >>> print(f"Predicted classes: {predictions}") >>> >>> # Fine-tune on your data >>> X_train = np.random.randn(100, 187, 1) >>> y_train = np.random.randint(0, 5, 100) >>> model.fine_tune(X_train, y_train, epochs=10) """ def __init__( self, model: Any, metadata: Dict[str, Any], backend: str = "tensorflow" ): """ Initialize pre-trained model. Parameters ---------- model : object The trained model metadata : dict Model metadata (architecture, task, etc.) backend : str, default='tensorflow' Deep learning backend """ self.model = model self.metadata = metadata self.backend = backend.lower()
[docs] @classmethod def from_registry( cls, model_name: str, cache_dir: Optional[str] = None, force_download: bool = False, ) -> "PretrainedModel": """ Load model from registry. Parameters ---------- model_name : str Name of the model in the registry cache_dir : str, optional Directory to cache downloaded models. If None, use default. force_download : bool, default=False Force re-download even if cached Returns ------- PretrainedModel Loaded pre-trained model Raises ------ ValueError If model not found in registry """ if model_name not in MODEL_REGISTRY: available = ", ".join(MODEL_REGISTRY.keys()) raise ValueError( f"Model '{model_name}' not found in registry. " f"Available models: {available}" ) metadata = MODEL_REGISTRY[model_name].copy() # Setup cache directory if cache_dir is None: cache_dir = Path.home() / ".vitaldsp" / "pretrained_models" cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) # Model path model_path = cache_dir / f"{model_name}_v{metadata['version']}" backend = metadata["backend"] # Download or load from cache if not model_path.exists() or force_download: if metadata["url"] is None: # For demonstration, create a simple model warnings.warn( f"Model '{model_name}' URL not available. " "Creating a placeholder model for demonstration." ) model = cls._create_placeholder_model(metadata) else: model = cls._download_model(metadata["url"], model_path, backend) else: # Load from cache if backend == "tensorflow": model = keras.models.load_model(str(model_path)) else: # pytorch model = torch.load(str(model_path)) return cls(model, metadata, backend)
@staticmethod def _create_placeholder_model(metadata: Dict[str, Any]) -> Any: """ Create a placeholder model for demonstration. In production, this would not be needed as models would be downloaded. """ backend = metadata["backend"] architecture = metadata["architecture"] input_shape = metadata["input_shape"] task = metadata["task"] if backend == "tensorflow": if not TENSORFLOW_AVAILABLE: raise ImportError("TensorFlow not installed") # Create simple model based on architecture inputs = keras.Input(shape=input_shape) x = inputs if "cnn" in architecture: x = keras.layers.Conv1D(32, 7, padding="same")(x) x = keras.layers.BatchNormalization()(x) x = keras.layers.Activation("relu")(x) x = keras.layers.MaxPooling1D(2)(x) x = keras.layers.Conv1D(64, 5, padding="same")(x) x = keras.layers.BatchNormalization()(x) x = keras.layers.Activation("relu")(x) x = keras.layers.MaxPooling1D(2)(x) x = keras.layers.GlobalAveragePooling1D()(x) elif "lstm" in architecture: x = keras.layers.LSTM(64, return_sequences=True)(x) x = keras.layers.LSTM(32)(x) elif "transformer" in architecture: # Simple transformer-like architecture x = keras.layers.MultiHeadAttention(num_heads=4, key_dim=32)(x, x) x = keras.layers.GlobalAveragePooling1D()(x) x = keras.layers.Dense(128, activation="relu")(x) # Output layer if task == "classification": n_classes = metadata["n_classes"] outputs = keras.layers.Dense(n_classes, activation="softmax")(x) elif task == "multi_label_classification": n_classes = metadata["n_classes"] outputs = keras.layers.Dense(n_classes, activation="sigmoid")(x) elif task == "regression": outputs = keras.layers.Dense(1, activation="linear")(x) elif task == "autoencoder": # Decoder part (simplified) x = keras.layers.Dense(128, activation="relu")(x) x = keras.layers.Reshape((-1, 1))(x) outputs = keras.layers.UpSampling1D(2)(x) model = keras.Model(inputs, outputs) model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"], ) else: # pytorch raise NotImplementedError("PyTorch placeholder models not yet implemented") return model @staticmethod def _download_model(url: str, save_path: Path, backend: str) -> Any: """Download model from URL.""" try: print(f"Downloading model from {url}...") urlretrieve(url, str(save_path)) print("Download complete.") # Load model if backend == "tensorflow": model = keras.models.load_model(str(save_path)) else: model = torch.load(str(save_path)) return model except (URLError, HTTPError, OSError, Exception) as e: raise RuntimeError(f"Failed to download model: {e}")
[docs] def predict( self, X: np.ndarray, batch_size: int = 32, return_proba: bool = False ) -> np.ndarray: """ Make predictions on new data. Parameters ---------- X : np.ndarray Input signals batch_size : int, default=32 Batch size for prediction return_proba : bool, default=False Return class probabilities (classification only) Returns ------- np.ndarray Predictions """ if self.backend == "tensorflow": predictions = self.model.predict(X, batch_size=batch_size, verbose=0) if self.metadata["task"] == "classification" and not return_proba: if predictions.ndim > 1 and predictions.shape[1] > 1: predictions = np.argmax(predictions, axis=1) else: predictions = (predictions.squeeze() > 0.5).astype(int) else: # pytorch self.model.eval() with torch.no_grad(): X_tensor = torch.FloatTensor(X) predictions = self.model(X_tensor).cpu().numpy() if self.metadata["task"] == "classification" and not return_proba: if predictions.ndim > 1 and predictions.shape[1] > 1: predictions = np.argmax(predictions, axis=1) else: predictions = (predictions.squeeze() > 0.5).astype(int) return predictions
[docs] def fine_tune( self, X_train: np.ndarray, y_train: np.ndarray, X_val: Optional[np.ndarray] = None, y_val: Optional[np.ndarray] = None, epochs: int = 10, batch_size: int = 32, learning_rate: float = 1e-4, freeze_layers: Optional[int] = None, verbose: int = 1, ) -> Dict[str, List[float]]: """ Fine-tune model on new data. Parameters ---------- X_train : np.ndarray Training signals y_train : np.ndarray Training labels X_val : np.ndarray, optional Validation signals y_val : np.ndarray, optional Validation labels epochs : int, default=10 Number of training epochs batch_size : int, default=32 Batch size learning_rate : float, default=1e-4 Learning rate (smaller than training from scratch) freeze_layers : int, optional Number of initial layers to freeze verbose : int, default=1 Verbosity level Returns ------- dict Training history """ if self.backend == "tensorflow": # Freeze layers if requested if freeze_layers is not None: for i, layer in enumerate(self.model.layers): if i < freeze_layers: layer.trainable = False # Compile with lower learning rate optimizer = keras.optimizers.Adam(learning_rate=learning_rate) if self.metadata["task"] == "classification": loss = "sparse_categorical_crossentropy" metrics = ["accuracy"] else: loss = "mse" metrics = ["mae"] self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) # Callbacks callbacks = [ keras.callbacks.EarlyStopping( monitor="val_loss" if X_val is not None else "loss", patience=5, restore_best_weights=True, ), keras.callbacks.ReduceLROnPlateau( monitor="val_loss" if X_val is not None else "loss", factor=0.5, patience=3, min_lr=1e-7, ), ] # Train validation_data = ( (X_val, y_val) if X_val is not None and y_val is not None else None ) history = self.model.fit( X_train, y_train, validation_data=validation_data, epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=verbose, ) return history.history else: # pytorch raise NotImplementedError("PyTorch fine-tuning not yet implemented")
[docs] def evaluate( self, X_test: np.ndarray, y_test: np.ndarray, batch_size: int = 32 ) -> Dict[str, float]: """ Evaluate model on test data. Parameters ---------- X_test : np.ndarray Test signals y_test : np.ndarray Test labels batch_size : int, default=32 Batch size Returns ------- dict Evaluation metrics """ if self.backend == "tensorflow": results = self.model.evaluate( X_test, y_test, batch_size=batch_size, verbose=0 ) metrics = {} for name, value in zip(self.model.metrics_names, results): metrics[name] = value return metrics else: # pytorch raise NotImplementedError("PyTorch evaluation not yet implemented")
[docs] def save(self, filepath: str): """ Save model to file. Parameters ---------- filepath : str Path to save model """ filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) if self.backend == "tensorflow": self.model.save(str(filepath)) else: torch.save(self.model, str(filepath)) # Save metadata metadata_path = filepath.parent / f"{filepath.stem}_metadata.json" with open(metadata_path, "w") as f: json.dump(self.metadata, f, indent=2)
[docs] @classmethod def load(cls, filepath: str) -> "PretrainedModel": """ Load model from file. Parameters ---------- filepath : str Path to saved model Returns ------- PretrainedModel Loaded model """ filepath = Path(filepath) # Load metadata metadata_path = filepath.parent / f"{filepath.stem}_metadata.json" if metadata_path.exists(): with open(metadata_path, "r") as f: metadata = json.load(f) else: metadata = {} backend = metadata.get("backend", "tensorflow") # Load model if backend == "tensorflow": model = keras.models.load_model(str(filepath)) else: model = torch.load(str(filepath)) return cls(model, metadata, backend)
[docs] def info(self) -> str: """ Get model information. Returns ------- str Formatted model information """ info_str = "Pre-trained Model Information\n" info_str += "=" * 50 + "\n" for key, value in self.metadata.items(): if isinstance(value, (list, tuple)) and len(value) > 5: value = f"{value[:3]}... ({len(value)} total)" info_str += f"{key.replace('_', ' ').title()}: {value}\n" return info_str
[docs] def get_layer_names(self) -> List[str]: """Get list of layer names in the model.""" if self.backend == "tensorflow": return [layer.name for layer in self.model.layers] else: return [name for name, _ in self.model.named_modules()]
[docs] def get_features( self, X: np.ndarray, layer_name: Optional[str] = None, batch_size: int = 32 ) -> np.ndarray: """ Extract features from intermediate layer. Parameters ---------- X : np.ndarray Input signals layer_name : str, optional Name of layer to extract features from. If None, use second-to-last layer. batch_size : int, default=32 Batch size Returns ------- np.ndarray Extracted features """ if self.backend == "tensorflow": if layer_name is None: # Use second-to-last layer (before output) layer_name = self.model.layers[-2].name # Create feature extractor feature_model = keras.Model( inputs=self.model.input, outputs=self.model.get_layer(layer_name).output ) features = feature_model.predict(X, batch_size=batch_size, verbose=0) else: # pytorch raise NotImplementedError("PyTorch feature extraction not yet implemented") return features
[docs] class ModelHub: """ Central hub for managing pre-trained models. Provides utilities for: - Listing available models - Downloading models - Managing model cache - Model comparison Examples -------- >>> from vitalDSP.ml_models.pretrained_models import ModelHub >>> >>> # List all available models >>> hub = ModelHub() >>> models = hub.list_models() >>> print(models) >>> >>> # Filter models by signal type >>> ecg_models = hub.list_models(signal_type='ecg', task='classification') >>> print(ecg_models) >>> >>> # Download specific model >>> model = hub.get_model('ecg_classifier_mitbih') """ def __init__(self, cache_dir: Optional[str] = None): """ Initialize model hub. Parameters ---------- cache_dir : str, optional Directory for caching models """ if cache_dir is None: cache_dir = Path.home() / ".vitaldsp" / "pretrained_models" self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True)
[docs] def list_models( self, signal_type: Optional[str] = None, task: Optional[str] = None, architecture: Optional[str] = None, ) -> List[Dict[str, Any]]: """ List available models with optional filtering. Parameters ---------- signal_type : str, optional Filter by signal type ('ecg', 'ppg', 'eeg') task : str, optional Filter by task ('classification', 'regression', 'autoencoder') architecture : str, optional Filter by architecture ('cnn1d', 'lstm', 'transformer') Returns ------- list of dict List of model metadata """ models = [] for name, metadata in MODEL_REGISTRY.items(): # Apply filters if signal_type and metadata.get("signal_type") != signal_type: continue if task and metadata.get("task") != task: continue if architecture and metadata.get("architecture") != architecture: continue model_info = {"name": name, **metadata} models.append(model_info) return models
[docs] def get_model( self, model_name: str, force_download: bool = False ) -> PretrainedModel: """ Get pre-trained model. Parameters ---------- model_name : str Name of model force_download : bool, default=False Force re-download Returns ------- PretrainedModel Loaded model """ return PretrainedModel.from_registry( model_name, cache_dir=str(self.cache_dir), force_download=force_download )
[docs] def clear_cache(self, model_name: Optional[str] = None): """ Clear model cache. Parameters ---------- model_name : str, optional Specific model to clear. If None, clear all. """ if model_name is not None: # Clear specific model for file in self.cache_dir.glob(f"{model_name}*"): file.unlink() else: # Clear all for file in self.cache_dir.glob("*"): file.unlink()
[docs] def get_cache_size(self) -> float: """ Get total size of cached models in MB. Returns ------- float Total cache size in MB """ total_size = 0 for file in self.cache_dir.glob("*"): total_size += file.stat().st_size return total_size / (1024 * 1024) # Convert to MB
[docs] def compare_models( self, model_names: List[str], metric: str = "accuracy" ) -> Dict[str, float]: """ Compare models by a specific metric. Parameters ---------- model_names : list of str Names of models to compare metric : str, default='accuracy' Metric to compare Returns ------- dict Model names and their metric values """ comparison = {} for name in model_names: if name in MODEL_REGISTRY: metadata = MODEL_REGISTRY[name] if metric in metadata: comparison[name] = metadata[metric] return comparison
# Convenience function
[docs] def load_pretrained_model( model_name: str, cache_dir: Optional[str] = None, force_download: bool = False ) -> PretrainedModel: """ Quick function to load a pre-trained model. Parameters ---------- model_name : str Name of model from registry cache_dir : str, optional Cache directory force_download : bool, default=False Force re-download Returns ------- PretrainedModel Loaded model Examples -------- >>> from vitalDSP.ml_models.pretrained_models import load_pretrained_model >>> import numpy as np >>> >>> # Load ECG classifier >>> model = load_pretrained_model('ecg_classifier_mitbih') >>> >>> # Make predictions >>> ecg_signal = np.random.randn(1, 187, 1) >>> prediction = model.predict(ecg_signal) >>> print(f"Predicted class: {prediction[0]}") """ return PretrainedModel.from_registry( model_name, cache_dir=cache_dir, force_download=force_download )