File size: 5,331 Bytes
417a395 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import tensorflow as tf
from tensorflow.keras import layers, models, applications, optimizers, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# 参数设置
IMAGE_SIZE = (380, 380) # EfficientNetB4的推荐输入尺寸
BATCH_SIZE = 8
EPOCHS = 15 # 20/10/5
NUM_CLASSES = 38 # PlantVillage数据集有38个类别(包含健康叶片)
DATA_DIR = "./PlantVillage-Dataset-master/raw/color" # 替换为你的数据集路径
# 数据增强和预处理
def create_data_generator():
# 使用EfficientNet的专用预处理方法
return ImageDataGenerator(
preprocessing_function=applications.efficientnet.preprocess_input,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
validation_split=0.05 # 保留5%数据作为验证集
)
# 创建数据生成器
train_datagen = create_data_generator()
# 训练数据流
train_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
shuffle=True
)
# 验证数据流
val_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="validation",
shuffle=False
)
# 构建模型
def build_model():
# 加载预训练基模型
base_model = applications.EfficientNetB4(
weights="imagenet",
include_top=False,
input_shape=(*IMAGE_SIZE, 3)
)
# 冻结预训练层(初始训练阶段)
base_model.trainable = False
# 自定义顶层
inputs = layers.Input(shape=(*IMAGE_SIZE, 3))
x = base_model(inputs)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = models.Model(inputs, outputs)
return model
model = build_model()
# 编译模型
model.compile(
optimizer=optimizers.Adam(learning_rate=1e-3),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
# 回调函数
callbacks_list = [
callbacks.EarlyStopping(
monitor="val_loss",
patience=5,
restore_best_weights=True
),
callbacks.ModelCheckpoint(
"best_model_initial", # 去后缀或使用.keras
save_best_only=True,
monitor="val_accuracy",
save_format="tf" # 显式指定保存格式
),
callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.2,
patience=3
)
]
# 初始训练(仅训练自定义顶层)
history = model.fit(
train_generator,
epochs=EPOCHS,
validation_data=val_generator,
callbacks=callbacks_list
)
# 解冻部分层进行微调
def fine_tune_model(model):
# 解冻顶层卷积块
model.get_layer("efficientnetb4").trainable = True
for layer in model.layers[1].layers[:-10]: # 保留最后10层可训练
layer.trainable = False
# 重新编译模型(使用更小的学习率)
model.compile(
optimizer=optimizers.Adam(learning_rate=1e-5),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
return model
model = fine_tune_model(model)
# 微调训练
fine_tune_history = model.fit(
train_generator,
initial_epoch=history.epoch[-1],
epochs=history.epoch[-1] + 10, # 再训练10个epoch
validation_data=val_generator,
callbacks=[
callbacks.ModelCheckpoint(
"best_model_finetuned.h5",
save_best_only=True,
monitor="val_accuracy"
)
]
)
# 保存最终模型
model.save("plant_disease_efficientnetb4.h5")
# 可视化训练过程
def plot_history(history, title):
plt.figure(figsize=(12, 4))
# 准确率曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title(f'{title} Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left') # 与第一个文件一致
# 损失曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title(f'{title} Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left') # 统一图例位置
plt.tight_layout()
plt.show()
# 修改调用方式(替换最后两行plot_training调用)
plot_history(history, "Initial Training")
plot_history(fine_tune_history, "Fine-tuning")
# 评估模型
def evaluate_model(model_path):
model = models.load_model(model_path)
loss, acc = model.evaluate(val_generator)
print(f"Validation accuracy: {acc*100:.2f}%")
print(f"Validation loss: {loss:.4f}")
print("Initial model evaluation:")
evaluate_model("best_model_initial.h5")
print("\nFine-tuned model evaluation:")
evaluate_model("best_model_finetuned.h5") |