目录
  1. 1. 一、Keras 设计哲学与三种 API 风格
    1. 1.1. 1.1 Sequential API —— 线性层堆叠
    2. 1.2. 1.2 Functional API —— DAG 拓扑结构
    3. 1.3. 1.3 Subclassing API —— 完全自定义
  2. 2. 二、层组合模式
    1. 2.1. 2.1 残差块(Residual Block)
    2. 2.2. 2.2 Inception 模块
    3. 2.3. 2.3 Dense Block(密集连接)
  3. 3. 三、迁移学习工作流
    1. 3.1. 3.1 标准迁移学习流程
    2. 3.2. 3.2 Keras 内置预训练模型
    3. 3.3. 3.3 特征提取模式
  4. 4. 四、自定义组件
    1. 4.1. 4.1 自定义 Layer
    2. 4.2. 4.2 自定义 Loss
    3. 4.3. 4.3 自定义 Metric
  5. 5. 五、训练特性与技巧
    1. 5.1. 5.1 compile 详解
    2. 5.2. 5.2 fit 参数详解
    3. 5.3. 5.3 混合精度训练
    4. 5.4. 5.4 多 GPU 训练(Distribution Strategy)
  6. 6. 六、模型优化
    1. 6.1. 6.1 权重剪枝(Weight Pruning)
    2. 6.2. 6.2 量化感知训练(QAT)
    3. 6.3. 6.3 权重聚类(Weight Clustering)
  7. 7. 七、Keras 与 PyTorch 工作流对比
机器学习框架篇-Keras

Keras 是一个用 Python 编写的高级神经网络 API,以 TensorFlow(以及历史支持的 CNTK、Theano)作为后端运行。自 TensorFlow 2.0 起,Keras 被正式作为 TF 的官方高层 API(tf.keras)。Keras 的设计哲学是”为人类设计的 API”——减少认知负担、模块化、可组合、易于扩展。本文将从三种 API 风格、层组合模式、迁移学习、自定义组件,到训练优化和模型部署,系统掌握 Keras。

一、Keras 设计哲学与三种 API 风格

Keras 提供三种构建模型的方式,灵活性与复杂度依次递增。

1.1 Sequential API —— 线性层堆叠

适合简单的线性拓扑结构,每层只有单一输入和单一输出:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation='relu',
input_shape=(28, 28, 1)),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])

# 也可以用 .add() 逐层添加
model = keras.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(784,)))
model.add(layers.Dropout(0.3))
model.add(layers.Dense(10, activation='softmax'))

# Sequential 的限制:
# - 只能单输入单输出
# - 不支持层共享
# - 不支持分支/跳跃连接
# - 不支持多路融合

1.2 Functional API —— DAG 拓扑结构

Functional API 将层视为可调用的函数,接受张量作为输入并返回输出张量,从而构建任意有向无环图(DAG)结构的模型:

# 基础用法
inputs = keras.Input(shape=(784,))
x = layers.Dense(256, activation='relu')(inputs)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# 共享层:多个输入共享同一层
shared_embedding = layers.Embedding(10000, 128)

# 输入 A
input_a = keras.Input(shape=(None,), dtype='int32', name='input_a')
embed_a = shared_embedding(input_a)

# 输入 B — 使用同一 Embedding 层
input_b = keras.Input(shape=(None,), dtype='int32', name='input_b')
embed_b = shared_embedding(input_b)

# 多输入模型
input_text = keras.Input(shape=(100,), name='text_input')
input_image = keras.Input(shape=(64, 64, 3), name='image_input')

# 文本分支
x_text = layers.Embedding(10000, 128)(input_text)
x_text = layers.LSTM(64)(x_text)

# 图像分支
x_image = layers.Conv2D(32, 3, activation='relu')(input_image)
x_image = layers.GlobalAveragePooling2D()(x_image)

# 融合分支
concatenated = layers.Concatenate()([x_text, x_image])
output = layers.Dense(1, activation='sigmoid')(concatenated)

multi_input_model = keras.Model(
inputs=[input_text, input_image],
outputs=output
)

