카테고리 없음
tensorflow 2.x HeatMap(Grad-CAM)
Cuoong
2022. 7. 11. 18:52
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img
from keras.models import load_model
from keras.applications.efficientnet import(
EfficientNetB3,
)
import cv2
img_Path = "cat.jpg"
image = np.array(load_img(img_Path, target_size=(250, 250, 3)))
#cv2.imshow("input", image)
#cv2.waitKey(0)
model = load_model("D:\\Model_02-0.8691406846.h5")
#model.summary()
last_conv_layer = model.get_layer("efficientnetb3").get_layer("top_activation")
last_conv_layer_model = tf.keras.Model(model.get_layer("efficientnetb3").input , last_conv_layer.output)
classifier_input = tf.keras.Input(shape=last_conv_layer.output.shape[1:])
x = classifier_input
for layer_name in ["global_average_pooling2d_1", "dropout_1", "dense_1"]:
x = model.get_layer(layer_name)(x)
classifier_model = tf.keras.Model(classifier_input, x)
#classifier_model.summary()
with tf.GradientTape() as tape:
last_conv_layer_output = last_conv_layer_model(image[np.newaxis, ...])
tape.watch(last_conv_layer_output)
preds = classifier_model(last_conv_layer_output)
top_pred_index = tf.argmax(preds[0])
top_class_channel = preds[:, top_pred_index]
grads = tape.gradient(top_class_channel, last_conv_layer_output)[0]
last_conv_layer_output = last_conv_layer_output[0]
guided_grads = (
tf.cast(last_conv_layer_output > 0, "float32")
* tf.cast(grads > 0, "float32")
* grads
)
pooled_guided_grads = tf.reduce_mean(guided_grads, axis=(0, 1))
guided_gradcam = np.ones(last_conv_layer_output.shape[:2], dtype=np.float32)
for i, w in enumerate(pooled_guided_grads):
guided_gradcam += w * last_conv_layer_output[:, :, i]
guided_gradcam = cv2.resize(guided_gradcam.numpy(), (250, 250))
guided_gradcam = np.clip(guided_gradcam, 0, np.max(guided_gradcam))
guided_gradcam = (guided_gradcam - guided_gradcam.min()) / (
guided_gradcam.max() - guided_gradcam.min()
)
#plt.imshow(image)
#plt.imshow(guided_gradcam, alpha=0.5)
heatmap = np.uint8(255 * guided_gradcam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
img_result = cv2.addWeighted(image, 0.7, heatmap, 0.3, 0)
print(f"image Type {type(image)}, Shape: {image.shape}")
print(f"heatmap Type {type(heatmap)}, Shape: {heatmap.shape}")
cv2.imshow("CAM", heatmap)
cv2.imshow("RESULT", img_result)
cv2.waitKey(0)
cv2.destroyAllWindows()