VITS源码解读6-训练&推理

1. train.py

1.1 大体流程

  • 执行main函数,调用多线程和run函数
  • 执行run函数,加载日志、数据集、模型、模型优化器
  • for循环迭代数据batch,每次执行train_and_evaluate函数,训练模型

这里需要注意,源码中加载数据集用的分布式,单卡最好将其删除,用普通的data_loader即可。

1.2 train_and_evaluate函数

训练每次迭代的执行函数

  • g的输入和输出特别多,有2行,分布为:
y_hat, l_length, attn, ids_slice, x_mask, z_mask,\ 
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths)

其中第二行为求kl所需值

  • d的输入和输出为:
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
  • 损失函数计算

g:

        loss_dur = torch.sum(l_length.float())
        loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

        loss_fm = feature_loss(fmap_r, fmap_g)
        loss_gen, losses_gen = generator_loss(y_d_hat_g)
        loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl

d:

loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)

1.3 evaluuate函数

这里主要是用到mel图的对比,通过可视化mel图判别生成效果。

来自数据集的spec

      mel = spec_to_mel_torch(
        spec, 
        hps.data.filter_length, 
        hps.data.n_mel_channels, 
        hps.data.sampling_rate,
        hps.data.mel_fmin, 
        hps.data.mel_fmax)

来自生成的音频y

      y_hat_mel = mel_spectrogram_torch(
        y_hat.squeeze(1).float(),
        hps.data.filter_length,
        hps.data.n_mel_channels,
        hps.data.sampling_rate,
        hps.data.hop_length,
        hps.data.win_length,
        hps.data.mel_fmin,
        hps.data.mel_fmax)

1.4 train_ms.py 文件

区别于train.py文件用于训练lj,train_ms.py用于训练vctk数据集。

即train_ms适用于多人数据集,模型的输入也多了一个多人的embedding_layer。

2. Inference.py

这个是jupter即 ipynb文件,其通过 SynthesizerTrn类的infer方法和voice_conversion方法实现。

2.1 infer方法

该方法实现tts功能,输入文本的音素化特征,输出对应文本语义的语音效果。

主要步骤如下:

  • 通过文本编码器输入文本音素化向量,得到x, m_p, logs_p, x_mask

如果多人,潜入人物特征

    x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
    if self.n_speakers > 0:
      g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
    else:
      g = None
  • 通过dp输入x,x_mask, 以及人物特征,得到logw
    if self.use_sdp:
      logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
    else:
      logw = self.dp(x, x_mask, g=g)
  • logw与x_mask运算得到w_ceil, 最终得到y_lengths
    w = torch.exp(logw) * x_mask * length_scale
    w_ceil = torch.ceil(w)
    y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
  • y_lengths得到y_mask,x_mask和y_mask得到注意力矩阵attn
    y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)

    attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)

    attn = commons.generate_path(w_ceil, attn_mask)
  • m_p 和 logs_p 分别与attn相乘,再相加得到z_p
 m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
    
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']

z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  • z_p送入flow逆向,得到可用于生成音频的潜码z
z = self.flow(z_p, y_mask, g=g, reverse=True)
  • z送入解码器 dec, 得到对应音频y
 o = self.dec((z * y_mask)[:,:,:max_len], g=g)

2.2 voice_conversion

该方法实现vctk数据集下的不同人物音色转换。

  • 通过输入人物的id,得到人物的嵌入特征
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
  • 将原人物的音频y和y长度,及对应id输入,得到其潜码
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
  • 通过流模型输入目标任务的id,在潜码中嵌入目标任务的特征
z_p = self.flow(z, y_mask, g=g_src)
  • 剩下的步骤和tts类似,用z_p合成目标语音
 z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
 o_hat = self.dec(z_hat * y_mask, g=g_tgt)

3.总结

VITS到这里就告于段落了, 后面的VITS2改进了VITS的dp模型(flow变gan),

在cosvoice等模型里面也能见到VITS的主干网络。

因此, VITS是音频tts和vc、sc的核心技术。

### 使用so-vits-svc在Ubuntu上进行推理 #### 安装依赖库 为了能够在Ubuntu上运行so-vits-svc模型,安装必要的Python包和其他依赖项是必需的操作。这通常涉及使用`pip`来安装特定版本的PyTorch以及其他辅助工具,如numpy, scipy等。 ```bash sudo apt-get update && sudo apt-get install -y portaudio19-dev pip install torch==1.7.0+cu101 torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html pip install numpy scipy soundfile webrtcvad pyworld pypinyin unidecode g2p_en onnxruntime ``` #### 下载并配置环境 获取so-vits-svc项目的源码可以通过克隆GitHub仓库完成。之后进入项目目录设置环境变量以及下载预训练模型文件以便后续调用[^1]。 ```bash git clone https://github.com/your-repo/so-vits-svc.git cd so-vits-svc wget http://path.to/pretrained_model.pth export PYTHONPATH=$PYTHONPATH:$(pwd) ``` #### 准备音频数据 对于输入的声音片段,确保其格式符合要求(通常是WAV),并且采样率匹配预期设定。如果有必要的话,可以利用脚本批量转换待处理的语音样本至适当规格[^2]。 ```python import librosa from scipy.io.wavfile import write as wavwrite def preprocess_audio(input_path, output_path): y, sr = librosa.load(input_path, sr=22050) # 调整为所需的采样频率 wavwrite(output_path, sr, y.astype('float32')) ``` #### 执行推理过程 最后一步就是编写简单的测试程序加载之前准备好的模型权重,并传入经过前处理后的音频作为输入执行预测操作。这里给出一个简化版的例子展示基本流程[^3]。 ```python import torch from model import SoVITS_SVC_Model # 假设这是定义了网络结构的一个模块 device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SoVITS_SVC_Model().to(device) checkpoint = torch.load('./pretrained_model.pth', map_location=device) model.load_state_dict(checkpoint['state_dict']) model.eval() with torch.no_grad(): audio_tensor = ... # 加载和预处理过的音频张量 result = model(audio_tensor.unsqueeze(0).to(device)) print("Inference completed.") ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值