# 编译时需为不同输出/输入指定 loss 和 metrics
multi_input_model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)

# 训练时需传入字典或按顺序的列表
# history = multi_input_model.fit(
# {'text_input': x_text, 'image_input': x_image}, y,
# epochs=10
# )

# 多输出模型
base_input = keras.Input(shape=(784,))
x = layers.Dense(256, activation='relu')(base_input)
x = layers.Dense(128, activation='relu')(x)
class_output = layers.Dense(10, activation='softmax', name='class')(x)
aux_output = layers.Dense(1, activation='sigmoid', name='aux')(x)

multi_output_model = keras.Model(
inputs=base_input,
outputs=[class_output, aux_output]
)
multi_output_model.compile(
optimizer='adam',
loss={'class': 'sparse_categorical_crossentropy', 'aux': 'binary_crossentropy'},
loss_weights={'class': 1.0, 'aux': 0.3},
metrics={'class': 'accuracy'}
)

1.3 Subclassing API —— 完全自定义

通过继承 keras.Model 实现完全灵活的控制,适用于需要动态执行流程或非标准前向传播的研究性模型:

class CustomModel(keras.Model):
def __init__(self, num_classes=10, **kwargs):
super().__init__(**kwargs)
# 在 __init__ 中定义所有层
self.conv1 = layers.Conv2D(32, 3, activation='relu')
self.conv2 = layers.Conv2D(64, 3, activation='relu')
self.maxpool = layers.MaxPooling2D()
self.flatten = layers.Flatten()
self.dropout = layers.Dropout(0.5)
self.classifier = layers.Dense(num_classes)
self.loss_tracker = keras.metrics.Mean(name='loss')

def call(self, inputs, training=None):
# 在 call 中定义前向传播逻辑
x = self.conv1(inputs)
x = self.maxpool(x)
x = self.conv2(x)
x = self.maxpool(x)
x = self.flatten(x)
if training:
x = self.dropout(x)
return self.classifier(x)

# 覆盖 train_step 自定义训练逻辑
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}

@property
def metrics(self):
return [self.loss_tracker] + self.compiled_metrics.metrics

# Subclassing 的注意事项:
# - model.summary() 在首次 call 前可能无法显示完整形状
# - model.save() / load_model() 需要 get_config() 实现
# - 不可用 tf.saved_model.save() 自动导出(需要定义 input_signature)

三种 API 风格对比:

特性 Sequential Functional Subclassing
复杂度
拓扑结构 线性 任意 DAG 任意(含动态)
层共享
多输入/输出
可保存性 最好 需额外工作
适用场景 简单分类 多模态/残差 研究/动态架构

二、层组合模式

2.1 残差块(Residual Block)

残差连接通过跳跃连接将输入直接加到输出,解决了深层网络的梯度消失问题:

def residual_block(x, filters, stride=1, use_projection=False):
"""标准 ResNet 残差块"""
shortcut = x

if use_projection:
shortcut = layers.Conv2D(filters, 1, strides=stride,
use_bias=False)(shortcut)
shortcut = layers.BatchNormalization()(shortcut)

x = layers.Conv2D(filters, 3, strides=stride, padding='same',
use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)

x = layers.Conv2D(filters, 3, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)

x = layers.Add()([x, shortcut])
x = layers.ReLU()(x)
return x

# 瓶颈残差块(Bottleneck)— 1x1 降维 + 3x3 + 1x1 升维
def bottleneck_block(x, filters, stride=1, expansion=4):
shortcut = x
bottleneck_filters = filters // expansion

if stride != 1 or x.shape[-1] != filters:
shortcut = layers.Conv2D(filters, 1, strides=stride,
use_bias=False)(shortcut)
shortcut = layers.BatchNormalization()(shortcut)

x = layers.Conv2D(bottleneck_filters, 1, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)

x = layers.Conv2D(bottleneck_filters, 3, strides=stride,
padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)

x = layers.Conv2D(filters, 1, use_bias=False)(x)
x = layers.BatchNormalization()(x)

