AAAI Version
This commit is contained in:
@@ -0,0 +1,223 @@
|
||||
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
|
||||
|
||||
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Bounding box representation."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
@property
|
||||
def xyxy(self) -> List[float]:
|
||||
"""Return bounding box coordinates.
|
||||
|
||||
Returns:
|
||||
List[float]: coodinates: [xmin, ymin, xmax, ymax]
|
||||
|
||||
"""
|
||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Detection result from Grounding DINO + Mask from SAM."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
mask: Optional[np.array] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
|
||||
"""Create a DetectionResult from a dictionary.
|
||||
|
||||
Args:
|
||||
detection_dict (Dict): Detection result dictionary.
|
||||
|
||||
Returns:
|
||||
DetectionResult: Detection result object.
|
||||
|
||||
"""
|
||||
return cls(
|
||||
score=detection_dict["score"],
|
||||
label=detection_dict["label"],
|
||||
box=BoundingBox(
|
||||
xmin=detection_dict["box"]["xmin"],
|
||||
ymin=detection_dict["box"]["ymin"],
|
||||
xmax=detection_dict["box"]["xmax"],
|
||||
ymax=detection_dict["box"]["ymax"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
||||
"""Use OpenCV to refine a mask by turning it into a polygon.
|
||||
|
||||
Args:
|
||||
mask (np.ndarray): Segmentation mask.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||
|
||||
"""
|
||||
# Find contours in the binary mask
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the contour with the largest area
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Extract the vertices of the contour
|
||||
return largest_contour.reshape(-1, 2).tolist()
|
||||
|
||||
|
||||
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
||||
"""Convert a polygon to a segmentation mask.
|
||||
|
||||
Args:
|
||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Segmentation mask with the polygon filled.
|
||||
|
||||
"""
|
||||
# Create an empty mask
|
||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
# Convert polygon to an array of points
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
|
||||
# Fill the polygon with white color (255)
|
||||
cv2.fillPoly(mask, [pts], color=(255,))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_image(image_str: str) -> Image.Image:
|
||||
"""Load an image from a URL or file path.
|
||||
|
||||
Args:
|
||||
image_str (str): URL or file path to the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image: Image object.
|
||||
|
||||
"""
|
||||
if image_str.startswith("http"):
|
||||
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
|
||||
else:
|
||||
image = Image.open(image_str).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
|
||||
boxes = []
|
||||
for result in results:
|
||||
xyxy = result.box.xyxy
|
||||
boxes.append(xyxy)
|
||||
|
||||
return [boxes]
|
||||
|
||||
|
||||
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
||||
masks = masks.cpu().float()
|
||||
masks = masks.permute(0, 2, 3, 1)
|
||||
masks = masks.mean(axis=-1)
|
||||
masks = (masks > 0).int()
|
||||
masks = masks.numpy().astype(np.uint8)
|
||||
masks = list(masks)
|
||||
|
||||
if polygon_refinement:
|
||||
for idx, mask in enumerate(masks):
|
||||
shape = mask.shape
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
masks[idx] = mask
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
detector_id = "IDEA-Research/grounding-dino-tiny"
|
||||
print(f"load object detector pipeline: {detector_id}")
|
||||
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
||||
|
||||
segmenter_id = "facebook/sam-vit-base"
|
||||
print(f"load segmentator: {segmenter_id}")
|
||||
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
||||
print(f"load processor: {segmenter_id}")
|
||||
processor = AutoProcessor.from_pretrained(segmenter_id)
|
||||
|
||||
|
||||
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
|
||||
global object_detector, device
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
||||
return [DetectionResult.from_dict(result) for result in results]
|
||||
|
||||
|
||||
def segment(
|
||||
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
|
||||
) -> List[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
global segmentator, processor, device
|
||||
boxes = _get_boxes(detection_results)
|
||||
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
||||
|
||||
outputs = segmentator(**inputs)
|
||||
masks = processor.post_process_masks(
|
||||
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
|
||||
)[0]
|
||||
|
||||
masks = _refine_masks(masks, polygon_refinement)
|
||||
|
||||
for detection_result, mask in zip(detection_results, masks):
|
||||
detection_result.mask = mask
|
||||
|
||||
return detection_results
|
||||
|
||||
|
||||
def grounded_segmentation(
|
||||
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
|
||||
) -> Tuple[torch.Tensor, List[DetectionResult]]:
|
||||
"""Segment out the objects in an image given a set of labels.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, str]): Image to load/work on.
|
||||
labels (List[str]): Object labels to segment.
|
||||
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
|
||||
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
|
||||
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
|
||||
detections = detect(image, labels, threshold)
|
||||
if len(detections) == 0:
|
||||
return ToTensor()(image), []
|
||||
detections = segment(image, detections, polygon_refinement)
|
||||
|
||||
return ToTensor()(image), detections
|
||||
Reference in New Issue
Block a user