使用Rust原生实现小波卡尔曼滤波算法

一、算法原理概述

  1. 小波变换(Wavelet Transform)

    • 通过多尺度分解将信号分为高频(细节)和低频(近似)部分,高频通常包含噪声,低频保留主体信息。
    • 使用Haar小波(计算高效)进行快速分解与重构:
      // Haar小波分解公式
      approx = (x[2i] + x[2i+1]) / 2.0;
      detail = (x[2i] - x[2i+1]) / 2.0;
      
  2. 卡尔曼滤波(Kalman Filter)

    • 预测-更新两阶段递归估计[citation:5][citation:6]:
      • 预测:基于状态方程预估当前状态和误差协方差

        Xpred=A⋅Xprev+B⋅uPpred=A⋅Pprev⋅AT+QXpred​=A⋅Xprev​+B⋅uPpred​=A⋅Pprev​⋅AT+Q

      • 更新:结合观测值修正预测值,计算卡尔曼增益K

        K=Ppred⋅HT/(H⋅Ppred⋅HT+R)Xnew=Xpred+K⋅(z−H⋅Xpred)Pnew=(I−K⋅H)⋅PpredK=Ppred​⋅HT/(H⋅Ppred​⋅HT+R)Xnew​=Xpred​+K⋅(z−H⋅Xpred​)Pnew​=(I−K⋅H)⋅Ppred​

  3. 小波卡尔曼融合

    • 先对原始信号小波分解,对不同频段独立应用卡尔曼滤波,最后重构信号
    • 高频部分使用更高测量噪声协方差R(抑制噪声),低频部分使用更低R(保留主体)。

二、Rust实现代码

cargo.toml

[package]
name = "wavelet-kalman-filter"
version = "0.1.0"
edition = "2021"

[dependencies]
plotters = "0.3.5"
rand = "0.8.5"
log = "0.4.17"
simple_logger = "5.0.0"

main.rs

use log::{info};
use log::LevelFilter;
use simple_logger::SimpleLogger;

// 初始化日志记录,将日志保存到文件
fn init_logger() {
    SimpleLogger::new()
        .with_level(LevelFilter::Info)
        .with_utc_timestamps()
        .init()
        .expect("Failed to initialize logger");
    info!("日志系统初始化完成");
}

// 主结构体
pub struct WaveletKalmanFilter {
    pub a: f64,      // 状态转移系数(如匀速运动时A=1)
    pub q: f64,      // 过程噪声协方差
    pub r_low: f64,  // 低频观测噪声协方差
    pub r_high: f64, // 高频观测噪声协方差
    pub state: f64,  // 当前状态估计
    pub cov: f64,    // 当前误差协方差
}

// 一维Haar小波分解(返回低频近似 + 高频细节)
pub fn haar_decompose(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
    info!("开始Haar小波分解,输入信号长度: {}", signal.len());
    // 对信号进行对称延拓,以缓解小波分析的边界问题
    fn symmetric_extend(signal: &[f64]) -> Vec<f64> {
        if signal.len() <= 1 {
            info!("信号长度小于等于1,直接返回原信号");
            return signal.to_vec();
        }
        let mut extended = Vec::new();
        // 左侧对称部分
        for i in (1..=signal.len() - 1).rev() {
            extended.push(signal[i]);
        }
        // 原始信号部分
        extended.extend_from_slice(signal);
        // 右侧对称部分
        for i in 1..signal.len() {
            extended.push(signal[signal.len() - 1 - i]);
        }
        info!("对称延拓后信号长度: {} 信号内容: {:?}", extended.len(), extended);
        extended
    }

    // 对延拓后的信号进行Haar小波分解
    pub fn haar_decompose(signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
        let extended = symmetric_extend(signal);
        let mut approx = Vec::new();
        let mut detail = Vec::new();
        for i in (0..extended.len()).step_by(2) {
            let sum = extended[i] + extended.get(i + 1).unwrap_or(&0.0);
            let diff = extended[i] - extended.get(i + 1).unwrap_or(&0.0);
            approx.push(sum / 2.0);
            detail.push(diff / 2.0);
        }
        // 截取与原始信号分解后相同长度的结果
        let half_len = (signal.len() + 1) / 2;
        let approx = approx[extended.len() / 4..extended.len() / 4 + half_len].to_vec();
        let detail = detail[extended.len() / 4..extended.len() / 4 + half_len].to_vec();
        info!(
            "延拓后Haar分解完成,近似分量长度: {}, 细节分量长度: {}",
            approx.len(),
            detail.len()
        );
        (approx, detail)
    }
    let mut approx = Vec::new();
    let mut detail = Vec::new();
    for i in (0..signal.len()).step_by(2) {
        let sum = signal[i] + signal.get(i + 1).unwrap_or(&0.0);
        let diff = signal[i] - signal.get(i + 1).unwrap_or(&0.0);
        approx.push(sum / 2.0);
        detail.push(diff / 2.0);
    }
    info!(
        "直接Haar分解完成,近似分量长度: {}, 细节分量长度: {}",
        approx.len(),
        detail.len()
    );
    (approx, detail)
}