x = layers.Add()([x, shortcut])
x = layers.ReLU()(x)
return x

2.2 Inception 模块

Inception 模块通过并行的不同大小卷积核捕获多尺度特征:

def inception_module(x, filters):
"""简化版 Inception 模块"""
# 1x1 卷积分支
branch1 = layers.Conv2D(filters, 1, padding='same', activation='relu')(x)

# 1x1 → 5x5 卷积分支
branch2 = layers.Conv2D(filters, 1, padding='same', activation='relu')(x)
branch2 = layers.Conv2D(filters, 5, padding='same', activation='relu')(branch2)

# 1x1 → 3x3 卷积分支
branch3 = layers.Conv2D(filters, 1, padding='same', activation='relu')(x)
branch3 = layers.Conv2D(filters, 3, padding='same', activation='relu')(branch3)

# 3x3 最大池化 → 1x1 卷积分支
branch4 = layers.MaxPooling2D(3, strides=1, padding='same')(x)
branch4 = layers.Conv2D(filters, 1, padding='same', activation='relu')(branch4)

# 沿通道维拼接
output = layers.Concatenate()([branch1, branch2, branch3, branch4])
return output

2.3 Dense Block(密集连接)

DenseNet 的核心模块,每层的输入包括所有前面层的输出:

def dense_block(x, num_layers, growth_rate):
"""DenseNet 的 Dense Block"""
features_list = [x]

for i in range(num_layers):
# 拼接所有前面的特征图
concat_features = layers.Concatenate()(features_list)

# 瓶颈层:1x1 → 3x3
x = layers.BatchNormalization()(concat_features)
x = layers.ReLU()(x)
x = layers.Conv2D(4 * growth_rate, 1, use_bias=False)(x)

x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(growth_rate, 3, padding='same', use_bias=False)(x)

features_list.append(x)

return layers.Concatenate()(features_list)

# Transition Layer(下采样层)
def transition_layer(x, compression=0.5):
num_filters = int(x.shape[-1] * compression)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(num_filters, 1, use_bias=False)(x)
x = layers.AveragePooling2D(2, strides=2)(x)
return x

三、迁移学习工作流

3.1 标准迁移学习流程

# 加载 ImageNet 预训练权重
base_model = keras.applications.ResNet50(
weights='imagenet', # 预训练权重
include_top=False, # 不包含分类头(保留卷积基)
input_shape=(224, 224, 3)
)

# 冻结卷积基
base_model.trainable = False

# 添加自定义分类头
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False) # training=False 确保 BN 用 running stats
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs, outputs)

# 阶段 1:仅训练头部
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(train_dataset, validation_data=val_dataset,
epochs=10, callbacks=callbacks)

# 阶段 2:逐步解冻(Progressive Unfreezing)
# 解冻模型的最后几层
base_model.trainable = True

# 设置逐层不同的学习率(判别性微调)
for layer in base_model.layers[:100]:
layer.trainable = False # 浅层保持冻结

# 重新编译(trainable 变化后必须重新编译)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5), # 降低学习率
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# 阶段 2:微调全模型
model.fit(train_dataset, validation_data=val_dataset,
epochs=20, callbacks=callbacks)

3.2 Keras 内置预训练模型

# Keras Applications 提供的模型家族:

# ResNet 家族
keras.applications.ResNet50(weights='imagenet')
keras.applications.ResNet101(weights='imagenet')
keras.applications.ResNet152(weights='imagenet')
keras.applications.ResNet50V2(weights='imagenet')

# VGG 家族
keras.applications.VGG16(weights='imagenet')
keras.applications.VGG19(weights='imagenet')

# Inception 家族
keras.applications.InceptionV3(weights='imagenet')
keras.applications.InceptionResNetV2(weights='imagenet')

# MobileNet 家族(轻量级)
keras.applications.MobileNet(weights='imagenet')
keras.applications.MobileNetV2(weights='imagenet')
keras.applications.MobileNetV3Small(weights='imagenet')
keras.applications.MobileNetV3Large(weights='imagenet')

