-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclassification_training.py
82 lines (68 loc) · 3.17 KB
/
classification_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Created on 17 Jun, 2022 at 11:08
Title: classification_training.py - Model training for Fire vs No Fire Classification
Description:
- Training the model for fire classification task
@author: Supantha Sen, nrsc, ISRO
"""
# Importing Modules
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
# Importing Custom Modules
from model_architectures import cnn_model
from dataset_fetching import fetch_data_classification
from data_plotting import training_plot
...
def train_model(val_generator, train_generator, batchsize):
model = cnn_model()
model.summary()
# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6),
loss='binary_crossentropy',
metrics=['accuracy'])
# Visualising model architecture
tf.keras.utils.plot_model(model,
to_file='./fire_classification_output/cnn_model.pdf',
show_shapes=True,
show_layer_names=True,
show_layer_activations=True)
# display(Image.open('cnn_model.png'))
# Saving Model Checkpoint
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='./fire_classification_output/saved_weights_epoch{epoch:02d}.h5',
monitor='accuracy',
verbose=1,
save_best_only=True,
save_wrights_only=True,
mode='auto',
save_freq='epoch')
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss',
min_delta=0.01,
patience=2,
verbose=1,
mode='auto',
baseline=None,
restore_best_weights=True)
# Fitting the model
hist = model.fit(train_generator,
epochs=30,
batch_size=batchsize,
validation_data=val_generator,
verbose=1,
#steps_per_epoch=(train_generator.samples)//batchsize,
#validation_steps=(val_generator.samples)//batchsize,
callbacks=[checkpoint, earlystop],
use_multiprocessing=True)
# Saving the trained model
np.save('./fire_classification_output/trained_model_history.npy', hist.history)
model.save('./fire_classification_output/trained_model.h5')
return model
## Main program
# Fetching the dataset from the directory
path = './Fire_vs_NoFire'
val_generator, train_generator, test_generator = fetch_data_classification(path)
# Training the model
batchsize = 256
model = train_model(val_generator, train_generator, batchsize)
# Plotting the Training Metrices
training_plot('./fire_classification_output/trained_model_history.npy')