// 一维Haar小波重构
pub fn haar_recompose(approx: &[f64], detail: &[f64]) -> Vec<f64> {
    info!(
        "开始Haar小波重构,近似分量长度: {}, 细节分量长度: {}",
        approx.len(),
        detail.len()
    );
    let mut output = Vec::with_capacity(approx.len() * 2);
    for i in 0..approx.len() {
        output.push(approx[i] + detail[i]);
        output.push(approx[i] - detail[i]);
    }
    info!("Haar小波重构完成,输出信号长度: {}", output.len());
    output
}

impl WaveletKalmanFilter {
    // 初始化滤波器
    pub fn new(a: f64, q: f64, r_low: f64, r_high: f64) -> Self {
        info!(
            "创建新的WaveletKalmanFilter,a: {}, q: {}, r_low: {}, r_high: {}",
            a, q, r_low, r_high
        );
        WaveletKalmanFilter {
            a,
            q,
            r_low,
            r_high,
            state: 0.0,
            cov: 1.0, // 初始协方差设为较大值(表示不确定性高)
        }
    }

    // 单步预测
    pub fn predict(&mut self) {
        info!("执行单步预测前,状态: {}, 协方差: {}", self.state, self.cov);
        self.state = self.a * self.state; // X_pred = A * X_prev
        self.cov = self.a * self.cov * self.a + self.q; // P_pred = A*P*A^T + Q
        info!("执行单步预测后,状态: {}, 协方差: {}", self.state, self.cov);
    }

    // 单步更新(根据频段选择R)
    pub fn update(&mut self, measurement: f64, is_low_freq: bool) {
        // info!(
        //     "执行单步更新前,测量值: {}, 是否低频: {}",
        //     measurement, is_low_freq
        // );
        let r = if is_low_freq { self.r_low } else { self.r_high };
        let k = self.cov / (self.cov + r); // 卡尔曼增益K
        self.state += k * (measurement - self.state); // X_new = X_pred + K*(z - H*X_pred)
        self.cov = (1.0 - k) * self.cov; // P_new = (I - K*H)*P_pred
        info!("执行单步更新后,状态: {}, 协方差: {}", self.state, self.cov);
    }
}

pub fn wavelet_kalman_filter(
    signal: &[f64],
    levels: usize, // 小波分解层数
    a: f64,        // 状态转移系数
    q: f64,        // 过程噪声
    r_low: f64,    // 低频观测噪声
    r_high: f64,   // 高频观测噪声
) -> Vec<f64> {
    info!(
        "开始小波卡尔曼滤波,输入信号长度: {}, 分解层数: {}",
        signal.len(),
        levels
    );
    // 1. 递归小波分解(多尺度)
    let mut approx = signal.to_vec();
    let mut details = Vec::new();
    for level in 0..levels {
        info!("第 {} 层小波分解", level + 1);
        let (a, d) = haar_decompose(&approx);
        details.push(d);
        approx = a;
    }

    // 2. 对各频段独立应用卡尔曼滤波
    let mut kalman = WaveletKalmanFilter::new(a, q, r_low, r_high);
    kalman.state = approx[0]; // 初始化状态

    // 低频滤波(R较小,信任观测值)
    info!("开始低频滤波,近似分量长度: {}", approx.len());
    for i in 0..approx.len() {
        info!("低频滤波第 {} 步", i + 1);
        kalman.predict();
        kalman.update(approx[i], true); // 标记为低频
        approx[i] = kalman.state;
    }

    // 高频滤波(R较大,抑制噪声)
    info!("开始高频滤波,细节分量数量: {}", details.len());
    for (level, d) in details.iter_mut().enumerate() {
        info!("第 {} 层细节分量滤波,长度: {}", level + 1, d.len());
        kalman = WaveletKalmanFilter::new(a, q, r_low, r_high); // 重置滤波器
        kalman.state = d[0];
        for i in 0..d.len() {
            info!("第 {} 层细节分量第 {} 步滤波", level + 1, i + 1);
            kalman.predict();
            kalman.update(d[i], false); // 标记为高频
            d[i] = kalman.state;
        }
    }

    // 3. 小波重构
    info!("开始小波重构,分解层数: {}", levels);
    for _ in 0..levels {
        approx = haar_recompose(&approx, &details.pop().unwrap());
    }
    info!("小波卡尔曼滤波完成,输出信号长度: {}", approx.len());
    approx
}

