Keras 是一个用 Python 编写的高级神经网络 API,以 TensorFlow(以及历史支持的 CNTK、Theano)作为后端运行。自 TensorFlow 2.0 起,Keras 被正式作为 TF 的官方高层 API(tf.keras)。Keras 的设计哲学是”为人类设计的 API”——减少认知负担、模块化、可组合、易于扩展。本文将从三种 API 风格、层组合模式、迁移学习、自定义组件,到训练优化和模型部署,系统掌握 Keras。
一、Keras 设计哲学与三种 API 风格 Keras 提供三种构建模型的方式,灵活性与复杂度依次递增。
1.1 Sequential API —— 线性层堆叠 适合简单的线性拓扑结构,每层只有单一输入和单一输出:
import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersmodel = keras.Sequential([ layers.Conv2D(32 , kernel_size=(3 , 3 ), activation='relu' , input_shape=(28 , 28 , 1 )), layers.MaxPooling2D(pool_size=(2 , 2 )), layers.Conv2D(64 , (3 , 3 ), activation='relu' ), layers.MaxPooling2D(pool_size=(2 , 2 )), layers.Flatten(), layers.Dropout(0.5 ), layers.Dense(10 , activation='softmax' ) ]) model = keras.Sequential() model.add(layers.Dense(256 , activation='relu' , input_shape=(784 ,))) model.add(layers.Dropout(0.3 )) model.add(layers.Dense(10 , activation='softmax' ))
1.2 Functional API —— DAG 拓扑结构 Functional API 将层视为可调用的函数,接受张量作为输入并返回输出张量,从而构建任意有向无环图(DAG)结构的模型:
inputs = keras.Input(shape=(784 ,)) x = layers.Dense(256 , activation='relu' )(inputs) x = layers.Dropout(0.5 )(x) outputs = layers.Dense(10 , activation='softmax' )(x) model = keras.Model(inputs=inputs, outputs=outputs) shared_embedding = layers.Embedding(10000 , 128 ) input_a = keras.Input(shape=(None ,), dtype='int32' , name='input_a' ) embed_a = shared_embedding(input_a) input_b = keras.Input(shape=(None ,), dtype='int32' , name='input_b' ) embed_b = shared_embedding(input_b) input_text = keras.Input(shape=(100 ,), name='text_input' ) input_image = keras.Input(shape=(64 , 64 , 3 ), name='image_input' ) x_text = layers.Embedding(10000 , 128 )(input_text) x_text = layers.LSTM(64 )(x_text) x_image = layers.Conv2D(32 , 3 , activation='relu' )(input_image) x_image = layers.GlobalAveragePooling2D()(x_image) concatenated = layers.Concatenate()([x_text, x_image]) output = layers.Dense(1 , activation='sigmoid' )(concatenated) multi_input_model = keras.Model( inputs=[input_text, input_image], outputs=output ) multi_input_model.compile ( optimizer='adam' , loss='binary_crossentropy' , metrics=['accuracy' ] ) base_input = keras.Input(shape=(784 ,)) x = layers.Dense(256 , activation='relu' )(base_input) x = layers.Dense(128 , activation='relu' )(x) class_output = layers.Dense(10 , activation='softmax' , name='class' )(x) aux_output = layers.Dense(1 , activation='sigmoid' , name='aux' )(x) multi_output_model = keras.Model( inputs=base_input, outputs=[class_output, aux_output] ) multi_output_model.compile ( optimizer='adam' , loss={'class' : 'sparse_categorical_crossentropy' , 'aux' : 'binary_crossentropy' }, loss_weights={'class' : 1.0 , 'aux' : 0.3 }, metrics={'class' : 'accuracy' } )
1.3 Subclassing API —— 完全自定义 通过继承 keras.Model 实现完全灵活的控制,适用于需要动态执行流程或非标准前向传播的研究性模型:
class CustomModel (keras.Model): def __init__ (self, num_classes=10 , **kwargs ): super ().__init__(**kwargs) self .conv1 = layers.Conv2D(32 , 3 , activation='relu' ) self .conv2 = layers.Conv2D(64 , 3 , activation='relu' ) self .maxpool = layers.MaxPooling2D() self .flatten = layers.Flatten() self .dropout = layers.Dropout(0.5 ) self .classifier = layers.Dense(num_classes) self .loss_tracker = keras.metrics.Mean(name='loss' ) def call (self, inputs, training=None ): x = self .conv1(inputs) x = self .maxpool(x) x = self .conv2(x) x = self .maxpool(x) x = self .flatten(x) if training: x = self .dropout(x) return self .classifier(x) def train_step (self, data ): x, y = data with tf.GradientTape() as tape: y_pred = self (x, training=True ) loss = self .compiled_loss(y, y_pred) grads = tape.gradient(loss, self .trainable_variables) self .optimizer.apply_gradients(zip (grads, self .trainable_variables)) self .compiled_metrics.update_state(y, y_pred) return {m.name: m.result() for m in self .metrics} @property def metrics (self ): return [self .loss_tracker] + self .compiled_metrics.metrics
三种 API 风格对比:
特性
Sequential
Functional
Subclassing
复杂度
低
中
高
拓扑结构
线性
任意 DAG
任意(含动态)
层共享
否
是
是
多输入/输出
否
是
是
可保存性
最好
好
需额外工作
适用场景
简单分类
多模态/残差
研究/动态架构
二、层组合模式 2.1 残差块(Residual Block) 残差连接通过跳跃连接将输入直接加到输出,解决了深层网络的梯度消失问题:
def residual_block (x, filters, stride=1 , use_projection=False ): """标准 ResNet 残差块""" shortcut = x if use_projection: shortcut = layers.Conv2D(filters, 1 , strides=stride, use_bias=False )(shortcut) shortcut = layers.BatchNormalization()(shortcut) x = layers.Conv2D(filters, 3 , strides=stride, padding='same' , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(filters, 3 , padding='same' , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.Add()([x, shortcut]) x = layers.ReLU()(x) return x def bottleneck_block (x, filters, stride=1 , expansion=4 ): shortcut = x bottleneck_filters = filters // expansion if stride != 1 or x.shape[-1 ] != filters: shortcut = layers.Conv2D(filters, 1 , strides=stride, use_bias=False )(shortcut) shortcut = layers.BatchNormalization()(shortcut) x = layers.Conv2D(bottleneck_filters, 1 , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(bottleneck_filters, 3 , strides=stride, padding='same' , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(filters, 1 , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.Add()([x, shortcut]) x = layers.ReLU()(x) return x
2.2 Inception 模块 Inception 模块通过并行的不同大小卷积核捕获多尺度特征:
def inception_module (x, filters ): """简化版 Inception 模块""" branch1 = layers.Conv2D(filters, 1 , padding='same' , activation='relu' )(x) branch2 = layers.Conv2D(filters, 1 , padding='same' , activation='relu' )(x) branch2 = layers.Conv2D(filters, 5 , padding='same' , activation='relu' )(branch2) branch3 = layers.Conv2D(filters, 1 , padding='same' , activation='relu' )(x) branch3 = layers.Conv2D(filters, 3 , padding='same' , activation='relu' )(branch3) branch4 = layers.MaxPooling2D(3 , strides=1 , padding='same' )(x) branch4 = layers.Conv2D(filters, 1 , padding='same' , activation='relu' )(branch4) output = layers.Concatenate()([branch1, branch2, branch3, branch4]) return output
2.3 Dense Block(密集连接) DenseNet 的核心模块,每层的输入包括所有前面层的输出:
def dense_block (x, num_layers, growth_rate ): """DenseNet 的 Dense Block""" features_list = [x] for i in range (num_layers): concat_features = layers.Concatenate()(features_list) x = layers.BatchNormalization()(concat_features) x = layers.ReLU()(x) x = layers.Conv2D(4 * growth_rate, 1 , use_bias=False )(x) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(growth_rate, 3 , padding='same' , use_bias=False )(x) features_list.append(x) return layers.Concatenate()(features_list) def transition_layer (x, compression=0.5 ): num_filters = int (x.shape[-1 ] * compression) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(num_filters, 1 , use_bias=False )(x) x = layers.AveragePooling2D(2 , strides=2 )(x) return x
三、迁移学习工作流 3.1 标准迁移学习流程 base_model = keras.applications.ResNet50( weights='imagenet' , include_top=False , input_shape=(224 , 224 , 3 ) ) base_model.trainable = False inputs = keras.Input(shape=(224 , 224 , 3 )) x = base_model(inputs, training=False ) x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.3 )(x) x = layers.Dense(256 , activation='relu' )(x) x = layers.Dropout(0.3 )(x) outputs = layers.Dense(num_classes, activation='softmax' )(x) model = keras.Model(inputs, outputs) model.compile ( optimizer=keras.optimizers.Adam(learning_rate=1e-3 ), loss='sparse_categorical_crossentropy' , metrics=['accuracy' ] ) model.fit(train_dataset, validation_data=val_dataset, epochs=10 , callbacks=callbacks) base_model.trainable = True for layer in base_model.layers[:100 ]: layer.trainable = False model.compile ( optimizer=keras.optimizers.Adam(learning_rate=1e-5 ), loss='sparse_categorical_crossentropy' , metrics=['accuracy' ] ) model.fit(train_dataset, validation_data=val_dataset, epochs=20 , callbacks=callbacks)
3.2 Keras 内置预训练模型 keras.applications.ResNet50(weights='imagenet' ) keras.applications.ResNet101(weights='imagenet' ) keras.applications.ResNet152(weights='imagenet' ) keras.applications.ResNet50V2(weights='imagenet' ) keras.applications.VGG16(weights='imagenet' ) keras.applications.VGG19(weights='imagenet' ) keras.applications.InceptionV3(weights='imagenet' ) keras.applications.InceptionResNetV2(weights='imagenet' ) keras.applications.MobileNet(weights='imagenet' ) keras.applications.MobileNetV2(weights='imagenet' ) keras.applications.MobileNetV3Small(weights='imagenet' ) keras.applications.MobileNetV3Large(weights='imagenet' ) keras.applications.EfficientNetB0(weights='imagenet' ) keras.applications.EfficientNetB7(weights='imagenet' ) keras.applications.EfficientNetV2S(weights='imagenet' ) keras.applications.EfficientNetV2M(weights='imagenet' ) keras.applications.ConvNeXtTiny(weights='imagenet' ) keras.applications.ConvNeXtSmall(weights='imagenet' ) keras.applications.ConvNeXtBase(weights='imagenet' ) keras.applications.ConvNeXtLarge(weights='imagenet' ) keras.applications.ConvNeXtXLarge(weights='imagenet' )
3.3 特征提取模式 base_model = keras.applications.EfficientNetB0( include_top=False , weights='imagenet' , pooling='avg' ) base_model.trainable = False features = base_model.predict(train_images) inputs = keras.Input(shape=(224 , 224 , 3 )) features = base_model(inputs) outputs = layers.Dense(num_classes, activation='softmax' )(features) model = keras.Model(inputs, outputs)
四、自定义组件 4.1 自定义 Layer class MyDense (keras.layers.Layer): """自定义全连接层(等同于 tf.keras.layers.Dense 的简化版)""" def __init__ (self, units, activation=None , **kwargs ): super ().__init__(**kwargs) self .units = units self .activation = keras.activations.get(activation) def build (self, input_shape ): """首次调用时自动调用,根据输入 shape 创建权重""" self .kernel = self .add_weight( name='kernel' , shape=(input_shape[-1 ], self .units), initializer='glorot_uniform' , trainable=True ) self .bias = self .add_weight( name='bias' , shape=(self .units,), initializer='zeros' , trainable=True ) self .built = True def call (self, inputs ): """前向计算""" output = tf.matmul(inputs, self .kernel) + self .bias if self .activation is not None : output = self .activation(output) return output def get_config (self ): """序列化支持(用于 model.save / clone)""" config = super ().get_config() config.update({'units' : self .units, 'activation' : keras.activations.serialize(self .activation)}) return config class L2NormLayer (keras.layers.Layer): def __init__ (self, **kwargs ): super ().__init__(**kwargs) def call (self, inputs ): return tf.math.l2_normalize(inputs, axis=-1 )
4.2 自定义 Loss class FocalLoss (keras.losses.Loss): """Focal Loss 实现""" def __init__ (self, alpha=0.25 , gamma=2.0 , from_logits=False , name='focal_loss' ): super ().__init__(name=name, reduction=keras.losses.Reduction.SUM_OVER_BATCH_SIZE) self .alpha = alpha self .gamma = gamma self .from_logits = from_logits def call (self, y_true, y_pred ): if self .from_logits: y_pred = tf.nn.softmax(y_pred, axis=-1 ) y_true_one_hot = tf.one_hot(tf.cast(y_true, tf.int32), depth=tf.shape(y_pred)[-1 ]) pt = tf.reduce_sum(y_true_one_hot * y_pred, axis=-1 ) focal_weight = tf.pow (1.0 - pt, self .gamma) ce = -tf.math.log(pt + keras.backend.epsilon()) loss = self .alpha * focal_weight * ce return loss def get_config (self ): config = super ().get_config() config.update({'alpha' : self .alpha, 'gamma' : self .gamma, 'from_logits' : self .from_logits}) return config def contrastive_loss (margin=1.0 ): def loss (y_true, y_pred ): square_pred = tf.square(y_pred) margin_square = tf.square(tf.maximum(margin - y_pred, 0 )) return tf.reduce_mean( y_true * square_pred + (1 - y_true) * margin_square ) return loss
4.3 自定义 Metric class F1Score (keras.metrics.Metric): """F1 分数 Metric""" def __init__ (self, num_classes, average='macro' , name='f1_score' , **kwargs ): super ().__init__(name=name, **kwargs) self .num_classes = num_classes self .average = average self .true_positives = self .add_weight(name='tp' , shape=(num_classes,), initializer='zeros' ) self .false_positives = self .add_weight(name='fp' , shape=(num_classes,), initializer='zeros' ) self .false_negatives = self .add_weight(name='fn' , shape=(num_classes,), initializer='zeros' ) def update_state (self, y_true, y_pred, sample_weight=None ): y_pred = tf.argmax(y_pred, axis=1 ) y_true = tf.cast(y_true, tf.int64) for c in range (self .num_classes): pred_c = tf.cast(tf.equal(y_pred, c), tf.float32) true_c = tf.cast(tf.equal(y_true, c), tf.float32) tp = tf.reduce_sum(pred_c * true_c) fp = tf.reduce_sum(pred_c * (1 - true_c)) fn = tf.reduce_sum((1 - pred_c) * true_c) self .true_positives[c].assign_add(tp) self .false_positives[c].assign_add(fp) self .false_negatives[c].assign_add(fn) def result (self ): precision = self .true_positives / (self .true_positives + self .false_positives + keras.backend.epsilon()) recall = self .true_positives / (self .true_positives + self .false_negatives + keras.backend.epsilon()) f1 = 2 * precision * recall / (precision + recall + keras.backend.epsilon()) if self .average == 'macro' : return tf.reduce_mean(f1) elif self .average == 'micro' : tp = tf.reduce_sum(self .true_positives) fp = tf.reduce_sum(self .false_positives) fn = tf.reduce_sum(self .false_negatives) p = tp / (tp + fp + keras.backend.epsilon()) r = tp / (tp + fn + keras.backend.epsilon()) return 2 * p * r / (p + r + keras.backend.epsilon()) return f1 def reset_state (self ): self .true_positives.assign(tf.zeros(self .num_classes)) self .false_positives.assign(tf.zeros(self .num_classes)) self .false_negatives.assign(tf.zeros(self .num_classes))
五、训练特性与技巧 5.1 compile 详解 model.compile ( optimizer=keras.optimizers.Adam(learning_rate=0.001 ), loss='sparse_categorical_crossentropy' , metrics=['accuracy' ], loss_weights=None , weighted_metrics=None , run_eagerly=False , steps_per_execution=1 , jit_compile=False , ) model.compile ( optimizer='adam' , loss='binary_crossentropy' , metrics=['accuracy' ], weighted_metrics=['accuracy' ] )
5.2 fit 参数详解 history = model.fit( x=train_dataset, y=None , batch_size=None , epochs=50 , verbose='auto' , callbacks=callbacks, validation_split=0.2 , validation_data=val_dataset, shuffle=True , class_weight={0 : 1.0 , 1 : 3.0 }, sample_weight=None , initial_epoch=10 , steps_per_epoch=None , validation_steps=None , validation_batch_size=None , validation_freq=1 , max_queue_size=10 , workers=1 , use_multiprocessing=False , )
5.3 混合精度训练 keras.mixed_precision.set_global_policy('mixed_float16' ) outputs = layers.Dense(num_classes, activation='softmax' , dtype='float32' )(x) print (keras.mixed_precision.global_policy())
5.4 多 GPU 训练(Distribution Strategy) strategy = tf.distribute.MirroredStrategy() print (f'Number of devices: {strategy.num_replicas_in_sync} ' )with strategy.scope(): model = create_model() model.compile ( optimizer=keras.optimizers.Adam(0.001 * strategy.num_replicas_in_sync), loss='sparse_categorical_crossentropy' , metrics=['accuracy' ] ) model.fit(train_dataset, epochs=10 )
六、模型优化 6.1 权重剪枝(Weight Pruning) import tensorflow_model_optimization as tfmotpruning_params = { 'pruning_schedule' : tfmot.sparsity.keras.PolynomialDecay( initial_sparsity=0.0 , final_sparsity=0.8 , begin_step=0 , end_step=1000 , frequency=100 ) } pruning_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) pruning_model.compile ( optimizer='adam' , loss='sparse_categorical_crossentropy' , metrics=['accuracy' ] ) callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), tfmot.sparsity.keras.PruningSummaries(log_dir='/tmp/logs' ), ] pruning_model.fit(train_dataset, epochs=5 , callbacks=callbacks) final_model = tfmot.sparsity.keras.strip_pruning(pruning_model)
6.2 量化感知训练(QAT) qat_model = tfmot.quantization.keras.quantize_model(model) qat_model.compile ( optimizer='adam' , loss='sparse_categorical_crossentropy' , metrics=['accuracy' ] ) qat_model.fit(train_dataset, epochs=2 ) converter = tf.lite.TFLiteConverter.from_keras_model(qat_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_qat_model = converter.convert()
6.3 权重聚类(Weight Clustering) cluster_weights = tfmot.clustering.keras.cluster_weights CentroidInitialization = tfmot.clustering.keras.CentroidInitialization clustering_params = { 'number_of_clusters' : 32 , 'cluster_centroids_init' : CentroidInitialization.LINEAR } clustered_model = cluster_weights(model, **clustering_params) clustered_model.compile (...) clustered_model.fit(...) final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
七、Keras 与 PyTorch 工作流对比
维度
Keras
PyTorch
学习曲线
平缓,适合入门
中等,需理解更多底层概念
灵活性
中→高(Subclassing)
高(完全自由)
代码量
少(compile+fit 封装)
多(手动循环)
调试
容易(Eager 模式)
容易(Pythonic)
数据管道
tf.data(功能强大)
DataLoader(简洁直观)
分布式
Distribution Strategy(简单)
DDP/FSDP(需手动配置)
生产部署
TF Serving(成熟)
TorchServe(发展中)
移动端
TF Lite(最成熟)
ExecuTorch(新)
预训练模型
Keras Applications
torchvision.models + timm
社区生态
Google 官方主导
Meta 主持,社区活跃
论文复现
较少直接提供
多数论文提供 PyTorch 代码
适合谁
快速实验、工程落地
研究、精细控制
model = keras.Sequential([ keras.layers.Dense(256 , activation='relu' , input_shape=(784 ,)), keras.layers.Dropout(0.5 ), keras.layers.Dense(128 , activation='relu' ), keras.layers.Dropout(0.3 ), keras.layers.Dense(10 , activation='softmax' ) ]) model.compile (optimizer='adam' , loss='sparse_categorical_crossentropy' , metrics=['accuracy' ]) model.fit(x_train, y_train, epochs=10 , validation_split=0.2 ) class MLP (nn.Module): def __init__ (self ): super ().__init__() self .fc1 = nn.Linear(784 , 256 ) self .fc2 = nn.Linear(256 , 128 ) self .fc3 = nn.Linear(128 , 10 ) self .dropout1 = nn.Dropout(0.5 ) self .dropout2 = nn.Dropout(0.3 ) def forward (self, x ): x = F.relu(self .fc1(x)) x = self .dropout1(x) x = F.relu(self .fc2(x)) x = self .dropout2(x) return self .fc3(x) model = MLP() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) for epoch in range (10 ): model.train() for x_batch, y_batch in train_loader: optimizer.zero_grad() output = model(x_batch) loss = criterion(output, y_batch) loss.backward() optimizer.step()
Keras 以”为人类设计”的哲学,大幅降低了深度学习的入门门槛。Sequential API 适合快速原型,Functional API 满足复杂架构需求,Subclassing API 则提供了无限制的研究自由度。配合 tf.data 高效数据管道、Distribution Strategy 自动分布式训练、以及 TF Serving / TF Lite 的成熟部署生态,Keras 能够覆盖从实验到生产的完整流程。建议初学者从 Keras 入手建立直觉,进阶后可转向 PyTorch 获取更多灵活性和研究社区的支持——两者并非替代关系,而是互补的工具。