-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsegmentation_training.py
85 lines (67 loc) · 3.04 KB
/
segmentation_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Created on 16 Jun, 2022 at 16:42
Title: segmentation_training.py - Model training for Fire Segmentation
Description:
- Training the model for fire segmentation task
@author: Supantha Sen, nrsc, ISRO
"""
# Importing Modules
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
# Importing Custom Modules
from model_architectures import unet_model
from dataset_fetching import fetch_data_segmentation
from data_plotting import training_plot
...
def train_model(val_generator, train_generator, batchsize):
model = unet_model()
model.summary()
# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
loss='binary_crossentropy',
metrics=['accuracy'])
# Visualising model architecture
tf.keras.utils.plot_model(model,
to_file='./fire_segmentation_output/unet_model.pdf',
show_shapes=True,
show_layer_names=True,
show_layer_activations=True)
# display(Image.open('unet_model.png'))
# Saving Model Checkpoint
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='./fire_segmentation_output/saved_weights_epoch{epoch:02d}.h5',
monitor='accuracy',
verbose=1,
save_best_only=True,
save_wrights_only=True,
mode='auto',
save_freq='epoch')
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss',
min_delta=0.0005,
patience=3,
verbose=1,
mode='auto',
baseline=None,
restore_best_weights=True)
# Fitting the model
hist = model.fit(train_generator,
epochs=30,
batch_size=batchsize,
validation_data=val_generator,
verbose=1,
steps_per_epoch=(1603//batchsize),
validation_steps=(400//batchsize),
callbacks=[earlystop, checkpoint],
use_multiprocessing=True)
# Saving the trained model
np.save('./fire_segmentation_output/trained_model_history.npy', hist.history)
model.save('./fire_segmentation_output/trained_model.h5')
return model
##Main Program
path = './Fire_Segmentation'
val_generator, train_generator = fetch_data_segmentation(path)
# Training the model
batchsize = 8
model = train_model(val_generator, train_generator, batchsize)
# Plotting the Training Metrices
training_plot('./fire_segmentation_output/trained_model_history.npy')