【Datawhale AI春训营-第三届世界科学智能大赛创新药赛道:RNA逆折叠与功能核酸设计】

一、大赛背景与定位

第三届世界科学智能大赛由上海市科学技术委员会等多部门指导,上海科学智能研究院与复旦大学联合主办,阿里云、复星医药等机构协办,总奖金池达百万元,面向全球开放。大赛聚焦科学智能与高价值产业场景的融合,设置航空安全、材料设计、合成生物、创新药、新能源五大赛道,旨在推动人工智能(AI)技术在复杂科学问题中的突破,加速科研范式变革。
创新药赛道作为核心方向之一,以 “RNA 逆折叠与功能核酸设计” 为主题,要求参赛者基于给定的 RNA 三维骨架结构,设计能够折叠成该结构的 RNA 序列。这一赛题直接关联 RNA 药物开发、mRNA 疫苗优化、生物传感器设计等前沿领域,是科学智能技术在生物医药产业落地的关键突破口。

二、赛题核心:RNA 逆折叠的科学挑战

RNA 的功能高度依赖其三维结构,而 RNA 逆折叠(Inverse RNA Folding)是指从目标三维结构反向设计 RNA 序列的过程。这一问题的挑战性体现在以下方面:

  • 结构 - 序列逆映射难题:RNA 序列的折叠路径受热力学、动力学及环境因素(如离子浓度、温度)的多重影响,需建立跨尺度模型以解析结构与序列的关联。
  • 多目标优化需求:理想的 RNA 序列不仅需匹配目标结构,还需具备生物学稳定性、低免疫原性等特性,需平衡多维度约束条件。
  • 计算复杂度:RNA 分子的构象空间庞大,传统分子动力学模拟难以高效求解,需借助深度学习、强化学习等 AI 方法加速搜索。

技术路径:

  • AI 驱动的生成模型:如上海元码智药近期获批的 “双曲离散扩散模型”,通过双曲等变图神经网络(Hyperbolic Isometric Graph Neural Network)将 RNA 结构嵌入几何空间,结合扩散过程逐步去噪,实现高效序列生成。
  • 物理模型与数据融合:结合第一性原理(如自由能计算)与大规模实验数据(如 siRNA 药物研发数据集),提升模型的泛化能力。

三、评估标准与技术工具

  1. 核心指标:恢复率(Recovery Rate)
    通过比对生成序列与真实折叠成目标结构的 RNA 序列的相似性(如 BLAST 或序列比对工具),评估算法的准确性。恢复率越高,表明模型对结构 - 序列关系的建模越精准。
  2. 算力与数据支持
    复旦大学 CFFF 智算平台为参赛团队提供 400 卡 GPU 的算力支持,并开放高质量科学数据集,包括 6TB 次季节气象预测数据、9.6 万条 siRNA 修饰序列等,助力复杂模型训练。
    现有工具与前沿技术
  3. 传统算法:如 Rosetta(蛋白质 / RNA 结构建模)、NUPACK(核酸分子设计)等,但在三维结构逆折叠中存在效率瓶颈。
  4. AI 创新:AlphaFold3、RoseTTAFold All-Atom 等模型在蛋白质复合物预测中取得进展,其技术思路可迁移至 RNA 设计。

四、应用前景与产业价值

RNA 逆折叠技术的突破将重塑生物医药研发范式:

  1. RNA 药物开发:设计靶向疾病相关 RNA 结构的序列,如反义 RNA、RNA 适配体,用于精准调控基因表达。
  2. mRNA 疫苗优化:通过结构设计提升疫苗的稳定性、递送效率及免疫激活效果,缩短研发周期。
  3. 合成生物学:构建可响应环境信号的 RNA 开关(Riboswitch),用于细胞编程与代谢通路调控。
  4. 生物传感器:设计对特定分子(如病毒蛋白、小分子药物)敏感的 RNA 结构,实现高灵敏度检测。

产业案例:上海元码智药的专利技术已展示出在三维 RNA 结构逆折叠中的潜力,其双曲扩散模型可在有限样本下恢复核苷酸分布,为 RNA 药物的智能化设计提供新工具

五、模型训练

进阶模型训练代码:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv, LayerNorm
from torch_geometric.nn import radius_graph
from Bio import SeqIO
import math

# 配置参数
class Config:
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16 if torch.cuda.is_available() else 8  # 根据显存调整
    lr = 0.001
    epochs = 50
    seq_vocab = "AUCG"
    coord_dims = 7  
    hidden_dim = 256
    num_layers = 4  # 减少层数防止显存溢出
    k_neighbors = 20  
    dropout = 0.1
    rbf_dim = 16
    num_heads = 4
    amp_enabled = True  # 混合精度训练

