카테고리 없음

Segmentation Inference

Cuoong 2022. 8. 31. 09:37
import os, shutil
import tensorflow as tf
import numpy as np
import cv2
model = tf.keras.models.load_model('.\Model_Mobile01.h5') #restore
root_dir = r"D:\DATASET\Oxford\train_image"
save_dir = r"D:\ADC_POC\cat_vs_rabbit\test-images"
for (root, dir, file) in os.walk(root_dir) :
    # print(f'root: {root}')
    # print(dir)
    # print(file)
    for fileName in file :
        print(root)
        # basename, extension = os.path.splitext(fileName)
        # print(f'basename:{basename}, extension:{extension}')
        sourcePath = os.path.join(root, fileName)
        print(f'sourcePath = {sourcePath}')
        img = cv2.imread(sourcePath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        savePath = os.path.join(save_dir, fileName)
        print(f'savePath = {savePath}')
        createFolder(save_dir)
        imgResize = cv2.resize(img, (256, 256))
        input_image = imgResize[tf.newaxis, ...]
        print(input_image.shape)
        pred_mask = model.predict(input_image)
        pred_img = cv2.cvtColor(pred_mask[0], cv2.COLOR_BGR2GRAY)
        print(pred_img)
        img = np.where(pred_img>0, 0, 255).astype(np.uint8)
        cv2.imwrite(savePath, img)
        break