github 코드에서 cam으로부터 bounding box를 추출하는 부분을 따왔다.
예상은 했다만 코드 한줄 혹은 for문 하나로 끝나는 건 아니구나
아무튼 나는 saliency map (or heatmap)에서 object localization할거라서 도움이 될 듯 하다.
전체코드
def extract_bbox(images, cams, gt_boxes, threshold=0.2, percentile=100,
color=[(0, 255, 0)]):
# Convert the format of threshold and percentile
if not isinstance(threshold, list):
threshold = [threshold]
if not isinstance(percentile, list):
percentile = [percentile]
assert len(threshold) == len(percentile)
# Generate colors
gt_color = (0, 0, 255) # (0, 0, 255)
line_thickness = 2
from itertools import cycle, islice
color = list(islice(cycle(color), len(percentile)))
# Convert a data format
images = images.clone().numpy().transpose(0, 2, 3, 1)
images = images[:, :, :, ::-1] * 255 # reverse the color representation(RGB -> BGR) and Opencv format
cams = cams.clone().detach().cpu().numpy().transpose(0, 2, 3, 1)
bboxes = []
blended_bboxes = []
for i in range(images.shape[0]):
image, cam, gt_box = images[i].astype('uint8'), cams[i], gt_boxes[i]
image_height, image_width, _ = image.shape
cam = cv2.resize(cam, (image_height, image_width),
interpolation=cv2.INTER_CUBIC)
# Generate a heatmap using jet colormap
cam_max, cam_min = np.amax(cam), np.amin(cam)
normalized_cam = (cam - cam_min) / (cam_max - cam_min) * 255
normalized_cam = normalized_cam.astype('uint8')
heatmap_jet = cv2.applyColorMap(normalized_cam, cv2.COLORMAP_JET)
blend = cv2.addWeighted(heatmap_jet, 0.5, image, 0.5, 0)
blended_bbox = blend.copy()
if not isinstance(gt_box, str):
cv2.rectangle(blended_bbox,
pt1=(gt_box[0], gt_box[1]), pt2=(gt_box[2], gt_box[3]),
color=gt_color, thickness=line_thickness)
# Extract a bbox
for _threshold, _percentile, _color in zip(threshold, percentile, color):
threshold_val = int(_threshold * np.percentile(normalized_cam, q=_percentile))
_, thresholded_gray_heatmap = cv2.threshold(
normalized_cam, threshold_val, maxval=255, type=cv2.THRESH_BINARY)
try:
_, contours, _ = cv2.findContours(thresholded_gray_heatmap,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
except:
contours, _ = cv2.findContours(thresholded_gray_heatmap,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
bbox = [0, 0, 224, 224]
if len(contours) > 0:
max_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(max_contour)
bbox = [x, y, x + w, y + h]
cv2.rectangle(blended_bbox,
pt1=(x, y), pt2=(x + w, y + h),
color=_color, thickness=line_thickness)
blended_bbox = blended_bbox[:,:,::-1] / 255.0
blended_bbox = blended_bbox.transpose(2, 0, 1)
bboxes.append(torch.tensor(bbox))
blended_bboxes.append(torch.tensor(blended_bbox))
bboxes = torch.stack(bboxes)
blended_bboxes = torch.stack(blended_bboxes)
return bboxes, blended_bboxes
This code defines a function called extract_bbox that takes in a set of input images, class activation maps (CAMs), and ground truth bounding boxes (gt_boxes). It then generates blended images with bounding boxes drawn around objects of interest based on the CAMs. The function returns the bounding boxes and the blended images with bounding boxes drawn.
def extract_bbox(images, cams, gt_boxes, threshold=0.2, percentile=100,
color=[(0, 255, 0)]):
뭐야 bounding box (gt)가 있어야하는건줄 알았는데 아래부분을 보니까 CAM을 이용해서도 bounding box를 시각화할 수 있었다.... 깜놀..
if not isinstance(threshold, list):
threshold = [threshold]
if not isinstance(percentile, list):
percentile = [percentile]
theshold, percentile을 list형태로 만들어줌
아직 어떻게 쓰이는 지는 파악 못함
gt_color = (0, 0, 255)
line_thickness = 2
from itertools import cycle, islice
color = list(islice(cycle(color), len(percentile)))
색깔과, 선 굵기를 세팅 해줌
cycle이랑 islice는 효율적인 loop을 위한 iterator임
cycle(iterable) : 전달 받은 인수로부터 모든 값을 순서대로 출력 (cycle('123')이면 1 2 3 1 2 3 1 2 ...)
islice(iterable, start, stop, step) : 주어진 iterable을 주어진 position에 따라 자른다 (islice(list,2,8,2)이면 3번째인덱스부터 8번째 인덱스까지 2개씩 스킵하며 출력)
images = images.clone().numpy().transpose(0, 2, 3, 1)
images = images[:, :, :, ::-1] * 255
cams = cams.clone().detach().cpu().numpy().transpose(0, 2, 3, 1)
input image와 CAM을 numpy배열로 변환
color representation 조정
bboxes = []
blended_bboxes = []
final bounding boxes, blended images with bounding boxes를 저장할 리스트 선언
for i in range(images.shape[0]):
image, cam, gt_box = images[i].astype('uint8'), cams[i], gt_boxes[i]
이미지를 순회하며, 각각에 이미지에 맞는 image, cam, gt_box 불러옴
image_height, image_width, _ = image.shape
cam = cv2.resize(cam, (image_height, image_width),
interpolation=cv2.INTER_CUBIC)
CAM의 size를 조정하여 image dimension과 맞춘다.
cam_max, cam_min = np.amax(cam), np.amin(cam)
normalized_cam = (cam - cam_min) / (cam_max - cam_min) * 255
normalized_cam = normalized_cam.astype('uint8')
heatmap_jet = cv2.applyColorMap(normalized_cam, cv2.COLORMAP_JET)
blend = cv2.addWeighted(heatmap_jet, 0.5, image, 0.5, 0)
blended_bbox = blend.copy()
heatmap을 생성함
if not isinstance(gt_box, str):
cv2.rectangle(blended_bbox,
pt1=(gt_box[0], gt_box[1]), pt2=(gt_box[2], gt_box[3]),
color=gt_color, thickness=line_thickness)
gt_box가 있으면 이미지에 그림
for _threshold, _percentile, _color in zip(threshold, percentile, color):
treshold, percentile, color값에 따라 loop를 돈다.
threshold_val = int(_threshold * np.percentile(normalized_cam, q=_percentile))
현재 threshold값 계산
_, thresholded_gray_heatmap = cv2.threshold(
normalized_cam, threshold_val, maxval=255, type=cv2.THRESH_BINARY)
CAM(graycale)에 threshold를 적용하여 binary image를 생성한다.
threshold보다 큰 값을 가진 픽셀만 white(255)값을 가짐
try:
_, contours, _ = cv2.findContours(thresholded_gray_heatmap,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
except:
contours, _ = cv2.findContours(thresholded_gray_heatmap,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
contours(continuous curve)를 찾는다.
이 controus들이 이미지 내에서 관심있는 영역을 나타냄
try except은 opencv version에 맞게 return value를 조정한 것
bbox = [0, 0, 224, 224]
default bounding box를 [0,0,224,224]로 세팅
if len(contours) > 0:
max_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(max_contour)
bbox = [x, y, x + w, y + h]
contour가 있으면 가장 큰 영역을 계산함
그럼 가장 큰 contour에 대해 bounding rectangle을 계산하고 bounding box 값을 업데이트한다.
바운딩박스는 [x,y,x+w,y+h]로 구성되어 있으며
(x,y)는 top-left corner, w는 width, h는 height
cv2.rectangle(blended_bbox,
pt1=(x, y), pt2=(x + w, y + h),
color=_color, thickness=line_thickness)
blended_bbox_image에 바운딩박스를 그림
나는 이 좌표를 가지고 crop해야겠지