import tensorflow as tf import numpy as np import os
AUTOTUNE = tf.data.AUTOTUNE BATCH_SIZE = 128 EPOCHS = 50 LEARNING_RATE = 0.001 IMAGE_SIZE = (32, 32)
tf.keras.mixed_precision.set_global_policy('mixed_float16')
def load_cifar10(): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() return x_train, y_train, x_test, y_test
x_train, y_train, x_test, y_test = load_cifar10()
data_augmentation = tf.keras.Sequential([ tf.keras.layers.RandomFlip('horizontal'), tf.keras.layers.RandomRotation(0.1), tf.keras.layers.RandomZoom(0.1), tf.keras.layers.RandomTranslation(0.1, 0.1), tf.keras.layers.RandomContrast(0.1), ], name='data_augmentation')
def preprocess(image, label, training=False): image = tf.cast(image, tf.float32) / 255.0 if training: image = data_augmentation(image, training=True) image = tf.keras.applications.resnet_v2.preprocess_input(image * 255.0) return image, label
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(50000) train_dataset = train_dataset.map( lambda x, y: preprocess(x, y, training=True), num_parallel_calls=AUTOTUNE ) train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True) train_dataset = train_dataset.prefetch(AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) test_dataset = test_dataset.map( lambda x, y: preprocess(x, y, training=False), num_parallel_calls=AUTOTUNE ) test_dataset = test_dataset.batch(BATCH_SIZE) test_dataset = test_dataset.prefetch(AUTOTUNE)
def create_model(): inputs = tf.keras.Input(shape=(32, 32, 3))
x = tf.keras.layers.Conv2D(64, 3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(inputs) x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.ReLU()(x)
def residual_block(x, filters, stride=1): shortcut = x if stride != 1 or x.shape[-1] != filters: shortcut = tf.keras.layers.Conv2D( filters, 1, strides=stride, use_bias=False )(shortcut) shortcut = tf.keras.layers.BatchNormalization()(shortcut)
x = tf.keras.layers.Conv2D(filters, 3, strides=stride, padding='same', use_bias=False)(x) x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Conv2D(filters, 3, padding='same', use_bias=False)(x) x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Add()([x, shortcut]) x = tf.keras.layers.ReLU()(x) return x
x = residual_block(x, 64) x = residual_block(x, 64) x = residual_block(x, 128, stride=2) x = residual_block(x, 128) x = residual_block(x, 256, stride=2) x = residual_block(x, 256)
x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dropout(0.3)(x) outputs = tf.keras.layers.Dense(10, dtype='float32', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
model = tf.keras.Model(inputs, outputs) return model
strategy = tf.distribute.MirroredStrategy() print(f'Number of devices: {strategy.num_replicas_in_sync}')
with strategy.scope(): model = create_model() model.compile( optimizer=tf.keras.optimizers.AdamW( learning_rate=LEARNING_RATE, weight_decay=1e-4 ), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[ tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5_accuracy'), ] )
callbacks = [ tf.keras.callbacks.EarlyStopping( monitor='val_accuracy', patience=10, restore_best_weights=True ), tf.keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7 ), tf.keras.callbacks.ModelCheckpoint( 'checkpoints/cifar10_best.h5', monitor='val_accuracy', save_best_only=True, mode='max' ), tf.keras.callbacks.TensorBoard(log_dir='./logs/cifar10', histogram_freq=1), ]
history = model.fit( train_dataset, validation_data=test_dataset, epochs=EPOCHS, callbacks=callbacks, verbose=1 )
test_loss, test_acc, test_top5 = model.evaluate(test_dataset) print(f'Test accuracy: {test_acc:.4f}, Top-5 accuracy: {test_top5:.4f}')
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 32, 32, 3], dtype=tf.uint8)]) def serving_fn(image): image = tf.cast(image, tf.float32) / 255.0 image = tf.keras.applications.resnet_v2.preprocess_input(image * 255.0) logits = model(image, training=False) probs = tf.nn.softmax(logits) return {'probabilities': probs, 'class_ids': tf.argmax(probs, axis=-1)}
export_path = './export/cifar10_classifier/1' tf.saved_model.save( model, export_path, signatures={'serving_default': serving_fn} ) print(f'Model saved to {export_path}')
converter = tf.lite.TFLiteConverter.from_saved_model(export_path) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() with open('cifar10_model.tflite', 'wb') as f: f.write(tflite_model) print(f'TF Lite model size: {len(tflite_model) / 1024:.1f} KB')
|