224 lines
6.8 KiB
Python
224 lines
6.8 KiB
Python
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
|
|
|
|
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision.transforms import ToTensor
|
|
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
|
|
|
|
|
@dataclass
|
|
class BoundingBox:
|
|
"""Bounding box representation."""
|
|
|
|
xmin: int
|
|
ymin: int
|
|
xmax: int
|
|
ymax: int
|
|
|
|
@property
|
|
def xyxy(self) -> List[float]:
|
|
"""Return bounding box coordinates.
|
|
|
|
Returns:
|
|
List[float]: coodinates: [xmin, ymin, xmax, ymax]
|
|
|
|
"""
|
|
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
|
|
|
|
|
@dataclass
|
|
class DetectionResult:
|
|
"""Detection result from Grounding DINO + Mask from SAM."""
|
|
|
|
score: float
|
|
label: str
|
|
box: BoundingBox
|
|
mask: Optional[np.array] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
|
|
"""Create a DetectionResult from a dictionary.
|
|
|
|
Args:
|
|
detection_dict (Dict): Detection result dictionary.
|
|
|
|
Returns:
|
|
DetectionResult: Detection result object.
|
|
|
|
"""
|
|
return cls(
|
|
score=detection_dict["score"],
|
|
label=detection_dict["label"],
|
|
box=BoundingBox(
|
|
xmin=detection_dict["box"]["xmin"],
|
|
ymin=detection_dict["box"]["ymin"],
|
|
xmax=detection_dict["box"]["xmax"],
|
|
ymax=detection_dict["box"]["ymax"],
|
|
),
|
|
)
|
|
|
|
|
|
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
|
"""Use OpenCV to refine a mask by turning it into a polygon.
|
|
|
|
Args:
|
|
mask (np.ndarray): Segmentation mask.
|
|
|
|
Returns:
|
|
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
|
|
|
"""
|
|
# Find contours in the binary mask
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
# Find the contour with the largest area
|
|
largest_contour = max(contours, key=cv2.contourArea)
|
|
|
|
# Extract the vertices of the contour
|
|
return largest_contour.reshape(-1, 2).tolist()
|
|
|
|
|
|
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
|
"""Convert a polygon to a segmentation mask.
|
|
|
|
Args:
|
|
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
|
image_shape (tuple): Shape of the image (height, width) for the mask.
|
|
|
|
Returns:
|
|
np.ndarray: Segmentation mask with the polygon filled.
|
|
|
|
"""
|
|
# Create an empty mask
|
|
mask = np.zeros(image_shape, dtype=np.uint8)
|
|
|
|
# Convert polygon to an array of points
|
|
pts = np.array(polygon, dtype=np.int32)
|
|
|
|
# Fill the polygon with white color (255)
|
|
cv2.fillPoly(mask, [pts], color=(255,))
|
|
|
|
return mask
|
|
|
|
|
|
def load_image(image_str: str) -> Image.Image:
|
|
"""Load an image from a URL or file path.
|
|
|
|
Args:
|
|
image_str (str): URL or file path to the image.
|
|
|
|
Returns:
|
|
PIL.Image: Image object.
|
|
|
|
"""
|
|
if image_str.startswith("http"):
|
|
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
|
|
else:
|
|
image = Image.open(image_str).convert("RGB")
|
|
|
|
return image
|
|
|
|
|
|
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
|
|
boxes = []
|
|
for result in results:
|
|
xyxy = result.box.xyxy
|
|
boxes.append(xyxy)
|
|
|
|
return [boxes]
|
|
|
|
|
|
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
|
masks = masks.cpu().float()
|
|
masks = masks.permute(0, 2, 3, 1)
|
|
masks = masks.mean(axis=-1)
|
|
masks = (masks > 0).int()
|
|
masks = masks.numpy().astype(np.uint8)
|
|
masks = list(masks)
|
|
|
|
if polygon_refinement:
|
|
for idx, mask in enumerate(masks):
|
|
shape = mask.shape
|
|
polygon = mask_to_polygon(mask)
|
|
mask = polygon_to_mask(polygon, shape)
|
|
masks[idx] = mask
|
|
|
|
return masks
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
detector_id = "IDEA-Research/grounding-dino-tiny"
|
|
print(f"load object detector pipeline: {detector_id}")
|
|
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
|
|
|
segmenter_id = "facebook/sam-vit-base"
|
|
print(f"load segmentator: {segmenter_id}")
|
|
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
|
print(f"load processor: {segmenter_id}")
|
|
processor = AutoProcessor.from_pretrained(segmenter_id)
|
|
|
|
|
|
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
|
|
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
|
|
global object_detector, device
|
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
|
|
|
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
|
return [DetectionResult.from_dict(result) for result in results]
|
|
|
|
|
|
def segment(
|
|
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
|
|
) -> List[DetectionResult]:
|
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
|
global segmentator, processor, device
|
|
boxes = _get_boxes(detection_results)
|
|
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
|
|
|
outputs = segmentator(**inputs)
|
|
masks = processor.post_process_masks(
|
|
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
|
|
)[0]
|
|
|
|
masks = _refine_masks(masks, polygon_refinement)
|
|
|
|
for detection_result, mask in zip(detection_results, masks):
|
|
detection_result.mask = mask
|
|
|
|
return detection_results
|
|
|
|
|
|
def grounded_segmentation(
|
|
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
|
|
) -> Tuple[torch.Tensor, List[DetectionResult]]:
|
|
"""Segment out the objects in an image given a set of labels.
|
|
|
|
Args:
|
|
image (Union[Image.Image, str]): Image to load/work on.
|
|
labels (List[str]): Object labels to segment.
|
|
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
|
|
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
|
|
|
|
"""
|
|
if isinstance(image, str):
|
|
image = load_image(image)
|
|
|
|
detections = detect(image, labels, threshold)
|
|
if len(detections) == 0:
|
|
return ToTensor()(image), []
|
|
detections = segment(image, detections, polygon_refinement)
|
|
|
|
return ToTensor()(image), detections
|