// 绘图函数实现
fn draw_signals(
    clean: &[f64],
    noisy: &[f64],
    filtered: &[f64],
) -> Result<(), Box<dyn std::error::Error>> {
    info!(
        "开始绘制信号图,原始信号长度: {}, 含噪信号长度: {}, 滤波信号长度: {}",
        clean.len(),
        noisy.len(),
        filtered.len()
    );
    use plotters::prelude::*;

    let root = BitMapBackend::new("signals.png", (1024, 768)).into_drawing_area();
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("信号处理结果", ("sans-serif", 50).into_font())
        .margin(10)
        .x_label_area_size(30)
        .y_label_area_size(30)
        .build_cartesian_2d(0f64..clean.len() as f64, -2f64..2f64)?;

    chart.configure_mesh().draw()?;

    chart
        .draw_series(LineSeries::new(
            clean.iter().enumerate().map(|(i, &y)| (i as f64, y)),
            &RED,
        ))?
        .label("原始信号")
        .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &RED));

    chart
        .draw_series(LineSeries::new(
            noisy.iter().enumerate().map(|(i, &y)| (i as f64, y)),
            &BLUE,
        ))?
        .label("含噪信号")
        .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &BLUE));

    chart
        .draw_series(LineSeries::new(
            filtered.iter().enumerate().map(|(i, &y)| (i as f64, y)),
            &GREEN,
        ))?
        .label("滤波信号")
        .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &GREEN));

    chart
        .configure_series_labels()
        .background_style(&WHITE.mix(0.8))
        .border_style(&BLACK)
        .draw()?;
    info!("信号图绘制完成,保存至 signals.png");
    Ok(())
}

// 测试函数:添加高斯噪声的信号滤波
fn test_noisy_sine() {
    info!("开始测试含噪正弦信号滤波");
    let clean_signal: Vec<_> = (0..1000).map(|i| (i as f64 * 0.1).sin()).collect();
    let noisy_signal: Vec<_> = clean_signal
        .iter()
        .map(|x| x + rand::random::<f64>())
        .collect();

    let filtered = wavelet_kalman_filter(&noisy_signal, 3, 1.0, 0.01, 0.1, 2.0);

    // 绘图比较(使用plotters库)
    draw_signals(&clean_signal, &noisy_signal, &filtered);
    info!("含噪正弦信号滤波测试完成");
}

// 主函数
fn main() {
    init_logger();
    info!("程序启动");
    test_noisy_sine();
    info!("程序结束");
}

实时小波卡尔曼滤波

use nalgebra::{DMatrix, DVector, Scalar};
use std::f64::consts::PI;
use std::collections::VecDeque;
use std::cell::RefCell;
use std::rc::Rc;

/// 实时小波卡尔曼滤波器结构体
pub struct RealTimeWaveletKalmanFilter {
    // 卡尔曼滤波参数
    state: DVector<f64>,                // 状态向量
    covariance: DMatrix<f64>,           // 状态协方差矩阵
    transition: DMatrix<f64>,           // 状态转移矩阵
    observation: DMatrix<f64>,          // 观测矩阵
    process_noise: DMatrix<f64>,        // 过程噪声协方差
    measurement_noise: DMatrix<f64>,    // 测量噪声协方差
    
    // 小波变换参数
    wavelet_type: WaveletType,          // 小波基类型
    decomposition_level: usize,         // 小波分解层数
    threshold: f64,                     // 小波阈值去噪的阈值
    
    // 实时处理状态
    buffer: VecDeque<f64>,                // 输入缓冲区
    min_buffer_size: usize,              // 最小缓冲区大小
    buffer_size: usize,                  // 目标缓冲区大小
    processed_count: usize,              // 已处理数据计数
    
    // 小波变换状态
    approx_buffers: Vec<Vec<f64>>,       // 多分辨率近似系数缓冲区
    detail_buffers: Vec<Vec<f64>>,       // 多分辨率细节系数缓冲区
}

/// 可用的小波基类型
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WaveletType {
    Haar,
    Daubechies4,
}

