使用Python写Caffe网络的配置文件

当网络比较长时,在.prototxt文件中写Caffe网络容易出错,下面是使用Python写Caffe网络的配置文件(以手写体识别mnist为例),Python代码:

# -*- coding: utf-8 -*-
"""
使用python写Caffe的网络,生成Caffe的网络配置文件
"""

import sys
caffe_root="/home/pcb/caffe/"
sys.path.insert(0,caffe_root+"python")
from caffe import layers as L
from caffe import params as P
import caffe

def lenet(lmdb, batch_size,model):

    n=caffe.NetSpec()
    #数据层
    if model==False:
        n.data,n.label=L.Data(batch_size=batch_size,backend=P.Data.LMDB,source=lmdb,
                          include=dict(phase=0),transform_param=dict(scale=1./255),ntop=2)
    if model==True:
        n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb,
                                 include=dict(phase=1), transform_param=dict(scale=1. / 255), ntop=2)
    n.conv1=L.Convolution(n.data,param = [dict(lr_mult=1), dict(lr_mult=2)],kernel_size=5,stride=1,num_output=20,weight_filler=dict(type="xavier"),
                          bias_filler=dict(type='constant'))
    n.pool1=L.Pooling(n.conv1,kernel_size=2,stride=2,pool=P.Pooling.MAX)
    n.conv2 = L.Convolution(n.pool1,param = [dict(lr_mult=1), dict(lr_mult=2)], kernel_size=5,
                            num_output=50, weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'))
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    n.ip1 = L.InnerProduct(n.pool2, param = [dict(lr_mult=1), dict(lr_mult=2)],
                           num_output=500, weight_filler=dict(type='xavier'), bias_filler=dict(type='constant'))
    n.relu1 = L.ReLU(n.ip1, in_place=True)
    n.ip2 = L.InnerProduct(n.relu1,param = [dict(lr_mult=1), dict(lr_mult=2)],num_output=10,
                           weight_filler=dict(type='xavier'),bias_filler=dict(type='constant'))
    n.accuracy=L.Accuracy(n.ip2,n.label,include=dict(phase=1))
    n.loss = L.SoftmaxWithLoss(n.ip2, n.label)

    return n.to_proto()

#网络配置文件train.prototxt所存放的地址
with open('/home/pcb/caffe/python/Project/lenet_train.prototxt', 'w') as f:
#存放用于训练的lmdb数据源的地址
    f.write(str(lenet('/home/pcb/caffe/examples/mnist/mnist_train_lmdb', 64, False)))
#网络配置文件test.prototxt所存放的地址
with open('/home/pcb/caffe/python/Project/lenet_test.prototxt', 'w') as f:
#存放用于测试的lmdb数据源的地址
   f.write(str(lenet('/home/pcb/caffe/examples/mnist/mnist_test_lmdb', 100,True)))

运行之后就会在相应的目录下生成
lenet_train.prototxt和lenet_test.prototxt文件

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值