# EfficientNet 家族(最佳精度-效率平衡)
keras.applications.EfficientNetB0(weights='imagenet')
keras.applications.EfficientNetB7(weights='imagenet')
keras.applications.EfficientNetV2S(weights='imagenet')
keras.applications.EfficientNetV2M(weights='imagenet')

# 新型架构
keras.applications.ConvNeXtTiny(weights='imagenet')
keras.applications.ConvNeXtSmall(weights='imagenet')
keras.applications.ConvNeXtBase(weights='imagenet')
keras.applications.ConvNeXtLarge(weights='imagenet')
keras.applications.ConvNeXtXLarge(weights='imagenet')

# 每个模型的预处理函数
# 例如 ResNet50 使用:
# preprocess_input = keras.applications.resnet50.preprocess_input
# 例如 EfficientNet 使用:
# preprocess_input = keras.applications.efficientnet.preprocess_input

3.3 特征提取模式

# 模式:将预训练模型作为固定特征提取器
base_model = keras.applications.EfficientNetB0(
include_top=False, weights='imagenet', pooling='avg'
)
base_model.trainable = False

# 一次性提取所有特征(适合小数据集)
features = base_model.predict(train_images) # (N, 1280)
# 然后用传统 ML 或简单的 Dense 分类

# 在线提取(适合大数据集,节约内存)
inputs = keras.Input(shape=(224, 224, 3))
features = base_model(inputs)
outputs = layers.Dense(num_classes, activation='softmax')(features)
model = keras.Model(inputs, outputs)

四、自定义组件

4.1 自定义 Layer

class MyDense(keras.layers.Layer):
"""自定义全连接层(等同于 tf.keras.layers.Dense 的简化版)"""
def __init__(self, units, activation=None, **kwargs):
super().__init__(**kwargs)
self.units = units
self.activation = keras.activations.get(activation)

def build(self, input_shape):
"""首次调用时自动调用,根据输入 shape 创建权重"""
self.kernel = self.add_weight(
name='kernel',
shape=(input_shape[-1], self.units),
initializer='glorot_uniform',
trainable=True
)
self.bias = self.add_weight(
name='bias',
shape=(self.units,),
initializer='zeros',
trainable=True
)
self.built = True # 标记已构建

def call(self, inputs):
"""前向计算"""
output = tf.matmul(inputs, self.kernel) + self.bias
if self.activation is not None:
output = self.activation(output)
return output

def get_config(self):
"""序列化支持(用于 model.save / clone)"""
config = super().get_config()
config.update({'units': self.units, 'activation': keras.activations.serialize(self.activation)})
return config

# 自定义带正则化的层
class L2NormLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=-1)

4.2 自定义 Loss

