Unity使用GTCRN实现流式语音增强(一)

流式语音增强

sherpa-onnx 已经添加了 GTCRN 实现了离线语音增强,但流式语音增强一直没有添加,用官方onnxruntime在Unity中实现了一下,有点问题,对比官方效果有差距,但增强效果还是可以的。
在这里插入图片描述
原始音频
在这里插入图片描述
官方效果
在这里插入图片描述
Unity里效果
在这里插入图片描述
Unity里最新效果

在这里插入图片描述
Unity里最新效果7.18,除了开头的尖刺已经和原版python结果差不多了

在这里插入图片描述
Unity里最新效果7.22
在这里插入图片描述

主要代码如下

using MathNet.Numerics;
using MathNet.Numerics.IntegralTransforms;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;
using Debug = UnityEngine.Debug;

public class GtcrnTest2 : MonoBehaviour
{
    string simpleModelPath = "your_model_name.onnx";
    string outputPath = "test_wavs/enh_onnx.wav"; // Path to save the enhanced audio.
    int sampleRate = 16000; // Audio sample rate.
    int n_fft = 512; // FFT window size.
    int hop_length = 256; // Hop length between consecutive frames.
    int win_length = 512; // Window length (usually equals n_fft for STFT).

    // Start is called before the first frame update
    void Start()
    {
        simpleModelPath = Application.streamingAssetsPath + "/gtcrn_simple.onnx";
        outputPath = Application.streamingAssetsPath + "/result.wav";

        float[] rawAudio = ReadWav(Application.streamingAssetsPath + "/mix.wav");
        var window = Window.Hann(win_length).Select(x => (float)Math.Sqrt(x)).ToArray();
        (float[,,] stftResult, int frames) = ComputeSTFT(rawAudio, n_fft, hop_length, win_length, window);
        // 3. 初始化ONNX推理会话
        var session = new InferenceSession(simpleModelPath);

        var convCache = new DenseTensor<float>(new[] { 2, 1, 16, 16, 33 });
        var traCache = new DenseTensor<float>(new[] { 2, 3, 1, 1, 16 });
        var interCache = new DenseTensor<float>(new[] { 2, 1, 33, 16 });

        var outputs = new List<float[,,]>();

        for (int i = 0; i < frames; i++)
        {
            // 准备当前帧
            var input = new DenseTensor<float>(new[] { 1, 257, 1, 2 });
            for (int j = 0; j < 257; j++)
            {
                input[0, j, 0, 0] = stftResult[i, j, 0]; // 实部
                input[0, j, 0, 1] = stftResult[i, j, 1]; // 虚部
            }

            // 准备输入
            var inputs = new List<NamedOnnxValue>
                {
                    NamedOnnxValue.CreateFromTensor("mix", input),
                    NamedOnnxValue.CreateFromTensor("conv_cache", convCache),
                    NamedOnnxValue.CreateFromTensor("tra_cache", traCache),
                    NamedOnnxValue.CreateFromTensor("inter_cache", interCache)
                };

            // 推理
            using (var results = session.Run(inputs))
            {
                var enh = results.First(t => t.Name == "enh").AsTensor<float>();
                convCache = (DenseTensor<float>)results.First(t => t.Name == "conv_cache_out").AsTensor<float>();
                traCache = (DenseTensor<float>)results.First(t => t.Name == "tra_cache_out").AsTensor<float>();
                interCache = (DenseTensor<float>)results.First(t => t.Name == "inter_cache_out").AsTensor<float>();

                // 存储输出
                var frameOutput = new float[1, 257, 2];
                for (int j = 0; j < 257; j++)
                {
                    frameOutput[0, j, 0] = enh[0, j, 0, 0];
                    frameOutput[0, j, 1] = enh[0, j, 0, 1];
                }
                outputs.Add(frameOutput);
            }
        }

        // 6. 合并所有帧
        var allFrames = new float[frames, 257, 2];
        for (int i = 0; i < frames; i++)
        {
            for (int j = 0; j < 257; j++)
            {
                allFrames[i, j, 0] = outputs[i][0, j, 0];
                allFrames[i, j, 1] = outputs[i][0, j, 1];
            }
        }

        // 7. 计算ISTFT
        float[] enhancedAudio = ComputeISTFT(allFrames, n_fft, hop_length, win_length, window);

        SaveClip(1, 16000, enhancedAudio, outputPath);
    }

    // Update is called once per frame
    void Update()
    {

    }

    static (float[,,] result, int frames) ComputeSTFT(float[] audio, int n_fft, int hop, int win, float[] window)
    {
        int frames = (audio.Length - n_fft) / hop + 1;
        var stft = new float[frames, n_fft / 2 + 1, 2]; // [frame, freq, real/imag]

        for (int i = 0; i < frames; i++)
        {
            // 提取帧并加窗
            var frame = new float[n_fft];
            Array.Copy(audio, i * hop, frame, 0, Math.Min(n_fft, audio.Length - i * hop));
            for (int j = 0; j < n_fft; j++) frame[j] *= window[j];

            // 计算FFT (使用MathNet.Numerics)
            var complexFrame = new Complex32[n_fft];
            for (int j = 0; j < n_fft; j++)
            {
                complexFrame[j] = new Complex32(frame[j], 0);
            }
            Fourier.Forward(complexFrame, FourierOptions.Default);

            // 存储结果(仅保留一半)
            for (int j = 0; j <= n_fft / 2; j++)
            {
                stft[i, j, 0] = complexFrame[j].Real;
                stft[i, j, 1] = complexFrame[j].Imaginary;
            }
        }
        return (stft, frames);
    }

