1. 新建一个Conda虚拟环境
conda create -n mamba python=3.10
2. 进入该环境
conda activate mamba
3. 安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio
直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio
CSDN资源
下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可
pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
4. 安装triton和transformers库
pip install triton==2.3.1
pip install transformers==4.43.3
5. 安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
6. 最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本
causal-conv1d —— 1.4.0
mamba-ssm —— 2.2.2
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装
pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
7. 安装好环境后,验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
except ImportError:
causal_conv1d_fn = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:
RMSNormGated, LayerNorm = None, None
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
class Mamba2Simple(nn.Module):
def __init__(
self,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=1,
A_init_range=(1, 16),
dt_min=