AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,223 @@
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
@dataclass
class BoundingBox:
"""Bounding box representation."""
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
"""Return bounding box coordinates.
Returns:
List[float]: coodinates: [xmin, ymin, xmax, ymax]
"""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO + Mask from SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
"""Create a DetectionResult from a dictionary.
Args:
detection_dict (Dict): Detection result dictionary.
Returns:
DetectionResult: Detection result object.
"""
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
"""Use OpenCV to refine a mask by turning it into a polygon.
Args:
mask (np.ndarray): Segmentation mask.
Returns:
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour
return largest_contour.reshape(-1, 2).tolist()
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
Returns:
np.ndarray: Segmentation mask with the polygon filled.
"""
# Create an empty mask
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def load_image(image_str: str) -> Image.Image:
"""Load an image from a URL or file path.
Args:
image_str (str): URL or file path to the image.
Returns:
PIL.Image: Image object.
"""
if image_str.startswith("http"):
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
else:
image = Image.open(image_str).convert("RGB")
return image
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
boxes = []
for result in results:
xyxy = result.box.xyxy
boxes.append(xyxy)
return [boxes]
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks
device = "cuda" if torch.cuda.is_available() else "cpu"
detector_id = "IDEA-Research/grounding-dino-tiny"
print(f"load object detector pipeline: {detector_id}")
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
segmenter_id = "facebook/sam-vit-base"
print(f"load segmentator: {segmenter_id}")
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
print(f"load processor: {segmenter_id}")
processor = AutoProcessor.from_pretrained(segmenter_id)
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
global object_detector, device
labels = [label if label.endswith(".") else label + "." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
return [DetectionResult.from_dict(result) for result in results]
def segment(
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
) -> List[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
global segmentator, processor, device
boxes = _get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = _refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
) -> Tuple[torch.Tensor, List[DetectionResult]]:
"""Segment out the objects in an image given a set of labels.
Args:
image (Union[Image.Image, str]): Image to load/work on.
labels (List[str]): Object labels to segment.
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
Returns:
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
"""
if isinstance(image, str):
image = load_image(image)
detections = detect(image, labels, threshold)
if len(detections) == 0:
return ToTensor()(image), []
detections = segment(image, detections, polygon_refinement)
return ToTensor()(image), detections