impl RealTimeWaveletKalmanFilter {
    /// 创建新的实时小波卡尔曼滤波器
    pub fn new(
        initial_state: DVector<f64>,
        initial_covariance: DMatrix<f64>,
        transition: DMatrix<f64>,
        observation: DMatrix<f64>,
        process_noise: DMatrix<f64>,
        measurement_noise: DMatrix<f64>,
        wavelet_type: WaveletType,
        decomposition_level: usize,
        threshold: f64,
    ) -> Self {
        // 计算所需的最小缓冲区大小
        let min_buffer_size = match wavelet_type {
            WaveletType::Haar => 2,
            WaveletType::Daubechies4 => 4,
        };
        
        // 计算目标缓冲区大小
        let buffer_size = 1 << decomposition_level; // 2^decomposition_level
        
        // 初始化多分辨率缓冲区
        let mut approx_buffers = vec![vec![]; decomposition_level];
        let mut detail_buffers = vec![vec![]; decomposition_level];
        
        RealTimeWaveletKalmanFilter {
            state: initial_state,
            covariance: initial_covariance,
            transition,
            observation,
            process_noise,
            measurement_noise,
            wavelet_type,
            decomposition_level,
            threshold,
            buffer: VecDeque::with_capacity(buffer_size * 2),
            min_buffer_size,
            buffer_size,
            processed_count: 0,
            approx_buffers,
            detail_buffers,
        }
    }
    
    /// 添加新数据点并处理
    pub fn add_data_point(&mut self, measurement: f64) -> Option<f64> {
        // 1. 将新数据添加到缓冲区
        self.buffer.push_back(measurement);
        
        // 2. 如果缓冲区有足够数据,进行处理
        if self.buffer.len() >= self.min_buffer_size {
            // 3. 进行小波分解和重构
            if let Some(denoised) = self.process_wavelets() {
                // 4. 卡尔曼预测
                self.predict();
                
                // 5. 卡尔曼更新
                self.update(denoised);
                
                // 6. 返回状态估计值
                return Some(self.state[0]);
            }
        }
        
        None
    }
    
    /// 处理缓冲区中的数据
    fn process_wavelets(&mut self) -> Option<f64> {
        // 当前分解层数
        let level = (self.processed_count % self.decomposition_level) + 1;
        let level = level.min(self.decomposition_level);
        
        // 根据当前层确定处理块大小
        let block_size = self.buffer_size;
        
        if self.buffer.len() < block_size {
            return None;
        }
        
        // 取出要处理的数据块
        let mut block: Vec<f64> = self.buffer.drain(..block_size).collect();
        
        // 多分辨率小波处理
        let mut denoised_block = Vec::with_capacity(block_size);
        
        // 1. 小波分解
        for _ in 0..level {
            if block.len() < self.min_buffer_size {
                break;
            }
            
            let (approx, detail) = match self.wavelet_type {
                WaveletType::Haar => self.haar_decompose(&block),
                WaveletType::Daubechies4 => self.daubechies4_decompose(&block),
            };
            
            // 添加到对应级别的缓冲区
            self.approx_buffers[level - 1].extend_from_slice(&approx);
            self.detail_buffers[level - 1].extend_from_slice(&detail);
            
            block = approx;
        }
        
        // 2. 小波系数阈值处理
        for i in 0..level {
            // 阈值处理近似系数
            self.approx_buffers[i] = self.threshold_coefficients(self.approx_buffers[i].clone());
            
            // 阈值处理细节系数
            self.detail_buffers[i] = self.threshold_coefficients(self.detail_buffers[i].clone());
        }
        
        // 3. 小波重构 - 仅重构最后部分用于实时处理
        let mut reconstructed = self.approx_buffers[level - 1].clone();
        
        for i in (0..level).rev() {
            reconstructed = match self.wavelet_type {
                WaveletType::Haar => self.haar_reconstruct(&reconstructed, &self.detail_buffers[i]),
                WaveletType::Daubechies4 => self.daubechies4_reconstruct(&reconstructed, &self.detail_buffers[i]),
            };
            
            // 只保留最近的数据用于处理
            let keep_count = (self.buffer_size >> i).max(self.min_buffer_size);
            self.approx_buffers[i] = self.approx_buffers[i].split_off(self.approx_buffers[i].len().saturating_sub(keep_count));
            self.detail_buffers[i] = self.detail_buffers[i].split_off(self.detail_buffers[i].len().saturating_sub(keep_count));
        }
        
        // 取最后一个值作为当前去噪后的测量值
        if let Some(last_value) = reconstructed.last() {
            // 更新已处理计数
            self.processed_count += 1;
            
            return Some(*last_value);
        }
        
        None
    }
    
