|
import tensorflow as tf
|
|
from tensorflow.keras import layers, models, applications, optimizers, callbacks
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
IMAGE_SIZE = (380, 380)
|
|
BATCH_SIZE = 8
|
|
EPOCHS = 15
|
|
NUM_CLASSES = 38
|
|
DATA_DIR = "./PlantVillage-Dataset-master/raw/color"
|
|
|
|
|
|
def create_data_generator():
|
|
|
|
return ImageDataGenerator(
|
|
preprocessing_function=applications.efficientnet.preprocess_input,
|
|
rotation_range=40,
|
|
width_shift_range=0.2,
|
|
height_shift_range=0.2,
|
|
shear_range=0.2,
|
|
zoom_range=0.2,
|
|
horizontal_flip=True,
|
|
vertical_flip=True,
|
|
validation_split=0.05
|
|
)
|
|
|
|
|
|
train_datagen = create_data_generator()
|
|
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
DATA_DIR,
|
|
target_size=IMAGE_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode="categorical",
|
|
subset="training",
|
|
shuffle=True
|
|
)
|
|
|
|
|
|
val_generator = train_datagen.flow_from_directory(
|
|
DATA_DIR,
|
|
target_size=IMAGE_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode="categorical",
|
|
subset="validation",
|
|
shuffle=False
|
|
)
|
|
|
|
|
|
def build_model():
|
|
|
|
base_model = applications.EfficientNetB4(
|
|
weights="imagenet",
|
|
include_top=False,
|
|
input_shape=(*IMAGE_SIZE, 3)
|
|
)
|
|
|
|
|
|
base_model.trainable = False
|
|
|
|
|
|
inputs = layers.Input(shape=(*IMAGE_SIZE, 3))
|
|
x = base_model(inputs)
|
|
x = layers.GlobalAveragePooling2D()(x)
|
|
x = layers.Dense(256, activation="relu")(x)
|
|
x = layers.Dropout(0.5)(x)
|
|
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
|
|
|
|
model = models.Model(inputs, outputs)
|
|
return model
|
|
|
|
model = build_model()
|
|
|
|
|
|
model.compile(
|
|
optimizer=optimizers.Adam(learning_rate=1e-3),
|
|
loss="categorical_crossentropy",
|
|
metrics=["accuracy"]
|
|
)
|
|
|
|
|
|
callbacks_list = [
|
|
callbacks.EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=5,
|
|
restore_best_weights=True
|
|
),
|
|
callbacks.ModelCheckpoint(
|
|
"best_model_initial",
|
|
save_best_only=True,
|
|
monitor="val_accuracy",
|
|
save_format="tf"
|
|
),
|
|
callbacks.ReduceLROnPlateau(
|
|
monitor="val_loss",
|
|
factor=0.2,
|
|
patience=3
|
|
)
|
|
]
|
|
|
|
|
|
history = model.fit(
|
|
train_generator,
|
|
epochs=EPOCHS,
|
|
validation_data=val_generator,
|
|
callbacks=callbacks_list
|
|
)
|
|
|
|
|
|
def fine_tune_model(model):
|
|
|
|
model.get_layer("efficientnetb4").trainable = True
|
|
for layer in model.layers[1].layers[:-10]:
|
|
layer.trainable = False
|
|
|
|
|
|
model.compile(
|
|
optimizer=optimizers.Adam(learning_rate=1e-5),
|
|
loss="categorical_crossentropy",
|
|
metrics=["accuracy"]
|
|
)
|
|
return model
|
|
|
|
model = fine_tune_model(model)
|
|
|
|
|
|
fine_tune_history = model.fit(
|
|
train_generator,
|
|
initial_epoch=history.epoch[-1],
|
|
epochs=history.epoch[-1] + 10,
|
|
validation_data=val_generator,
|
|
callbacks=[
|
|
callbacks.ModelCheckpoint(
|
|
"best_model_finetuned.h5",
|
|
save_best_only=True,
|
|
monitor="val_accuracy"
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
model.save("plant_disease_efficientnetb4.h5")
|
|
|
|
|
|
def plot_history(history, title):
|
|
plt.figure(figsize=(12, 4))
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(history.history['accuracy'])
|
|
plt.plot(history.history['val_accuracy'])
|
|
plt.title(f'{title} Accuracy')
|
|
plt.ylabel('Accuracy')
|
|
plt.xlabel('Epoch')
|
|
plt.legend(['Train', 'Validation'], loc='upper left')
|
|
|
|
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(history.history['loss'])
|
|
plt.plot(history.history['val_loss'])
|
|
plt.title(f'{title} Loss')
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.legend(['Train', 'Validation'], loc='upper left')
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
plot_history(history, "Initial Training")
|
|
plot_history(fine_tune_history, "Fine-tuning")
|
|
|
|
|
|
def evaluate_model(model_path):
|
|
model = models.load_model(model_path)
|
|
loss, acc = model.evaluate(val_generator)
|
|
print(f"Validation accuracy: {acc*100:.2f}%")
|
|
print(f"Validation loss: {loss:.4f}")
|
|
|
|
print("Initial model evaluation:")
|
|
evaluate_model("best_model_initial.h5")
|
|
|
|
print("\nFine-tuned model evaluation:")
|
|
evaluate_model("best_model_finetuned.h5") |