直接上程序。
import tensorflow as tf
import numpy
def parameters_count(_model):
total_count = _model.count_params()
trainable_count = numpy.sum([tf.keras.backend.count_params(w) for w in _model.trainable_weights])
non_trainable_count = numpy.sum([tf.keras.backend.count_params(w) for w in _model.non_trainable_weights])
return total_count, trainable_count, non_trainable_count
举例如下:
model = tf.keras.applications.resnet_v2.ResNet101V2(include_top=False)
parameters_count(model)
打印出来的结果:(42626560, 42528896, 97664)
模型总参数、可训练参数、非训练参数分别为42626560、42528896和97664