OmegaConf 使用详解与案例说明

OmegaConf 是一个 Python 库,用于处理配置文件和配置数据,特别适合管理层次化的配置结构。它是 Hydra 配置框架的核心组件,但也可以独立使用。

主要特性

  • 配置合并:支持合并多个配置源
  • 变量插值:支持在配置中引用其他值
  • 运行时修改:可以在运行时修改配置
  • 多种格式支持:YAML、JSON、字典
  • 类型安全:支持类型检查和转换

安装

pip install omegaconf

1. 基本用法

(1)从字典创建配置

from omegaconf import OmegaConf

# 从字典创建
cfg = OmegaConf.create({"a": 1, "b": {"c": 2}})
print(cfg.a)  # 输出: 1
print(cfg.b.c)  # 输出: 2

(2). 从 YAML 文件加载
假设有一个 config.yaml 文件:

server:
  host: localhost
  port: 8080
database:
  name: test
  user: admin

加载配置:

cfg = OmegaConf.load("config.yaml")
print(cfg.server.host)  # 输出: localhost

(3) 合并多个配置

base_cfg = OmegaConf.create({"a": 1, "b": 2})
override_cfg = OmegaConf.create({"b": 3, "c": 4})

merged = OmegaConf.merge(base_cfg, override_cfg)
print(merged)  # {'a': 1, 'b': 3, 'c': 4}
base = OmegaConf.load("base_config.yaml")
dev = OmegaConf.load("dev_config.yaml")
cfg = OmegaConf.merge(base, dev)

print(cfg.app.log_level)  # DEBUG (覆盖了base的INFO)
print(cfg.database.host)  # dev.db.example.com (覆盖了base的localhost)

(4) 访问和修改配置

cfg = OmegaConf.create({"a": 1, "b": {"c": 2}})

# 访问
print(cfg.a)  # 1
print(cfg.b.c)  # 2

# 修改
cfg.a = 10
cfg.b.c = 20

# 添加新字段
cfg.d = "new field"

(5) 变量插值 (用的比较少)

cfg = OmegaConf.create({
    "server": {
        "host": "localhost",
        "port": 8080,
        "url": "${server.host}:${server.port}"
    }
})
print(cfg.server.url)  # 输出: localhost:8080
  • 环境变量插值
cfg = OmegaConf.create({
    "database": {
        "password": "${oc.env:DB_PASSWORD,default_password}"
    }
})
# 如果环境变量 DB_PASSWORD 存在则使用它,否则使用 default_password

2. to_container 方法

OmegaConf.to_container() 是将 OmegaConf 配置对象转换为原生 Python 数据结构的核心方法,它可以将 OmegaConf 的特殊容器类型(如 DictConfig 和 ListConfig)转换为标准的 Python dict 和 list。

self.conf = OmegaConf.load(self.vc_config)
(Pdb) type(self.conf)
<class 'omegaconf.dictconfig.DictConfig'>

得到的conf对象的类型是'omegaconf.dictconfig.DictConfig', 支持cfg.a的形式访问对象。如果需要转换成python 的数据类型,则需要OmegaConf.to_container转换,此时得到的数据类型为dict, 此时通过cfg[‘a’]

self.conf = OmegaConf.to_container(self.conf, resolve=True)
(Pdb) type(self.conf)
<class 'dict'>

案例:

from omegaconf import OmegaConf

# 创建 OmegaConf 配置对象
cfg = OmegaConf.create({
    "server": {
        "host": "localhost",
        "port": 8080
    },
    "features": ["auth", "logging"]
})

# 转换为原生 Python 容器
python_dict = OmegaConf.to_container(cfg)

print(type(python_dict))  # <class 'dict'>
print(python_dict)
# 输出: {'server': {'host': 'localhost', 'port': 8080}, 'features': ['auth', 'logging']}
  • resolve (默认为 False): 是否解析所有插值变量,默认为不解析
cfg = OmegaConf.create({
    "host": "localhost",
    "port": 8080,
    "url": "${host}:${port}"
})

# 不解析插值
result = OmegaConf.to_container(cfg, resolve=False)
print(result["url"])  # ${host}:${port}