    static float[] ComputeISTFT(float[,,] stft, int n_fft, int hop, int win, float[] window)
    {
        int frames = stft.GetLength(0);
        int outputLength = (frames - 1) * hop + n_fft;
        var output = new float[outputLength];
        var scale = window.Select(w => w * w).Sum(); // 用于归一化

        for (int i = 0; i < frames; i++)
        {
            // 重建完整频谱
            var fullSpectrum = new Complex32[n_fft];
            for (int j = 0; j <= n_fft / 2; j++)
            {
                fullSpectrum[j] = new Complex32(stft[i, j, 0], stft[i, j, 1]);
                if (j > 0 && j < n_fft / 2)
                {
                    fullSpectrum[n_fft - j] = fullSpectrum[j].Conjugate();
                }
            }

            // 逆FFT
            Fourier.Inverse(fullSpectrum, FourierOptions.Default);

            // 加窗并重叠相加
            int pos = i * hop;
            for (int j = 0; j < n_fft; j++)
            {
                if (pos + j < output.Length)
                {
                    output[pos + j] += fullSpectrum[j].Real * window[j] / scale;
                }
            }
        }
        return output;
    }

    float[] ReadWav(string filePath)
    {
        using (FileStream fs = new FileStream(filePath, FileMode.Open, FileAccess.Read))
        using (BinaryReader reader = new BinaryReader(fs))
        {
            // 读取WAV文件头
            string riff = new string(reader.ReadChars(4));    // "RIFF"
            int fileSize = reader.ReadInt32();                // 文件总大小-8
            string wave = new string(reader.ReadChars(4));    // "WAVE"
            string fmt = new string(reader.ReadChars(4));     // "fmt "
            int fmtSize = reader.ReadInt32();                 // fmt块大小(至少16)

            // 读取音频格式信息
            short audioFormat = reader.ReadInt16();           // 1=PCM
            short numChannels = reader.ReadInt16();           // 通道数
            int sampleRate = reader.ReadInt32();              // 采样率
            int byteRate = reader.ReadInt32();                // 字节率
            short blockAlign = reader.ReadInt16();            // 块对齐
            short bitsPerSample = reader.ReadInt16();         // 采样深度

            // 验证文件格式
            if (riff != "RIFF" || wave != "WAVE" || fmt != "fmt ")
                throw new Exception("无效的WAV文件头");

            // 跳过fmt块的额外信息(如果有)
            if (fmtSize > 16)
                reader.ReadBytes(fmtSize - 16);

            // 查找数据块
            string dataChunkId;
            do
            {
                dataChunkId = new string(reader.ReadChars(4));
                if (dataChunkId != "data")
                    reader.ReadBytes(reader.ReadInt32()); // 跳过非数据块
            } while (dataChunkId != "data");

            int dataSize = reader.ReadInt32(); // 数据块大小(字节)

            // 验证音频参数
            if (audioFormat != 1)
                throw new Exception("仅支持PCM格式");
            if (numChannels != 1)
                throw new Exception("仅支持单声道音频");
            if (sampleRate != 16000)
                throw new Exception("仅支持16kHz采样率");
            if (bitsPerSample != 16)
                throw new Exception("仅支持16位采样深度");

            // 读取PCM数据并转换为float
            int sampleCount = dataSize / 2; // 16位 = 2字节/样本
            float[] floatData = new float[sampleCount];

            for (int i = 0; i < sampleCount; i++)
            {
                // 小端序读取16位样本
                byte lowByte = reader.ReadByte();
                byte highByte = reader.ReadByte();
                short pcmValue = (short)((highByte << 8) | lowByte);

                // 将16位PCM值转换为[-1.0, 1.0]范围的float
                floatData[i] = pcmValue / 32768.0f;
            }

            return floatData;
        }
    }

    void SaveClip(int channels, int frequency, float[] data, string filePath)
    {
        using (FileStream fileStream = new FileStream(filePath, FileMode.Create))
        {
            using (BinaryWriter writer = new BinaryWriter(fileStream))
            {
                // 写入RIFF头部标识
                writer.Write("RIFF".ToCharArray());
                // 写入文件总长度(后续填充)
                writer.Write(0);
                writer.Write("WAVE".ToCharArray());
                // 写入fmt子块
                writer.Write("fmt ".ToCharArray());
                writer.Write(16); // PCM格式块长度
                writer.Write((short)1); // PCM编码类型
                writer.Write((short)channels);
                writer.Write(frequency);
                writer.Write(frequency * channels * 2); // 字节率
                writer.Write((short)(channels * 2)); // 块对齐
                writer.Write((short)16); // 位深度
                                         // 写入data子块
                writer.Write("data".ToCharArray());
                writer.Write(data.Length * 2); // 音频数据字节数
                                               // 写入PCM数据(float转为short)
                foreach (float sample in data)
                {
                    // 转换过程代码可能有误,此处把音量放大100倍
                    writer.Write((short)(sample * 32767 * 100));
                }
                // 返回填充文件总长度
                fileStream.Position = 4;
                writer.Write((int)(fileStream.Length - 8));
            }
        }
    }
}

最后是工程地址
gtcrn-unity

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地狱为王

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值