카테고리 없음

Segmentation MVtec Tutorial

Cuoong 2022. 8. 19. 16:44

Segmentation_transistor_EfficientV2B3.py
0.01MB
Segmentation_transistor_EfficientV2B3.ipynb
1.00MB

 


# %% [markdown]
# ### 라이브러리

# %%
import sys
sys.path.append('D:\SEGMENTATION')

# %%
import os, shutil
import tensorflow as tf
import numpy as np
import cv2
import pix2pix
from tensorflow.keras.preprocessing import image_dataset_from_directory  
from IPython.display import clear_output
import matplotlib.pyplot as plt
import tensorflow_addons as tfa

# %% [markdown]
# ### MVTec 이미지 정리

# %%
# save_train_image_path = r"D:\Dataset\Oxford\train_image"
# save_train_mask_path = r"D:\Dataset\Oxford\train_mask"
# save_test_image_path = r"D:\Dataset\Oxford\test_image"
# save_test_mask_path = r"D:\Dataset\Oxford\test_mask"
save_train_image_path = r"D:\SEGMENTATION\transistor\train_image"
save_train_mask_path = r"D:\SEGMENTATION\transistor\train_mask"
save_test_image_path = r"D:\SEGMENTATION\transistor\test_image"
save_test_mask_path = r"D:\SEGMENTATION\transistor\test_mask"

# %%
def createFolder(directory):
    try:        
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print('Error: Creating directory. ' + directory)

# %%
createFolder(save_train_image_path)
createFolder(save_train_mask_path)
createFolder(save_test_image_path)
createFolder(save_test_mask_path)

file_count = 0
for (root, dir, file) in os.walk(root_dir) :
    print(f'root: {root}')
    print(dir)
    print(file)
    for fileName in file :
        basename, extension = os.path.splitext(fileName)
        print(f'basename:{basename}, extension:{extension}')
        if extension != '.png' :
            continue
        sourcePath = os.path.join(root, fileName)
        newFileName = str(file_count) + ".png"
        if file_count%4 > 0 :
            destImagePath = os.path.join(save_train_image_path,newFileName)
            destMaskPath = os.path.join(save_train_mask_path,newFileName)
        else :
            destImagePath = os.path.join(save_test_image_path,newFileName)
            destMaskPath = os.path.join(save_test_mask_path,newFileName)
        file_count += 1    
        shutil.copy(sourcePath, destImagePath)
        print(f'image copy : {sourcePath} -> {destImagePath}')
        dirname = os.path.dirname(sourcePath)
        cur_dirname = os.path.split(dirname)[-1]
        print(f'dirname = {cur_dirname}')
        if cur_dirname == 'good' : 
            img = cv2.imread(sourcePath)
            height, width, _ = img.shape
            mask_img = np.zeros((height, width), np.uint8)
            cv2.imwrite(destMaskPath, mask_img)
            print(f'empty mask : {destMaskPath}')
        else :
            mask_file_name = basename + "_mask.png"
            mask_sourcePath = sourcePath.replace(root_dir, ground_truth_dir)
            mask_sourcePath = mask_sourcePath.replace(fileName, mask_file_name)
            shutil.copy(mask_sourcePath, destMaskPath)
            print(f'mask copy : {mask_sourcePath} -> {destMaskPath}')
        

# %% [markdown]
# ### DATA LOADING

# %%
IMAGE_SIZE = (256,256)
IMG_SHAPE = IMAGE_SIZE+(3,) 
BATCH_SIZE = 64
BUFFER_SIZE = 1000
train_image_dataset = image_dataset_from_directory(save_train_image_path, shuffle=False, label_mode=None, batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)
train_mask_dataset = image_dataset_from_directory(save_train_mask_path, shuffle=False, label_mode=None, color_mode="grayscale", batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)
test_image_dataset = image_dataset_from_directory(save_test_image_path, shuffle=False, label_mode=None, batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)
test_mask_dataset = image_dataset_from_directory(save_test_mask_path, shuffle=False, label_mode=None, color_mode="grayscale", batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)

