跳至主要内容

Keras_Tuner_鐵人賽分享

Open In Colab

18.Keras Tuner

  • 分享於ithome鐵人賽文章的實作範例。
  • Keras Tuner是 Keras 團隊的一個模組,可自動執行神經網絡的超參數調整
  • 為了進行比較,首先使用預先選擇的超參數訓練基線模型,然後使用調整後的超參數重做該過程。
  • 範例改寫自Tensorflow 提供的官方教學

下載資料集

採用 Fashion MNIST dataset

from tensorflow import keras

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()

# Normalize
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
  • 建立 Baseline model
baseline_model = keras.Sequential()
baseline_model.add(keras.layers.Flatten(input_shape=(28, 28)))
baseline_model.add(keras.layers.Dense(units=512, activation='relu', name='dense_1'))
baseline_model.add(keras.layers.Dropout(0.2))
baseline_model.add(keras.layers.Dense(10, activation='softmax'))
baseline_model.compile(
optimizer = keras.optimizers.Adam(learning_rate=0.001),
loss = keras.losses.SparseCategoricalCrossentropy(),
metrics = ['accuracy']
)
baseline_model.summary()
# 訓練及評估模型
baseline_model.fit(img_train, label_train, epochs=10, validation_split=0.2)
baseline_eval_dict = baseline_model.evaluate(img_test, label_test, return_dict=True)
  • 建立顯示結果的輔助函數,以便稍後進行比較。
def print_results(model, model_name, eval_dict):
print(f'\n{model_name}:')
print(f'number of units in 1st Dense layer: {model.get_layer("dense_1").units}')
print(f'learning rate for the optimizer: {model.optimizer.lr.numpy()}')

for key,value in eval_dict.items():
print(f'{key}: {value}')

# Print results for baseline model
print_results(baseline_model, 'BASELINE MODEL', baseline_eval_dict)

Keras Tuner

要使用 Keras Tuner 執行超調,您需要:

  • 定義模型
  • 選擇要調整的超參數
  • 定義其搜索空間
  • 定義搜索策略
!pip install -q -U keras-tuner
# Import required packages
import tensorflow as tf
import kerastuner as kt

定義模型

  • 當您構建用於超調的模型時,除了模型架構之外,您還定義了超參數搜索空間。您為超調設置的模型稱為HyperModel

  • 您可以通過兩種方法定義超模型:

    • 通過使用模型構建器功能。
    • 通過HyperModel繼承 Keras Tuner API的類別。
  • 另外兩個預定義的HyperModel類 - HyperXceptionHyperResNet可用於計算機視覺應用程序。

  • 本範例由模型構建器函數來定義圖像分類模型。

  • Int()用來定義密集單元的搜索空間的最小值和最大值。

  • Choice()用於設定學習率。

def model_builder(hp):

model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))

# 設定搜索值範圍
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)

model.add(
keras.layers.Dense(
units=hp_units,
activation='relu',
name='dense_kt1'
)
)

model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.Dense(10, activation='softmax'))

# 設定學習率範圍 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
return model

實例化 Tuner 並執行超調

# Instantiate the tuner
tuner = kt.Hyperband(
model_builder,
objective='val_accuracy',
max_epochs=10,
factor=3,
directory='kt_dir',
project_name='kt_hyperband'
)
# Display hypertuning settings
tuner.search_space_summary()

定義了一個 EarlyStopping,當驗證的loss在5個epoch沒改善時停止訓練。

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

執行調參,需要些時間。

# Perform hypertuning
tuner.search(img_train, label_train, epochs=10, validation_split=0.2, callbacks=[stop_early])

您可以使用get_best_hyperparameters() 方法獲得性能最佳的模型。

# Get the optimal hyperparameters from the results
best_hps=tuner.get_best_hyperparameters()[0]

print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")

建立與訓練模型

現在最佳模型儲存在best_hps中,即可進行自動調參後的套用作業。

# Build the model with the optimal hyperparameters
h_model = tuner.hypermodel.build(best_hps)
h_model.summary()
# Train the hypertuned model
h_model.fit(img_train, label_train, epochs=10, validation_split=0.2)
# Evaluate the hypertuned model against the test set
h_eval_dict = h_model.evaluate(img_test, label_test, return_dict=True)
# Print results of the baseline and hypertuned model
print_results(baseline_model, 'BASELINE MODEL', baseline_eval_dict)
print_results(h_model, 'HYPERTUNED MODEL', h_eval_dict)

參考