카테고리 없음
Binary Classification
Cuoong
2022. 7. 29. 10:43
import numpy as np
import os
import tensorflow as tf
import datetime
import tensorboard
from tensorflow.keras.preprocessing import image_dataset_from_directory
import tensorflow_addons as tfa
# Base model resolution
# EfficientNetB0 224
# EfficientNetB1 240
# EfficientNetB2 260
# EfficientNetB3 300
# EfficientNetB4 380
# EfficientNetB5 456
# EfficientNetB6 528
# EfficientNetB7 600
# Constant
DATASETPATH = r"D:\딥러닝 데이터셋\Tire Textures\training_data"
SAVEPATH = r"D:\딥러닝 데이터셋\Tire Textures\training_data_RESULT"
BATCH_SIZE = 32
IMG_SIZE = (300, 300)
IMG_SHAPE = IMG_SIZE+(3,)
AUTOTUNE = tf.data.AUTOTUNE
# Validation set과 Test set 을 각각 20%로 하고 전체 데이터 셋에 Trianing set은 60%
# Training set은 50배 증강
def createFolder(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSError:
print ('Error: Creating directory. ' + directory)
def loadingDataset(path) :
# 디렉토리 안에 클래스명으로 폴더가 생성되어 있어야 함
print(f'Directory : {os.listdir(path)}')
total_dataset = image_dataset_from_directory(path,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
class_names = total_dataset.class_names
class_count = len(class_names)
print(f'Class Count : {class_count}')
for images, labels in total_dataset.take(1):
print(f'image count : {len(images)}')
print(f'image size : {images[0].shape}')
return total_dataset, class_count
def splitDataset(dataset):
# 사용할 수 있는 데이터 배치 수 확인
total_batches = tf.data.experimental.cardinality(dataset)
print(f'Number of total batches : {total_batches}')
valid_batches = int(total_batches//5) # 20%
test_batches = int(total_batches//5) # 20%
train_batches = total_batches - valid_batches - test_batches
train_dataset = dataset.take(train_batches)
total_dataset_trainSkip = dataset.skip(train_batches)
valid_dataset = total_dataset_trainSkip.take(valid_batches)
total_dataset_trainValidSkip = total_dataset_trainSkip.skip(valid_batches)
test_dataset = total_dataset_trainValidSkip.take(test_batches)
total_dataset_trainValidTestSkip = total_dataset_trainValidSkip.skip(test_batches)
train_dataset_repeat = train_dataset.repeat(25) # 반복 25배
print('Number of train batches: %d' % tf.data.experimental.cardinality(train_dataset))
print('Number of train repeat batches: %d' % tf.data.experimental.cardinality(train_dataset_repeat))
print('Number of validation batches: %d' % tf.data.experimental.cardinality(valid_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
print('Number of total batches: %d' % tf.data.experimental.cardinality(total_dataset_trainValidTestSkip))
train_dataset_prefetch = train_dataset_repeat.prefetch(buffer_size=AUTOTUNE)
valid_dataset_prefetch = valid_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset_prefetch = test_dataset.prefetch(buffer_size=AUTOTUNE)
return train_dataset_prefetch, valid_dataset_prefetch, test_dataset_prefetch
def main():
total_dataset, class_count = loadingDataset(DATASETPATH)
train_dataset, valid_dataset, test_dataset = splitDataset(total_dataset)
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomFlip('vertical'),
tf.keras.layers.experimental.preprocessing.RandomRotation((-8.0,0.8),fill_mode='nearest'),
#tf.keras.layers.experimental.preprocessing.CenterCrop(200,200),
tf.keras.layers.experimental.preprocessing.RandomTranslation(0.05,0.05,fill_mode='nearest',interpolation='bilinear',seed=1,fill_value=0.0),
tf.keras.layers.experimental.preprocessing.RandomZoom((0.1, 0.2)),
tf.keras.layers.experimental.preprocessing.RandomContrast(0.1)
])
preprocess_input = tf.keras.applications.efficientnet.preprocess_input
base_model = tf.keras.applications.efficientnet.EfficientNetB3(input_shape=IMG_SHAPE, include_top = False, weights='imagenet')
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
#print(feature_batch.shape)
base_model.trainable = False
#base_model.summary()
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
#print(feature_batch_average.shape)
prediction_layer = tf.keras.layers.Dense(1, activation='sigmoid')
prediction_batch = prediction_layer(feature_batch_average)
#print(prediction_batch.shape)
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
print("Number of trainable_variables: ", len(model.trainable_variables))
#이미지 확인
# import matplotlib.pyplot as plt
# for image, _ in train_dataset.take(1):
# plt.figure(figsize=(10,10))
# for i in range(BATCH_SIZE):
# ax = plt.subplot(7,7,i+1)
# augmented_image = data_augmentation(tf.expand_dims(image[i], 0))
# plt.imshow(augmented_image[0]/255)
# plt.axis('off')
# plt.show()
base_learning_rate = 0.001
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryFocalCrossentropy(
gamma=2.0,
from_logits=False,
label_smoothing=0.2),
metrics=['accuracy', tf.keras.metrics.Recall(), tf.keras.metrics.Precision(), tf.keras.metrics.TrueNegatives(), tf.keras.metrics.FalsePositives()]
# metrics=['accuracy']
)
# imagenet weight로 기본 평가
#metrics0 = model.evaluate(valid_dataset)
loss0, accuracy0, recall0, precision0, truenegatvies0, falsepositives0 = model.evaluate(valid_dataset)
print("initial loss:{:.2f}".format(loss0))
print("initial accuracy:{:.2f}".format(accuracy0))
print("initial recall:{:.2f}".format(recall0))
print("initial precision:{:.2f}".format(precision0))
print("initial TrueNegatvies:{:.2f}".format(truenegatvies0))
print("initial FalsePositives:{:.2f}".format(falsepositives0))
#Callback 함수
log_dir = SAVEPATH + "\\logs\\fit\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# base model의 weight값을 고정한 상태에서 학습 진행
initial_epochs = 1
history = model.fit(train_dataset,
epochs = initial_epochs,
validation_data = valid_dataset,
callbacks = [tensorboard_callback]
)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
# 미세조정
base_model.trainable = True
print("Number of layers in the base model: ", len(base_model.layers))
fine_tune_at = 300
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate/10),
loss=tf.keras.losses.BinaryFocalCrossentropy(
gamma=2.0,
from_logits=False,
label_smoothing=0.2),
metrics=['accuracy', tf.keras.metrics.Recall(), tf.keras.metrics.Precision(), tf.keras.metrics.TrueNegatives(), tf.keras.metrics.FalsePositives()]
)
model.summary()
print("Number of trainable_variables: ", len(model.trainable_variables))
fine_tune_epochs = 1
total_epochs = initial_epochs + fine_tune_epochs
#Callback 함수
weight_filePath = SAVEPATH + "\\" + "Model_" + "{epoch:02d}-{val_loss:.10f}.h5"
modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(filepath=weight_filePath,monitor="val_loss",save_best_only=False,verbose=1)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',patience=25, factor=0.1, verbose=1)
earlyStopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=50, restore_best_weights=True, verbose=1)
log_dir = SAVEPATH + "\\logs\\fit\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
history_fine = model.fit(train_dataset,
epochs=total_epochs,
initial_epoch=history.epoch[-1],
validation_data=valid_dataset,
callbacks=[modelcheckpoint, reduce_lr, earlyStopping, tensorboard_callback])
acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']
loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
#평가
loss, accuracy, recall, precision, truenegatvies, falsepositives = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
print("THE END")
if __name__ == "__main__" :
main()