# 解析插值
result = OmegaConf.to_container(cfg, resolve=True)
print(result["url"])  # localhost:8080
# 这是一个示例 Python 脚本。 # 按 Shift+F10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 import argparse import math import pickle import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from omegaconf import OmegaConf from sklearn.metrics import f1_score from torch.utils.data import Dataset, DataLoader from torch.nn import TransformerEncoderLayer, TransformerEncoder restypes = [ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ] unsure_restype = 'X' unknown_restype = 'U' def make_dataset(data_config, train_rate=0.7, valid_rate=0.2): data_path = data_config.data_path with open(data_path, 'rb') as f: data = pickle.load(f) total_number = len(data) train_sep = int(total_number * train_rate) valid_sep = int(total_number * (train_rate + valid_rate)) train_data_dicts = data[:train_sep] valid_data_dicts = data[train_sep:valid_sep] test_data_dicts = data[valid_sep:] train_dataset = DisProtDataset(train_data_dicts) valid_dataset = DisProtDataset(valid_data_dicts) test_dataset = DisProtDataset(test_data_dicts) return train_dataset, valid_dataset, test_dataset class DisProtDataset(Dataset): def __init__(self, dict_data): sequences = [d['sequence'] for d in dict_data] labels = [d['label'] for d in dict_data] assert len(sequences) == len(labels) self.sequences = sequences self.labels = labels self.residue_mapping = {'X':20} self.residue_mapping.update(dict(zip(restypes, range(len(restypes))))) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = torch.zeros(len(self.sequences[idx]), len(self.residue_mapping)) for i, c in enumerate(self.sequences[idx]): if c not in restypes: c = 'X' sequence[i][self.residue_mapping[c]] = 1 label = torch.tensor([int(c) for c in self.labels[idx]], dtype=torch.long) return sequence, label class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.0, max_len=40): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) def forward(self, x): if len(x.shape) == 3: x = x + self.pe[:, : x.size(1)] elif len(x.shape) == 4: x = x + self.pe[:, :x.size(1), None, :] return self.dropout(x) class DisProtModel(nn.Module): def __init__(self, model_config): super().__init__() self.d_model = model_config.d_model self.n_head = model_config.n_head self.n_layer = model_config.n_layer self.input_layer = nn.Linear(model_config.i_dim, self.d_model) self.position_embed = PositionalEncoding(self.d_model, max_len=20000) self.input_norm = nn.LayerNorm(self.d_model) self.dropout_in = nn.Dropout(p=0.1) encoder_layer = TransformerEncoderLayer( d_model=self.d_model, nhead=self.n_head, activation='gelu', batch_first=True) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.n_layer) self.output_layer = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(self.d_model, model_config.o_dim) ) def forward(self, x): x = self.input_layer(x) x = self.position_embed(x) x = self.input_norm(x) x = self.dropout_in(x) x = self.transformer(x) x = self.output_layer(x) return x def metric_fn(pred, gt): pred = pred.detach().cpu() gt = gt.detach().cpu() pred_labels = torch.argmax(pred, dim=-1).view(-1) gt_labels = gt.view(-1) score = f1_score(y_true=gt_labels, y_pred=pred_labels, average='micro') return score if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' parser = argparse.ArgumentParser('IDRs prediction') parser.add_argument('--config_path', default='./config.yaml') args = parser.parse_args() config = OmegaConf.load(args.config_path) train_dataset, valid_dataset, test_dataset = make_dataset(config.data) train_dataloader = DataLoader(dataset=train_dataset, **config.train.dataloader) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) model = DisProtModel(config.model) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.optimizer.lr, weight_decay=config.train.optimizer.weight_decay) loss_fn = nn.CrossEntropyLoss() model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print("init f1_score:", metric / len(valid_dataloader)) for epoch in range(config.train.epochs): # train loop progress_bar = tqdm( train_dataloader, initial=0, desc=f"epoch:{epoch:03d}", ) model.train() total_loss = 0. for sequence, label in progress_bar: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) loss = loss_fn(pred.permute(0, 2, 1), label) progress_bar.set_postfix(loss=loss.item()) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = total_loss / len(train_dataloader) # valid loop model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print(f"avg_training_loss: {avg_loss}, f1_score: {metric / len(valid_dataloader)}") # 保存当前 epoch 的模型 save_path = f"model.pkl" torch.save(model.state_dict(), save_path) print(f"Model saved to {save_path}") 根据这样的代码创建符合的数据集
最新发布
07-13
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

@BangBang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值