liriope commited on
Commit
417a395
·
verified ·
1 Parent(s): 27c82e7

Upload 2 files

Browse files
Files changed (2) hide show
  1. plant_disease_efficientnetb4.h5 +3 -0
  2. train_v2.py +185 -0
plant_disease_efficientnetb4.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f34b8da5c996362a6d20582a090f1f9a67926e591922156a780046c66493fed
3
+ size 98030480
train_v2.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, models, applications, optimizers, callbacks
3
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
4
+ import matplotlib.pyplot as plt
5
+
6
+ # 参数设置
7
+ IMAGE_SIZE = (380, 380) # EfficientNetB4的推荐输入尺寸
8
+ BATCH_SIZE = 8
9
+ EPOCHS = 15 # 20/10/5
10
+ NUM_CLASSES = 38 # PlantVillage数据集有38个类别(包含健康叶片)
11
+ DATA_DIR = "./PlantVillage-Dataset-master/raw/color" # 替换为你的数据集路径
12
+
13
+ # 数据增强和预处理
14
+ def create_data_generator():
15
+ # 使用EfficientNet的专用预处理方法
16
+ return ImageDataGenerator(
17
+ preprocessing_function=applications.efficientnet.preprocess_input,
18
+ rotation_range=40,
19
+ width_shift_range=0.2,
20
+ height_shift_range=0.2,
21
+ shear_range=0.2,
22
+ zoom_range=0.2,
23
+ horizontal_flip=True,
24
+ vertical_flip=True,
25
+ validation_split=0.05 # 保留5%数据作为验证集
26
+ )
27
+
28
+ # 创建数据生成器
29
+ train_datagen = create_data_generator()
30
+
31
+ # 训练数据流
32
+ train_generator = train_datagen.flow_from_directory(
33
+ DATA_DIR,
34
+ target_size=IMAGE_SIZE,
35
+ batch_size=BATCH_SIZE,
36
+ class_mode="categorical",
37
+ subset="training",
38
+ shuffle=True
39
+ )
40
+
41
+ # 验证数据流
42
+ val_generator = train_datagen.flow_from_directory(
43
+ DATA_DIR,
44
+ target_size=IMAGE_SIZE,
45
+ batch_size=BATCH_SIZE,
46
+ class_mode="categorical",
47
+ subset="validation",
48
+ shuffle=False
49
+ )
50
+
51
+ # 构建模型
52
+ def build_model():
53
+ # 加载预训练基模型
54
+ base_model = applications.EfficientNetB4(
55
+ weights="imagenet",
56
+ include_top=False,
57
+ input_shape=(*IMAGE_SIZE, 3)
58
+ )
59
+
60
+ # 冻结预训练层(初始训练阶段)
61
+ base_model.trainable = False
62
+
63
+ # 自定义顶层
64
+ inputs = layers.Input(shape=(*IMAGE_SIZE, 3))
65
+ x = base_model(inputs)
66
+ x = layers.GlobalAveragePooling2D()(x)
67
+ x = layers.Dense(256, activation="relu")(x)
68
+ x = layers.Dropout(0.5)(x)
69
+ outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
70
+
71
+ model = models.Model(inputs, outputs)
72
+ return model
73
+
74
+ model = build_model()
75
+
76
+ # 编译模型
77
+ model.compile(
78
+ optimizer=optimizers.Adam(learning_rate=1e-3),
79
+ loss="categorical_crossentropy",
80
+ metrics=["accuracy"]
81
+ )
82
+
83
+ # 回调函数
84
+ callbacks_list = [
85
+ callbacks.EarlyStopping(
86
+ monitor="val_loss",
87
+ patience=5,
88
+ restore_best_weights=True
89
+ ),
90
+ callbacks.ModelCheckpoint(
91
+ "best_model_initial", # 去后缀或使用.keras
92
+ save_best_only=True,
93
+ monitor="val_accuracy",
94
+ save_format="tf" # 显式指定保存格式
95
+ ),
96
+ callbacks.ReduceLROnPlateau(
97
+ monitor="val_loss",
98
+ factor=0.2,
99
+ patience=3
100
+ )
101
+ ]
102
+
103
+ # 初始训练(仅训练自定义顶层)
104
+ history = model.fit(
105
+ train_generator,
106
+ epochs=EPOCHS,
107
+ validation_data=val_generator,
108
+ callbacks=callbacks_list
109
+ )
110
+
111
+ # 解冻部分层进行微调
112
+ def fine_tune_model(model):
113
+ # 解冻顶层卷积块
114
+ model.get_layer("efficientnetb4").trainable = True
115
+ for layer in model.layers[1].layers[:-10]: # 保留最后10层可训练
116
+ layer.trainable = False
117
+
118
+ # 重新编译模型(使用更小的学习率)
119
+ model.compile(
120
+ optimizer=optimizers.Adam(learning_rate=1e-5),
121
+ loss="categorical_crossentropy",
122
+ metrics=["accuracy"]
123
+ )
124
+ return model
125
+
126
+ model = fine_tune_model(model)
127
+
128
+ # 微调训练
129
+ fine_tune_history = model.fit(
130
+ train_generator,
131
+ initial_epoch=history.epoch[-1],
132
+ epochs=history.epoch[-1] + 10, # 再训练10个epoch
133
+ validation_data=val_generator,
134
+ callbacks=[
135
+ callbacks.ModelCheckpoint(
136
+ "best_model_finetuned.h5",
137
+ save_best_only=True,
138
+ monitor="val_accuracy"
139
+ )
140
+ ]
141
+ )
142
+
143
+ # 保存最终模型
144
+ model.save("plant_disease_efficientnetb4.h5")
145
+
146
+ # 可视化训练过程
147
+ def plot_history(history, title):
148
+ plt.figure(figsize=(12, 4))
149
+
150
+ # 准确率曲线
151
+ plt.subplot(1, 2, 1)
152
+ plt.plot(history.history['accuracy'])
153
+ plt.plot(history.history['val_accuracy'])
154
+ plt.title(f'{title} Accuracy')
155
+ plt.ylabel('Accuracy')
156
+ plt.xlabel('Epoch')
157
+ plt.legend(['Train', 'Validation'], loc='upper left') # 与第一个文件一致
158
+
159
+ # 损失曲线
160
+ plt.subplot(1, 2, 2)
161
+ plt.plot(history.history['loss'])
162
+ plt.plot(history.history['val_loss'])
163
+ plt.title(f'{title} Loss')
164
+ plt.ylabel('Loss')
165
+ plt.xlabel('Epoch')
166
+ plt.legend(['Train', 'Validation'], loc='upper left') # 统一图例位置
167
+
168
+ plt.tight_layout()
169
+ plt.show()
170
+ # 修改调用方式(替换最后两行plot_training调用)
171
+ plot_history(history, "Initial Training")
172
+ plot_history(fine_tune_history, "Fine-tuning")
173
+
174
+ # 评估模型
175
+ def evaluate_model(model_path):
176
+ model = models.load_model(model_path)
177
+ loss, acc = model.evaluate(val_generator)
178
+ print(f"Validation accuracy: {acc*100:.2f}%")
179
+ print(f"Validation loss: {loss:.4f}")
180
+
181
+ print("Initial model evaluation:")
182
+ evaluate_model("best_model_initial.h5")
183
+
184
+ print("\nFine-tuned model evaluation:")
185
+ evaluate_model("best_model_finetuned.h5")