    /// 预测步骤
    pub fn predict(&mut self) {
        // 传统卡尔曼预测
        self.state = &self.transition * &self.state;
        self.covariance = &self.transition * &self.covariance * self.transition.transpose() + &self.process_noise;
    }
    
    /// 更新步骤,使用去噪后的测量值
    pub fn update(&mut self, measurement: f64) {
        // 创建测量向量
        let measurement_vector = DVector::from_vec(vec![measurement]);
        
        // 传统卡尔曼更新
        let innovation = &measurement_vector - &self.observation * &self.state;
        let innovation_covariance = &self.observation * &self.covariance * self.observation.transpose() + &self.measurement_noise;
        
        if let Some(kalman_gain) = self.compute_kalman_gain(&innovation_covariance) {
            self.state = &self.state + &kalman_gain * innovation;
            let identity = DMatrix::identity(self.covariance.nrows(), self.covariance.ncols());
            self.covariance = (identity - &kalman_gain * &self.observation) * &self.covariance;
        }
    }
    
    /// 计算卡尔曼增益 (带有奇异值处理)
    fn compute_kalman_gain(&self, innovation_covariance: &DMatrix<f64>) -> Option<DMatrix<f64>> {
        if let Some(inv) = innovation_covariance.clone().try_inverse() {
            Some(&self.covariance * self.observation.transpose() * inv)
        } else {
            // 使用伪逆作为后备
            let svd = innovation_covariance.svd(true, true);
            svd.pseudo_inverse(1e-12)
                .map(|inv| &self.covariance * self.observation.transpose() * inv)
        }
    }
    
    /// Haar小波分解(使用对称边界延拓)
    fn haar_decompose(&self, signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
        let n = signal.len();
        if n < 2 {
            // 对非常短的信号直接返回
            return (signal.to_vec(), vec![0.0; n]);
        }
        
        let mut approx = Vec::with_capacity(n / 2);
        let mut detail = Vec::with_capacity(n / 2);
        
        // 对称边界延拓实现
        for i in (0..n).step_by(2) {
            // 对称边界处理:当处理边界点时,使用对称点进行计算
            let left_boundary = if i == 0 { i + 1 } else { i - 1 };
            let right_boundary = if i + 1 >= n { n - 2 } else { i + 1 };
            
            let s0 = if i < n { signal[i] } else { signal[n - (i - n) % (n - 1) - 1] };
            let s1 = if i + 1 < n { signal[i + 1] } else { signal[right_boundary] };
            
            approx.push((s0 + s1) / 2.0);
            detail.push((s0 - s1) / 2.0);
        }
        
        (approx, detail)
    }
    
