1、简介
对于tensorflow.contrib这个库,tensorflow官方对它的描述是:此目录中的任何代码未经官方支持,可能会随时更改或删除。每个目录下都有指定的所有者。它旨在包含额外功能和贡献,最终会合并到核心TensorFlow中,但其接口可能仍然会发生变化,或者需要进行一些测试,看是否可以获得更广泛的接受。所以slim依然不属于原生tensorflow。那么什么是slim?slim到底有什么用?
slim是一个使构建,训练,评估神经网络变得简单的库。它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更紧凑,更具备可读性。另外slim提供了很多计算机视觉方面的著名模型(VGG, AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。
Slim由几个独立存在的部分组成,以下为主要的模块:
arg_scope: 提供了一个新的scope,它允许用户定义在这个scope内的许多特殊操作(比如卷积、池化等)的默认参数。
data: 这个模块包含data_decoder、prefetch_queue、dataset_data_provider、tfexample_decoder、dataset、data_provider、parallel_reader。
evaluation: 包含一些评估模型的例程。
layers: 包含使用TensorFlow搭建模型的一些high level layers 。
learning: 包含训练模型的一些例程。
losses: 包含常用的损失函数。
metrics: 包含一些常用的评估指标。
nets: 包含一些常用的网络模型的定义,比如VGG和 AlexNet。
queues: 提供一个上下文管理器,使得开启和关闭一个QueueRunners.更加简单和安全。
regularizers: 包含权重正则化器。
variables: 为变量的创建和操作提供了比较方面的包装器。
2、定义模型
通过组合slim中变量(variables)、网络层(layer)、前缀名(scope),模型可以被简洁的定义。
(1)变量(Variables)定义
在原始的TensorFlow中,创建变量时,要么需要预定义的值,要么需要一个初始化机制(比如从高斯分布中随机采样)。此外,如果需要在一个特定的特备(比如GPU)上创建一个变量,这个变量必须被显式的创建。为了减少创建变量的代码量,slim提供了一系列包装器函数允许调用者轻易的创建变量。
例如,为了创建一个权重变量,它使用截断正太分布初始化、使用L2的正则化损失并且把这个变量放到CPU中,我们只需要简单的做如下声明:
weights = slim.variable('weights',
shape=[10, 10, 3 , 3],
initializer=tf.truncated_normal_initializer(stddev=0.1),
regularizer=slim.l2_regularizer(0.05),
device='/CPU:0')
注意在原本的TensorFlow中,有两种类型的变量:常规(regular)变量和局部(local)变量。大部分的变量都是常规变量,它们一旦被创建,它们就会被保存到磁盘。而局部变量只存在于一个Session的运行期间,它们并不会被保存到磁盘。在Slim中,模型变量代表一个模型中的各种参数,Slim通过定义模型变量,进一步把各种变量区分开来。模型变量在训练过程中不断被训练和调参,在评估和预测时可以从checkpoint文件中加载进来,例如被slim.fully_connected() 或 slim.conv2d()网络层创建的变量。而非模型变量是指那些在训练和评估中用到的但是在预测阶段没有用到的变量,例如global_step变量在训练和评估中用到,但是它并不是一个模型变量。同样,移动平均变量可能反映模型变量,但移动平均值本身不是模型变量。模型变量和常规变量可以被Slim很容易的创建如下:
# 模型变量
weights = slim.model_variable('weights',
shape=[10, 10, 3 , 3],
initializer=tf.truncated_normal_initializer(stddev=0.1),
regularizer=slim.l2_regularizer(0.05),
device='/CPU:0')
model_variables = slim.get_model_variables()
# 常规变量
my_var = slim.variable('my_var',
shape=[20, 1],
initializer=tf.zeros_initializer())
regular_variables_and_model_variables = slim.get_variables()
那这是怎么工作的呢?当你通过Slim的网络层或者直接通过slim.model_variable()创建模型变量时,Slim把模型变量加入到tf.GraphKeys.MODEL_VARIABLES 的collection中。那如果你有属于自己的自定义网络层或者变量创建例程,但是你仍然想要Slim来帮你管理,这时要怎么办呢?Slim提供了一个便利的函数把模型变量加入到它的collection中。
my