# 几何特征生成器
class GeometricFeatures:
    @staticmethod
    def rbf(D, D_min=0., D_max=20., D_count=16):
        device = D.device
        D_mu = torch.linspace(D_min, D_max, D_count, device=device)
        D_mu = D_mu.view(*[1]*len(D.shape), -1)
        D_sigma = (D_max - D_min) / D_count
        D_expand = D.unsqueeze(-1)
        return torch.exp(-((D_expand - D_mu)/D_sigma) ** 2)

    @staticmethod
    def dihedrals(X, eps=1e-7):
        X = X.to(torch.float32)
        L = X.shape[0]
        dX = X[1:] - X[:-1]
        U = F.normalize(dX, dim=-1)
        
        # 计算连续三个向量
        u_prev = U[:-2]
        u_curr = U[1:-1]
        u_next = U[2:]

        # 计算法向量
        n_prev = F.normalize(torch.cross(u_prev, u_curr, dim=-1), dim=-1)
        n_curr = F.normalize(torch.cross(u_curr, u_next, dim=-1), dim=-1)

        # 计算二面角
        cosD = (n_prev * n_curr).sum(-1)
        cosD = torch.clamp(cosD, -1+eps, 1-eps)
        D = torch.sign((u_prev * n_curr).sum(-1)) * torch.acos(cosD)

        # 填充处理
        if D.shape[0] < L:
            D = F.pad(D, (0,0,0,L-D.shape[0]), "constant", 0)
        
        return torch.stack([torch.cos(D[:,:5]), torch.sin(D[:,:5])], -1).view(L,-1)

    @staticmethod
    def direction_feature(X):
        dX = X[1:] - X[:-1]
        return F.pad(F.normalize(dX, dim=-1), (0,0,0,1))

# 图构建器
class RNAGraphBuilder:
    @staticmethod
    def build_graph(coord, seq):
        assert coord.shape[1:] == (7,3), f"坐标维度错误: {coord.shape}"
        coord = torch.tensor(coord, dtype=torch.float32)
        
        # 节点特征
        node_feats = [
            coord.view(-1, 7 * 3),  # [L,21]
            GeometricFeatures.dihedrals(coord[:,:6,:]),  # [L,10]
            GeometricFeatures.direction_feature(coord[:,4,:])  # [L,3]
        ]
        x = torch.cat(node_feats, dim=-1)  # [L,34]

        # 边构建
        pos = coord[:,4,:]
        edge_index = radius_graph(pos, r=20.0, max_num_neighbors=Config.k_neighbors)
        
        # 边特征
        row, col = edge_index
        edge_vec = pos[row] - pos[col]
        edge_dist = torch.norm(edge_vec, dim=-1, keepdim=True)
        edge_feat = torch.cat([
            GeometricFeatures.rbf(edge_dist).squeeze(1),  # [E,16]
            F.normalize(edge_vec, dim=-1)  # [E,3]
        ], dim=-1)  # [E,19]

        # 标签
        y = torch.tensor([Config.seq_vocab.index(c) for c in seq], dtype=torch.long)
        
        return Data(x=x, edge_index=edge_index, edge_attr=edge_feat, y=y)

# 模型架构
class RNAGNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 节点特征编码
        self.feat_encoder = nn.Sequential(
            nn.Linear(34, Config.hidden_dim),
            nn.ReLU(),
            LayerNorm(Config.hidden_dim),
            nn.Dropout(Config.dropout)
        )
        
        # 边特征编码(关键修复)
        self.edge_encoder = nn.Sequential(
            nn.Linear(19, Config.hidden_dim),
            nn.ReLU(),
            LayerNorm(Config.hidden_dim),
            nn.Dropout(Config.dropout)
        )

        # Transformer卷积层
        self.convs = nn.ModuleList([
            TransformerConv(
                Config.hidden_dim,
                Config.hidden_dim // Config.num_heads,
                heads=Config.num_heads,
                edge_dim=Config.hidden_dim,  # 匹配编码后维度
                dropout=Config.dropout
            ) for _ in range(Config.num_layers)
        ])

        # 残差连接
        self.mlp_skip = nn.ModuleList([
            nn.Sequential(
                nn.Linear(Config.hidden_dim, Config.hidden_dim),
                nn.ReLU(),
                LayerNorm(Config.hidden_dim)
            ) for _ in range(Config.num_layers)
        ])

        # 分类头
        self.cls_head = nn.Sequential(
            nn.Linear(Config.hidden_dim, Config.hidden_dim),
            nn.ReLU(),
            LayerNorm(Config.hidden_dim),
            nn.Dropout(Config.dropout),
            nn.Linear(Config.hidden_dim, len(Config.seq_vocab))
        )

        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        # 边特征编码(关键步骤)
        edge_attr = self.edge_encoder(edge_attr)  # [E,19] -> [E,256]
        
        # 节点编码
        h = self.feat_encoder(x)
        
        # 消息传递
        for i, (conv, skip) in enumerate(zip(self.convs, self.mlp_skip)):
            h_res = conv(h, edge_index, edge_attr=edge_attr)
            h = h + skip(h_res)
            if i < len(self.convs)-1:
                h = F.relu(h)
                h = F.dropout(h, p=Config.dropout, training=self.training)
        
        return self.cls_head(h)