    /// Daubechies-4小波分解(使用对称边界延拓)
    fn daubechies4_decompose(&self, signal: &[f64]) -> (Vec<f64>, Vec<f64>) {
        // Daubechies-4小波系数
        const H0: f64 = (1.0 + 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const H1: f64 = (3.0 + 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const H2: f64 = (3.0 - 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const H3: f64 = (1.0 - 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        
        let n = signal.len();
        if n < 4 {
            return (signal.to_vec(), vec![0.0; n]);
        }
        
        let mut approx = Vec::with_capacity(n / 2);
        let mut detail = Vec::with_capacity(n / 2);
        
        // 对称边界延拓函数
        let symmetric_extend = |idx: isize, signal: &[f64]| {
            let n = signal.len();
            if idx < 0 {
                // 左边界对称延拓
                let symmetric_idx = (-idx - 1) as usize;
                if symmetric_idx < n { signal[symmetric_idx] } else { *signal.last().unwrap() }
            } else if idx >= n as isize {
                // 右边界对称延拓
                let symmetric_idx = 2 * n as isize - 2 - idx;
                if symmetric_idx < 0 {
                    signal[0]
                } else {
                    let idx_usize = symmetric_idx as usize;
                    if idx_usize < n { signal[idx_usize] } else { *signal.last().unwrap() }
                }
            } else {
                signal[idx as usize]
            }
        };
        
        // 分解滤波
        for i in (0..n).step_by(2) {
            // 计算边界点位置
            let idx_minus2 = i as isize - 2;
            let idx_minus1 = i as isize - 1;
            let idx_0 = i as isize;
            let idx_plus1 = i as isize + 1;
            
            // 使用对称延拓获取值
            let s_minus2 = symmetric_extend(idx_minus2, signal);
            let s_minus1 = symmetric_extend(idx_minus1, signal);
            let s0 = symmetric_extend(idx_0, signal);
            let s1 = symmetric_extend(idx_plus1, signal);
            
            // 近似系数 (低通)
            let a = s_minus2 * H0 +
                    s_minus1 * H1 +
                    s0 * H2 +
                    s1 * H3;
            
            // 细节系数 (高通)
            let d = s_minus2 * (-H3) +
                    s_minus1 * H2 +
                    s0 * (-H1) +
                    s1 * H0;
            
            approx.push(a);
            detail.push(d);
        }
        
        (approx, detail)
    }
    
    /// Haar小波重构(使用对称边界延拓)
    fn haar_reconstruct(&self, approx: &[f64], detail: &[f64]) -> Vec<f64> {
        let n = approx.len().max(detail.len()) * 2;
        let mut signal = vec![0.0; n];
        
        // 对称边界延拓函数
        let extend = |arr: &[f64], idx: isize| {
            if idx < 0 {
                // 左边界对称延拓
                let symmetric_idx = (-idx - 1) as usize;
                if symmetric_idx < arr.len() { arr[symmetric_idx] } else { *arr.last().unwrap() }
            } else if idx >= arr.len() as isize {
                // 右边界对称延拓
                let symmetric_idx = 2 * arr.len() as isize - 2 - idx;
                if symmetric_idx < 0 {
                    arr[0]
                } else {
                    let idx_usize = symmetric_idx as usize;
                    if idx_usize < arr.len() { arr[idx_usize] } else { *arr.last().unwrap() }
                }
            } else {
                arr[idx as usize]
            }
        };
        
        for i in 0..n {
            if i % 2 == 0 {
                let a = extend(approx, (i / 2) as isize);
                let d = extend(detail, (i / 2) as isize);
                signal[i] = a + d;
            } else {
                let a = extend(approx, (i / 2) as isize);
                let d = extend(detail, (i / 2) as isize);
                signal[i] = a - d;
            }
        }
        
        signal
    }
    
    /// Daubechies-4小波重构(使用对称边界延拓)
    fn daubechies4_reconstruct(&self, approx: &[f64], detail: &[f64]) -> Vec<f64> {
        // Daubechies-4重构滤波器系数
        const G0: f64 = (1.0 - 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const G1: f64 = (3.0 - 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const G2: f64 = (3.0 + 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        const G3: f64 = (1.0 + 3.0_f64.sqrt()) / (4.0 * 2.0_f64.sqrt());
        
        let n = approx.len().max(detail.len()) * 2;
        let mut signal = vec![0.0; n];
        
        // 对称边界延拓函数
        let extend = |arr: &[f64], idx: isize| {
            if idx < 0 {
                // 左边界对称延拓
                let symmetric_idx = (-idx - 1) as usize;
                if symmetric_idx < arr.len() { arr[symmetric_idx] } else { *arr.last().unwrap() }
            } else if idx >= arr.len() as isize {
                // 右边界对称延拓
                let symmetric_idx = 2 * arr.len() as isize - 2 - idx;
                if symmetric_idx < 0 {
                    arr[0]
                } else {
                    let idx_usize = symmetric_idx as usize;
                    if idx_usize < arr.len() { arr[idx_usize] } else { *arr.last().unwrap() }
                }
            } else {
                arr[idx as usize]
            }
        };
        
        // 重构滤波
        for i in 0..n {
            // 计算输入点的索引位置
            let half_i = i as f64 / 2.0;
            let base_idx = half_i.floor() as isize;
            
            // 偶数索引使用向下取整,奇数索引使用向上取整
            let a_idx = base_idx - 1;
            let a_idx1 = base_idx;
            
            let a0 = extend(approx, a_idx);
            let d0 = extend(detail, a_idx);
            let a1 = extend(approx, a_idx1);
            let d1 = extend(detail, a_idx1);
            
            let value = 
                a0 * G2 + 
                d0 * G3 + 
                a1 * G0 + 
                d1 * G1;
            
            signal[i] = value;
        }
        
        signal
    }
    
    /// 小波系数阈值处理函数
    fn threshold_coefficients(&self, mut coefficients: Vec<f64>) -> Vec<f64> {
        // 实现软阈值函数
        coefficients.iter_mut().for_each(|c| {
            if *c > self.threshold {
                *c -= self.threshold;
            } else if *c < -self.threshold {
                *c += self.threshold;
            } else {
                *c = 0.0;
            }
        });
        coefficients
    }
    
    /// 获取当前状态估计
    pub fn get_state(&self) -> &DVector<f64> {
        &self.state
    }
    
    /// 获取当前协方差矩阵
    pub fn get_covariance(&self) -> &DMatrix<f64> {
        &self.covariance
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_relative_eq;
    use rand::Rng;
    
    // 生成含噪声的正弦波信号
    fn generate_sine_with_noise(length: usize, noise_level: f64) -> Vec<f64> {
        let mut rng = rand::thread_rng();
        (0..length)
            .map(|i| {
                let t = i as f64 * 0.1;
                t.sin() + noise_level * (rng.gen::<f64>() - 0.5)
            })
            .collect()
    }
    
    #[test]
    fn test_symmetric_boundary() {
        // 测试对称边界延拓
        let signal = vec![1.0, 2.0, 3.0, 4.0];
        
        let mut filter = RealTimeWaveletKalmanFilter::new(
            DVector::from_vec(vec![0.0]),
            DMatrix::from_diagonal(&DVector::from_vec(vec![1.0])),
            DMatrix::from_row_slice(1, 1, &[1.0]),
            DMatrix::from_row_slice(1, 1, &[1.0]),
            DMatrix::from_diagonal(&DVector::from_vec(vec![0.001])),
            DMatrix::from_element(1, 1, 0.1),
            WaveletType::Daubechies4,
            2,
            0.2,
        );
        
        // 测试左边界延拓 (索引-1应该对应索引0)
        assert_eq!(filter.daubechies4_decompose(&signal).0.len(), 2);
        
        // 测试短信号边界处理
        let short_signal = vec![1.0];
        assert_eq!(filter.haar_decompose(&short_signal).0.len(), 1);
        
        // 测试对称延拓后的边界一致性
        let decomposed = filter.daubechies4_decompose(&signal);
        assert!(decomposed.0.len() > 0);
    }
    
    #[test]
    fn test_realtime_boundary_performance() {
        // 测试对称边界处理下的实时性能
        let mut filter = RealTimeWaveletKalmanFilter::new(
            DVector::from_vec(vec![0.0]),
            DMatrix::from_diagonal(&DVector::from_vec(vec![1.0])),
            DMatrix::from_row_slice(1, 1, &[1.0]),
            DMatrix::from_row_slice(1, 1, &[1.0]),
            DMatrix::from_diagonal(&DVector::from_vec(vec![0.001])),
            DMatrix::from_element(1, 1, 0.1),
            WaveletType::Daubechies4,
            3,
            0.2,
        );
        
        // 生成带阶跃边界的信号
        let measurements: Vec<f64> = (0..500)
            .map(|i| {
                if i < 100 {
                    1.0
                } else if i < 300 {
                    5.0
                } else {
                    2.0
                }
            })
            .collect();
        
        // 处理所有数据点并测量性能
        let mut count = 0;
        let start = std::time::Instant::now();
        for m in &measurements {
            if filter.add_data_point(*m).is_some() {
                count += 1;
            }
        }
        let duration = start.elapsed();
        
        println!("带边界的信号处理耗时: {:?} ({}点/毫秒)", 
                 duration, 
                 measurements.len() as f64 / duration.as_millis() as f64);
        
        assert!(count > 450, "应产生大部分估计值");
        assert!(duration.as_millis() < 20, "处理过慢");
    }
    
    #[test]
    fn test_realtime_processing_with_boundary() {
        // 初始化参数 (追踪一维位置)
        let initial_state = DVector::from_vec(vec![0.0]);
        let initial_covariance = DMatrix::from_diagonal(&DVector::from_vec(vec![1.0]));
        let transition = DMatrix::from_row_slice(1, 1, &[0.95]); // 状态衰减因子
        let observation = DMatrix::from_row_slice(1, 1, &[1.0]); // 只能观测位置
        let process_noise = DMatrix::from_diagonal(&DVector::from_vec(vec![0.001]));
        let measurement_noise = DMatrix::from_element(1, 1, 0.1);
        
        // 创建实时小波卡尔曼滤波器
        let mut filter = RealTimeWaveletKalmanFilter::new(
            initial_state,
            initial_covariance,
            transition,
            observation,
            process_noise,
            measurement_noise,
            WaveletType::Daubechies4,
            3,
            0.3,
        );
        
        // 生成含噪声的正弦波测量数据,添加边界突变
        let n = 128;
        let mut measurements = generate_sine_with_noise(n, 0.5);
        
        // 添加边界突变 (10%幅度的突变)
        for i in 0..n {
            if i == 0 || i == n - 1 {
                measurements[i] += 1.0;
            }
        }
        
        // 处理前几个点应该不会有输出
        assert!(filter.add_data_point(measurements[0]).is_none());
        assert!(filter.add_data_point(measurements[1]).is_none());
        assert!(filter.add_data_point(measurements[2]).is_none());
        
        // 处理数据点
        let mut estimates = Vec::with_capacity(n);
        for (i, &m) in measurements.iter().enumerate() {
            if let Some(est) = filter.add_data_point(m) {
                estimates.push(est);
                
                // 输出最后几个估计值检查边界
                if i > n - 5 {
                    println!("Boundary estimate at {}: {:.4} (measured {:.4})", i, est, m);
                }
            }
        }
        
        // 验证输出数量
        assert!(estimates.len() > n - 10, "应产生几乎与输入相同数量的估计值");
        
        // 验证输出质量
        let original_signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.1).sin()).collect();
        let mut mse = 0.0;
        
        for (est, orig) in estimates.iter().zip(&original_signal[original_signal.len() - estimates.len()..]) {
            mse += (est - orig).powi(2);
        }
        mse /= estimates.len() as f64;
        
        println!("带边界延拓的实时处理MSE: {:.6}", mse);
        assert!(mse < 0.1, "去噪后MSE应较低");
        
        // 验证边界失真程度
        let boundary_distortion_start = (estimates[0] - original_signal[original_signal.len() - estimates.len()]).abs();
        let boundary_distortion_end = (estimates[estimates.len()-1] - *original_signal.last().unwrap()).abs();
        
        println!("开始边界失真: {:.4}, 结束边界失真: {:.4}", boundary_distortion_start, boundary_distortion_end);
        assert!(boundary_distortion_start < 0.5, "开始边界失真过大");
        assert!(boundary_distortion_end < 0.5, "结束边界失真过大");
    }
    
    #[test]
    fn test_realtime_haar_processing() {
        // 初始化参数
        let initial_state = DVector::from_vec(vec![0.0]);
        let initial_covariance = DMatrix::from_diagonal(&DVector::from_vec(vec![1.0]));
        let transition = DMatrix::from_row_slice(1, 1, &[0.98]);
        let observation = DMatrix::from_row_slice(1, 1, &[1.0]);
        let process_noise = DMatrix::from_diagonal(&DVector::from_vec(vec![0.001]));
        let measurement_noise = DMatrix::from_element(1, 1, 0.15);
        
        // 创建实时小波卡尔曼滤波器 (使用Haar小波)
        let mut filter = RealTimeWaveletKalmanFilter::new(
            initial_state,
            initial_covariance,
            transition,
            observation,
            process_noise,
            measurement_noise,
            WaveletType::Haar,
            4, // 分解层数
            0.4, // 阈值
        );
        
        // 生成阶跃信号加噪声 (添加边界阶跃)
        let mut measurements = vec![];
        for i in 0..100 {
            if i < 30 {
                measurements.push(1.0 + rand::thread_rng().gen_range(-0.5..0.5));
            } else if i < 70 {
                measurements.push(-1.0 + rand::thread_rng().gen_range(-0.5..0.5));
            } else {
                measurements.push(2.0 + rand::thread_rng().gen_range(-0.5..0.5));
            }
        }
        
        // 添加边界突变
        measurements[0] += 1.0;
        measurements[99] += 1.0;
        
        // 处理数据点
        let mut estimates = Vec::with_capacity(measurements.len());
        for (i, &m) in measurements.iter().enumerate() {
            if let Some(est) = filter.add_data_point(m) {
                estimates.push(est);
                if i == 0 || i == 99 {
                    println!("边界点 {}: 测量值={:.4}, 估计值={:.4}", i, m, est);
                }
            }
        }
        
        // 验证阶跃点
        let step1 = estimates[35].abs();
        let step2 = estimates[75].abs();
        
        assert!(step1 > 0.8, "未检测到第一个阶跃");
        assert!(step2 > 1.5, "未检测到第二个阶跃");
        assert!(step2 - step1 > 1.0, "阶跃过渡不清晰");
        
        // 验证边界失真
        let start_distortion = (estimates[0] - 2.0).abs(); // 因为添加了+1的边界突变,理想值应为2
        let end_distortion = (estimates[99] - 3.0).abs(); // 应为3
        
        println!("开始边界失真: {:.4}, 结束边界失真: {:.4}", start_distortion, end_distortion);
        assert!(start_distortion < 1.0, "开始边界失真过大");
        assert!(end_distortion < 1.0, "结束边界失真过大");
    }
}
 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值