class FocalLoss(keras.losses.Loss):
"""Focal Loss 实现"""
def __init__(self, alpha=0.25, gamma=2.0, from_logits=False, name='focal_loss'):
super().__init__(name=name, reduction=keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
self.alpha = alpha
self.gamma = gamma
self.from_logits = from_logits

def call(self, y_true, y_pred):
if self.from_logits:
y_pred = tf.nn.softmax(y_pred, axis=-1)

# 获取真实类的预测概率
y_true_one_hot = tf.one_hot(tf.cast(y_true, tf.int32), depth=tf.shape(y_pred)[-1])
pt = tf.reduce_sum(y_true_one_hot * y_pred, axis=-1)
# Focal loss
focal_weight = tf.pow(1.0 - pt, self.gamma)
ce = -tf.math.log(pt + keras.backend.epsilon())

loss = self.alpha * focal_weight * ce
return loss

def get_config(self):
config = super().get_config()
config.update({'alpha': self.alpha, 'gamma': self.gamma, 'from_logits': self.from_logits})
return config

# 或更简单的函数式 Loss
def contrastive_loss(margin=1.0):
def loss(y_true, y_pred):
# y_pred: 距离度量;y_true: 1=相似, 0=不相似
square_pred = tf.square(y_pred)
margin_square = tf.square(tf.maximum(margin - y_pred, 0))
return tf.reduce_mean(
y_true * square_pred + (1 - y_true) * margin_square
)
return loss

4.3 自定义 Metric

class F1Score(keras.metrics.Metric):
"""F1 分数 Metric"""
def __init__(self, num_classes, average='macro', name='f1_score', **kwargs):
super().__init__(name=name, **kwargs)
self.num_classes = num_classes
self.average = average
self.true_positives = self.add_weight(name='tp', shape=(num_classes,), initializer='zeros')
self.false_positives = self.add_weight(name='fp', shape=(num_classes,), initializer='zeros')
self.false_negatives = self.add_weight(name='fn', shape=(num_classes,), initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred, axis=1)
y_true = tf.cast(y_true, tf.int64)

for c in range(self.num_classes):
pred_c = tf.cast(tf.equal(y_pred, c), tf.float32)
true_c = tf.cast(tf.equal(y_true, c), tf.float32)

tp = tf.reduce_sum(pred_c * true_c)
fp = tf.reduce_sum(pred_c * (1 - true_c))
fn = tf.reduce_sum((1 - pred_c) * true_c)

self.true_positives[c].assign_add(tp)
self.false_positives[c].assign_add(fp)
self.false_negatives[c].assign_add(fn)

def result(self):
precision = self.true_positives / (self.true_positives + self.false_positives + keras.backend.epsilon())
recall = self.true_positives / (self.true_positives + self.false_negatives + keras.backend.epsilon())
f1 = 2 * precision * recall / (precision + recall + keras.backend.epsilon())

if self.average == 'macro':
return tf.reduce_mean(f1)
elif self.average == 'micro':
tp = tf.reduce_sum(self.true_positives)
fp = tf.reduce_sum(self.false_positives)
fn = tf.reduce_sum(self.false_negatives)
p = tp / (tp + fp + keras.backend.epsilon())
r = tp / (tp + fn + keras.backend.epsilon())
return 2 * p * r / (p + r + keras.backend.epsilon())
return f1

def reset_state(self):
self.true_positives.assign(tf.zeros(self.num_classes))
self.false_positives.assign(tf.zeros(self.num_classes))
self.false_negatives.assign(tf.zeros(self.num_classes))

五、训练特性与技巧

5.1 compile 详解

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
loss_weights=None, # 多输出时的权重
weighted_metrics=None, # 使用 sample_weight 时计算的 metrics
run_eagerly=False, # True 用于调试(关闭 tf.function 编译)
steps_per_execution=1, # 每 N 步更新一次(减少通信)
jit_compile=False, # 启用 XLA 编译
)

# weighted_metrics 示例:
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'],
weighted_metrics=['accuracy'] # 带权重的 accuracy
)
# fit 时传入 sample_weight,weighted_metrics 会考虑这些权重

5.2 fit 参数详解

history = model.fit(
x=train_dataset,
y=None, # datasets 模式不需要
batch_size=None, # datasets 已设置
epochs=50,
verbose='auto',
callbacks=callbacks,
validation_split=0.2, # 从训练数据中分出验证集(仅 ndarray 模式)
validation_data=val_dataset,
shuffle=True,
class_weight={0: 1.0, 1: 3.0}, # 类别权重(处理不平衡)
sample_weight=None, # 样本权重
initial_epoch=10, # 从第 11 个 epoch 开始(用于恢复训练)
steps_per_epoch=None, # 每 epoch 的 step 数(default=len(dataset))
validation_steps=None,
validation_batch_size=None,
validation_freq=1, # 每 N 个 epoch 验证一次
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)

5.3 混合精度训练

# 全局启用混合精度
keras.mixed_precision.set_global_policy('mixed_float16')

# 或按层控制
# 大部分层使用 float16,最后一层保持 float32
outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)

# 损失缩放由 optimizer 自动处理(使用 mixed_float16 policy 时)
# 无需手动配置

# 验证当前 policy
print(keras.mixed_precision.global_policy())
# 输出: <Policy "mixed_float16", loss_scale="dynamic">