# %%
train_dataset = tf.data.Dataset.zip((train_image_dataset, train_mask_dataset))
test_dataset = tf.data.Dataset.zip((test_image_dataset, test_mask_dataset))

# %%
for image, mask in train_dataset.take(1):
  sample_image, sample_mask = image, mask
  # print(f'sample_image.shaep = {sample_image.shape}')
  # print(f'sample_mask.shaep = {sample_mask.shape}')
  # mask = sample_mask.numpy()
  # mask = np.unique(mask[:])
  print(sample_image)
  # image = sample_image.numpy()
  # image = np.unique(image[:])
  print(sample_mask)
# display([sample_image, sample_mask])

# %%
TRAIN_LENGTH = tf.data.experimental.cardinality(train_dataset)
TEST_LENGTH = tf.data.experimental.cardinality(test_dataset)
# BATCH_SIZE = 64
# STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE


# %% [markdown]
# ### DATASET Augmentation

# %%
def convert_to_float(image, mask):
    # image = tf.image.convert_image_dtype(image, dtype=tf.float32)/255.0
    # mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)/255.0
    # mask = tf.cast(mask < 1, tf.float32)
    # mask = tf.image.rgb_to_grayscale(mask)
    # image = tf.cast(image, tf.float32) / 255.0
    # mask -= 1
    image = image / 255.0
    mask = tf.cast(mask > 0, tf.float32)
    return image, mask

def trans1(image, mask):
    #if tf.random.uniform(()) > 0.5:
    image = tf.image.flip_left_right(image)
    mask = tf.image.flip_left_right(mask)
    return image, mask

def trans2(image, mask):
    #if tf.random.uniform(()) > 0.5:
    image = tf.image.flip_up_down(image)
    mask = tf.image.flip_up_down(mask)
    return image, mask

def trans3(image, mask):
    # if tf.random.uniform(()) > 0.5:
    image = tfa.image.rotate(image,-.2,fill_mode="reflect",interpolation="bilinear")
    mask = tfa.image.rotate(mask,-.2,fill_mode="reflect",interpolation="bilinear")
    return image, mask

def trans4(image, mask):
    # if tf.random.uniform(()) > 0.5:
    image = tfa.image.rotate(image,.2,fill_mode="reflect",interpolation="bilinear")
    mask = tfa.image.rotate(mask,.2,fill_mode="reflect",interpolation="bilinear")
    return image, mask

def trans5(image, mask):
    # if tf.random.uniform(()) > 0.5:
    image = tfa.image.mean_filter2d(image, filter_shape=5)
    mask = tfa.image.mean_filter2d(mask, filter_shape=5)
    return image, mask

def trans6(image, mask):
    # if tf.random.uniform(()) > 0.5:
    image = tfa.image.median_filter2d(image, filter_shape=5)
    mask = tfa.image.median_filter2d(mask, filter_shape=5)
    return image, mask

def trans7(image, mask):
    # if tf.random.uniform(()) > 0.5:
    image = tfa.image.sharpness(image, 0.1)
    mask = tfa.image.sharpness(mask, 0.1)
    return image, mask

ds1,ds2 = train_dataset.map(trans1).map(trans3),train_dataset.map(trans2).map(trans4)
ds3,ds4 = train_dataset.map(trans1).map(trans4),train_dataset.map(trans2).map(trans3)
ds5,ds6,ds7 = train_dataset.map(trans5),train_dataset.map(trans6),train_dataset.map(trans7)

ds_img = ds1.concatenate(ds2).concatenate(ds3).concatenate(ds4).concatenate(ds5).concatenate(ds6).concatenate(ds7).repeat(100)

# ds1,ds2 = train_dataset.map(trans1), train_dataset.map(trans2)
# ds_img = ds1.concatenate(ds2)

AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = (
    ds_img
    .map(convert_to_float)
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

test_ds = (
    test_dataset
    .map(convert_to_float)
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)

# for image in train_image_dataset.take(1):
#     print(image[0])
#     break

# %%
print(tf.data.experimental.cardinality(train_ds))
print(tf.data.experimental.cardinality(test_ds))

# %%
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

# %%
for image, mask in train_ds.take(1):
  print(image[1].shape)
  sample_image, sample_mask = image[2], mask[2]
  print(f'sample_image.shaep = {sample_image.shape}')
  print(f'sample_mask.shaep = {sample_mask.shape}')
  sample_mask = tf.cast(sample_mask, tf.uint8)
  print(sample_mask)
  print(sample_mask.shape)
  # print(f'ch1 = {np.sum(sample_mask[:,:,0])}, ch2 = {np.sum(sample_mask[:,:,1])}, ch3 = {np.sum(sample_mask[:,:,2])}')
  # mask_image = np.median(sample_mask, axis=2)
  print(f'mask_image = {sample_mask.shape}, mask_image={np.max(sample_mask)}')
  break
display([sample_image, sample_mask])
# print(np.min(sample_mask))
# print(np.max(sample_mask))
# print(np.median(sample_mask))

# %% [markdown]
# ### Encoder(MobileNetV2 or EfficientNetV2B3)

# %%
OUTPUT_CHANNELS = 3
# base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)
# layer_names = [
#     'block_1_expand_relu',   # 64x64
#     'block_3_expand_relu',   # 32x32
#     'block_6_expand_relu',   # 16x16
#     'block_13_expand_relu',  # 8x8
#     'block_16_project',      # 4x4
# ]

base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
# 이 층들의 활성화를 이용합시다
layer_names = [
    'block1b_project_activation',   # 128x128
    'block2c_expand_activation',   # 64x64
    'block4a_expand_activation',   # 32x32
    'block6a_expand_activation',  # 16x16
    'top_activation',      # 8x8
]

layers = [base_model.get_layer(name).output for name in layer_names]

# 특징추출 모델을 만듭시다
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

# %% [markdown]
# ### Decoder (pix2pix)

# %%
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

# %%
def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=IMG_SHAPE)
  x = inputs

  # 모델을 통해 다운샘플링합시다
  skips = down_stack(x)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # 건너뛰기 연결을 업샘플링하고 설정하세요
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # 이 모델의 마지막 층입니다
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

# %% [markdown]
# ### Training

# %%
model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            # loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

# %%
model.summary()

# %%
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

# %%
def show_predictions(model, dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(1):
      print(image.shape)
      for index in range(num) :
        input_image = image[index]
        input_image = input_image[tf.newaxis, ...]
        print(input_image.shape)
        pred_mask = model.predict(input_image)
        print(pred_mask.shape)
        display([image[index], mask[index], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
     create_mask(model.predict(sample_image[tf.newaxis, ...]))])

# %%
show_predictions(model, train_ds, 2)

# %%
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions(model)
    print ('\n에포크 이후 예측 예시 {}\n'.format(epoch+1))

# %%
weight_filePath = "." + "\\" + "Model_Mobile" + "{epoch:02d}.h5"
modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(filepath=weight_filePath,monitor="val_loss",save_best_only=False,verbose=1)

# %%
import datetime

# %%
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',patience=25, factor=0.1, verbose=1)
#earlyStopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=50, restore_best_weights=True, verbose=1)
log_dir = "." + "\\logs\\fit\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# %%
EPOCHS = 100
BATCH_SIZE = 64
VAL_SUBSPLITS = 5
VALIDATION_STEPS = TEST_LENGTH//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_ds,
                          epochs=EPOCHS,
                          validation_data=test_ds,
                          callbacks=[DisplayCallback(), modelcheckpoint, tensorboard_callback])

# %%
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

# %% [markdown]
# ### Predict

# %%
# model = tf.keras.models.load_model(modelcheckpoint.filepath) #restore

# %%
show_predictions(model, test_ds, 10)