# 数据增强
class CoordTransform:
    @staticmethod
    def random_rotation(coords):
        device = torch.device(Config.device)
        coords_tensor = torch.from_numpy(coords).float().to(device)
        angle = np.random.uniform(0, 2*math.pi)
        rot_mat = torch.tensor([
            [math.cos(angle), -math.sin(angle), 0],
            [math.sin(angle), math.cos(angle), 0],
            [0, 0, 1]
        ], device=device)
        return (coords_tensor @ rot_mat.T).cpu().numpy()

# 数据集类
class RNADataset(torch.utils.data.Dataset):
    def __init__(self, coords_dir, seqs_dir, augment=False):
        self.samples = []
        self.augment = augment
        
        for fname in os.listdir(coords_dir):
            # 加载坐标
            coord = np.load(os.path.join(coords_dir, fname))
            coord = np.nan_to_num(coord, nan=0.0)
            
            # 数据增强
            if self.augment and np.random.rand() > 0.5:
                coord = CoordTransform.random_rotation(coord)
            
            # 加载序列
            seq_id = os.path.splitext(fname)[0]
            seq_path = os.path.join(seqs_dir, f"{seq_id}.fasta")
            seq = str(next(SeqIO.parse(seq_path, "fasta")).seq)
            
            # 构建图
            self.samples.append(RNAGraphBuilder.build_graph(coord, seq))
    
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

# 训练函数
def train(model, loader, optimizer, scheduler, criterion):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=Config.amp_enabled)
    total_loss = 0
    
    for batch in loader:
        batch = batch.to(Config.device)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=Config.amp_enabled):
            logits = model(batch)
            loss = criterion(logits, batch.y)
        
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
    
    scheduler.step()
    return total_loss / len(loader)

# 评估函数
def evaluate(model, loader):
    model.eval()
    total_correct = total_nodes = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(Config.device)
            logits = model(batch)
            preds = logits.argmax(dim=1)
            total_correct += (preds == batch.y).sum().item()
            total_nodes += batch.y.size(0)
    return total_correct / total_nodes

if __name__ == "__main__":
    # 初始化
    torch.manual_seed(Config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(Config.seed)
        torch.backends.cudnn.benchmark = True
    
    # 数据集
    train_set = RNADataset(
        "./RNA_design_public/RNAdesignv1/train/coords",
        "./RNA_design_public/RNAdesignv1/train/seqs",
        augment=True
    )
    
    # 划分数据集
    train_size = int(0.8 * len(train_set))
    val_size = (len(train_set) - train_size) // 2
    test_size = len(train_set) - train_size - val_size
    train_set, val_set, test_set = torch.utils.data.random_split(
        train_set, [train_size, val_size, test_size])
    
    # 数据加载
    train_loader = torch_geometric.loader.DataLoader(
        train_set, 
        batch_size=Config.batch_size, 
        shuffle=True,
        pin_memory=True,
        num_workers=4
    )
    val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=Config.batch_size)
    test_loader = torch_geometric.loader.DataLoader(test_set, batch_size=Config.batch_size)
    
    # 模型初始化
    model = RNAGNN().to(Config.device)
    optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)
    criterion = nn.CrossEntropyLoss()
    
    # 训练循环
    best_acc = 0
    for epoch in range(Config.epochs):
        train_loss = train(model, train_loader, optimizer, scheduler, criterion)
        val_acc = evaluate(model, val_loader)
        
        print(f"Epoch {epoch+1}/{Config.epochs} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
    
    # 最终测试
    model.load_state_dict(torch.load("best_model.pth"))
    test_acc = evaluate(model, test_loader)
    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

在这里插入图片描述

参赛感受

本次比赛用到了前沿的技术,利用深度学习技术预测RNA折叠,利用阿里云镜像仓库,实现代码开发与codeup代码管理,容器仓库联用。利用企业级完整的技术链进行快速迭代开发比赛。难点就是需要对技术栈要求比较全面,否则很难走完全流程。阿里云容器服务,个人服务因为人多,限制了云镜像构建,导致个人本地构建花了很多时间,希望阿里云后期可以改善这方面的支持。其他的都很不错,希望可以参与更多类似的研究,比赛。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值