5.4 多 GPU 训练(Distribution Strategy)

# 单机多卡
strategy = tf.distribute.MirroredStrategy()
print(f'Number of devices: {strategy.num_replicas_in_sync}')

with strategy.scope():
model = create_model()
model.compile(
optimizer=keras.optimizers.Adam(0.001 * strategy.num_replicas_in_sync),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# 注意:学习率通常按 replica 数等比放大

model.fit(train_dataset, epochs=10)

六、模型优化

6.1 权重剪枝(Weight Pruning)

import tensorflow_model_optimization as tfmot

# 定义剪枝配置
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, # 初始稀疏度
final_sparsity=0.8, # 最终稀疏度(80% 权重归零)
begin_step=0, # 开始剪枝的步数
end_step=1000, # 结束剪枝的步数
frequency=100 # 剪枝频率(每 N 步)
)
}

# 包裹模型以支持剪枝
pruning_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

pruning_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# 训练时使用 PruningSummary callback
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir='/tmp/logs'),
]

pruning_model.fit(train_dataset, epochs=5, callbacks=callbacks)

# 导出时剥离剪枝包装
final_model = tfmot.sparsity.keras.strip_pruning(pruning_model)

6.2 量化感知训练(QAT)

# 量化感知训练使模型适应 int8 量化
qat_model = tfmot.quantization.keras.quantize_model(model)

qat_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

qat_model.fit(train_dataset, epochs=2)

# 转换为 TF Lite
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_qat_model = converter.convert()

6.3 权重聚类(Weight Clustering)

# 将权重聚类为有限个离散值,减少模型大小
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
'number_of_clusters': 32, # 权重值聚类为 32 个中心
'cluster_centroids_init': CentroidInitialization.LINEAR
}

clustered_model = cluster_weights(model, **clustering_params)
clustered_model.compile(...)
clustered_model.fit(...)

# 导出
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

七、Keras 与 PyTorch 工作流对比

维度 Keras PyTorch
学习曲线 平缓,适合入门 中等,需理解更多底层概念
灵活性 中→高(Subclassing) 高(完全自由)
代码量 少(compile+fit 封装) 多(手动循环)
调试 容易(Eager 模式) 容易(Pythonic)
数据管道 tf.data(功能强大) DataLoader(简洁直观)
分布式 Distribution Strategy(简单) DDP/FSDP(需手动配置)
生产部署 TF Serving(成熟) TorchServe(发展中)
移动端 TF Lite(最成熟) ExecuTorch(新)
预训练模型 Keras Applications torchvision.models + timm
社区生态 Google 官方主导 Meta 主持,社区活跃
论文复现 较少直接提供 多数论文提供 PyTorch 代码
适合谁 快速实验、工程落地 研究、精细控制
# 相同的 MLP 在两种框架中的实现对比

# === Keras ===
model = keras.Sequential([
keras.layers.Dense(256, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.5),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, validation_split=0.2)

# === PyTorch === (等效实现)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout(0.5)
self.dropout2 = nn.Dropout(0.3)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.dropout2(x)
return self.fc3(x)

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
model.train()
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
output = model(x_batch)
loss = criterion(output, y_batch)
loss.backward()
optimizer.step()

Keras 以”为人类设计”的哲学,大幅降低了深度学习的入门门槛。Sequential API 适合快速原型,Functional API 满足复杂架构需求,Subclassing API 则提供了无限制的研究自由度。配合 tf.data 高效数据管道、Distribution Strategy 自动分布式训练、以及 TF Serving / TF Lite 的成熟部署生态,Keras 能够覆盖从实验到生产的完整流程。建议初学者从 Keras 入手建立直觉,进阶后可转向 PyTorch 获取更多灵活性和研究社区的支持——两者并非替代关系,而是互补的工具。

文章作者: Leo·Cheung
文章链接: http://tufusi.com/2022/04/10/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A1%86%E6%9E%B6%E7%AF%87-Keras/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 ONE·PIECE
打赏
  • 微信
  • 支付宝

评论