Source code for easy_bbox.bbox

"""
bbox.py

Provides the `Bbox` class and utility functions for manipulating bounding boxes
in various coordinate formats (Pascal VOC, COCO, YOLO, etc.). Supports
transformations, geometric operations, and conversions.
"""

from __future__ import annotations

import math
from enum import Enum
from typing import Callable, List, Optional, Sequence, Tuple

from pydantic import BaseModel, model_validator
from typing_extensions import Self


class DistanceMetric(Enum):
    L1 = "L1"  # Manhattan distance
    L2 = "L2"  # Euclidian distance


class RoundingMethod(Enum):
    ROUND = "round"
    CEIL = "ceil"
    FLOOR = "floor"


[docs] class Bbox(BaseModel): """ A class to represent a Bbox (inherits from Pydantic BaseModel). The bbox is stored in Pascal_VOC format: top-left, bottom-right with a top-left origin (PIL coord system). (meaning that top < bottom) The bottom and right edges are considered excluded from the Bbox for compatibility with array slicing and PIL image cropping features (in case of Int Bboxes). Attributes: left (float): The left coordinate of the bounding box. top (float): The top coordinate of the bounding box. right (float): The right coordinate of the bounding box. bottom (float): The bottom coordinate of the bounding box. """ left: float top: float right: float bottom: float
[docs] @model_validator(mode="after") def check_bbox_validity(self) -> Self: """Checks that the Bbox is valid.""" if self.left > self.right or self.top > self.bottom: raise ValueError("The Bbox is not valid (negative width or height).") return self
# region From methods
[docs] @classmethod def from_tlbr(cls, tlbr: Sequence[float]) -> Bbox: """ Initializes the bounding box from top-left and bottom-right coordinates. Args: tlbr (Sequence[float]): A sequence containing the top-left and bottom-right coordinates of the bounding box in the format (left, top, right, bottom). Returns: Bbox: The Bbox instance. Raises: ValueError: If the length of the sequence is not 4, or if the Bbox is not valid (ie `left > right` or `top > bottom`). Example: >>> bbox = Bbox.from_tlbr((10, 20, 30, 40)) >>> print(bbox.left, bbox.top, bbox.right, bbox.bottom) 10 20 30 40 """ _assert_sequence_len(seq=tlbr) return cls(left=tlbr[0], top=tlbr[1], right=tlbr[2], bottom=tlbr[3])
[docs] @classmethod def from_tlwh(cls, tlwh: Sequence[float]) -> Bbox: """ Initializes the bounding box from top-left and width-height coordinates. Args: tlwh (Sequence[float]): A sequence containing the top-left and width-height coordinates of the bounding box in the format (left, top, width, height). Returns: Bbox: The Bbox instance. Raises: ValueError: If the length of the sequence is not 4, or if the Bbox is not valid (ie `width < 0` or `height < 0`). Example: >>> bbox = Bbox.from_tlwh((10, 20, 20, 30)) >>> print(bbox.left, bbox.top, bbox.right, bbox.bottom) 10 20 30 50 """ _assert_sequence_len(seq=tlwh) return cls( left=tlwh[0], top=tlwh[1], right=tlwh[0] + tlwh[2], bottom=tlwh[1] + tlwh[3], )
[docs] @classmethod def from_cwh(cls, cwh: Sequence[float]) -> Bbox: """ Initializes the bounding box from center and width-height coordinates. Args: cwh (Sequence[float]): A sequence containing the center and width-height coordinates of the bounding box in the format (center_x, center_y, width, height). Returns: Bbox: The Bbox instance. Raises: ValueError: If the length of the sequence is not 4, or if the Bbox is not valid (ie `width < 0` or `height < 0`). Example: >>> bbox = Bbox.from_cwh((20, 35, 20, 30)) >>> print(bbox.left, bbox.top, bbox.right, bbox.bottom) 10 20 30 50 """ _assert_sequence_len(seq=cwh) half_width = cwh[2] / 2 half_height = cwh[3] / 2 return cls( left=cwh[0] - half_width, top=cwh[1] - half_height, right=cwh[0] + half_width, bottom=cwh[1] + half_height, )
from_xyxy = from_tlbr from_pascal_voc = from_tlbr from_list = from_tlbr from_coco = from_tlwh # endregion # region To methods
[docs] def to_tlbr(self) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Top-Left, Bottom-Right format. Returns: Tuple[float, float, float, float]: The bounding box coordinates (x_min, y_min, x_max, y_max). x_min and y_min are the coordinates of the top-left corner of the bounding box. x_max and y_max are the coordinates of the bottom-right corner of the bounding box. """ return (self.left, self.top, self.right, self.bottom)
[docs] def to_list(self) -> List[float]: """ Returns the bounding box coordinates in Top-Left, Bottom-Right format. Returns: List[float]: The bounding box coordinates [x_min, y_min, x_max, y_max]. x_min and y_min are the coordinates of the top-left corner of the bounding box. x_max and y_max are the coordinates of the bottom-right corner of the bounding box. """ return [self.left, self.top, self.right, self.bottom]
[docs] def to_norm_tlbr(self, img_w: int, img_h: int) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Top-Left, Bottom-Right format, normalized based on the image dimensions. Args: img_w (int): The image width in pixels. img_h (int): The image height in pixels. Returns: Tuple[float, float, float, float]: The bounding box coordinates (x_min, y_min, x_max, y_max). x_min and y_min are the coordinates of the top-left corner of the bounding box. x_max and y_max are the coordinates of the bottom-right corner of the bounding box. All the returned values are **NORMALIZED** based on the image dimensions. """ return ( self.left / img_w, self.top / img_h, self.right / img_w, self.bottom / img_h, )
[docs] def to_tlwh(self) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Top-Left, Width-Height format. Returns: Tuple[float, float, float, float]: The bounding box coordinates (x_min, y_min, width, height). x_min and y_min are coordinates of the top-left corner of the bounding box. """ return (self.left, self.top, self.width, self.height)
[docs] def to_norm_tlwh(self, img_w: int, img_h: int) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Top-Left, Width-Height format, normalized based on the image dimensions. Args: img_w (int): The image width in pixels. img_h (int): The image height in pixels. Returns: Tuple[float, float, float, float]: The bounding box coordinates [x_min, y_min, width, height]. x_min and y_min are the coordinates of the top-left corner of the bounding box. All the returned values are **NORMALIZED** based on the image dimensions. """ return ( self.left / img_w, self.top / img_h, self.right / img_w, self.bottom / img_h, )
[docs] def to_cwh(self) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Center, Width-Height format. Returns: Tuple[float, float, float, float]: The bounding box coordinates (x_center, y_center, width, height). """ return (*self.center, self.width, self.height)
[docs] def to_norm_cwh(self, img_w: int, img_h: int) -> Tuple[float, float, float, float]: """ Returns the bounding box coordinates in Center, Width-Height format, normalized based on the image dimensions. Args: img_w (int): The image width in pixels. img_h (int): The image height in pixels. Returns: Tuple[float, float, float, float]: The NORMALIZED bounding box coordinates (x_center, y_center, width, height). """ cx, cy = self.center return (cx / img_w, cy / img_h, self.width / img_w, self.height / img_h)
[docs] def to_polygon( self, ) -> Tuple[ Tuple[float, float], Tuple[float, float], Tuple[float, float], Tuple[float, float], ]: """ Returns the bounding box corners as points. Returns: The corners coordinates in (x, y) format. The order is `top_left > top_right > bottom_right > bottom_left` """ return ( (self.left, self.top), (self.right, self.top), (self.right, self.bottom), (self.left, self.bottom), )
[docs] def to_int_tuple( self, rounding_method: RoundingMethod = RoundingMethod.ROUND ) -> Tuple[int, int, int, int]: """ Returns the bounding box coordinates in Top-Left, Bottom-Right format, with values rounded to int. Args: rounding_method (RoundingMethod): The rounding method to use. Defaults to `RoundingMethod.ROUND`. Returns: Tuple[int, int, int, int]: The coordinates as integers. Raises: ValueError: If a wrong rounding method is provided. """ rounding_func: Callable[[float], int] if rounding_method == RoundingMethod.CEIL: rounding_func = math.ceil elif rounding_method == RoundingMethod.FLOOR: rounding_func = math.floor elif rounding_method == RoundingMethod.ROUND: rounding_func = round else: raise ValueError(f"{rounding_method} is not a valid rounding method.") return ( rounding_func(self.left), rounding_func(self.top), rounding_func(self.right), rounding_func(self.bottom), )
to_pascal_voc = to_tlbr to_xyxy = to_tlbr to_albu = to_norm_tlbr to_coco = to_tlwh to_yolo = to_norm_cwh # endregion # region Transformations
[docs] def shift(self, horizontal_shift: float = 0, vertical_shift: float = 0) -> Bbox: """ Return a shifted Bbox by the specified horizontal and vertical amounts. Args: horizontal_shift (float, optional): The amount to shift the bounding box horizontally. Defaults to 0. vertical_shift (float, optional): The amount to shift the bounding box vertically. Defaults to 0. Returns: Bbox: The shifted Bbox instance. """ return Bbox( left=self.left + horizontal_shift, top=self.top + vertical_shift, right=self.right + horizontal_shift, bottom=self.bottom + vertical_shift, )
[docs] def scale(self, scale_factor: float) -> Bbox: """ Return a scaled Bbox by the specified scale factor. The scaling will be from the center. Args: scale_factor (float): The factor to scale the bounding box by. Width and height will be scaled by this factor. Returns: Bbox: The scaled Bbox instance. Raises: ValueError: If the scale is strictly negative. """ if scale_factor < 0: raise ValueError( "Scaling with a negative value would result in an invalid Bbox." ) cx, cy = self.center new_width = self.width * scale_factor new_height = self.height * scale_factor return Bbox( left=cx - new_width / 2, right=cx + new_width / 2, top=cy - new_height / 2, bottom=cy + new_height / 2, )
[docs] def scale_area(self, scale_factor: float) -> Bbox: """ Return a scaled Bbox such that `new area`/`old area` == `scale_factor`. The scaling will be from the center. Args: scale_factor (float): The factor to scale the bounding box by. The area will be scaled by this factor. (and width and height will be scaled by the square root of this factor.) Returns: Bbox: The scaled Bbox instance. Raises: ValueError: If the scale is strictly negative. """ if scale_factor < 0: raise ValueError( "Scaling with a negative value would result in an invalid Bbox." ) return self.scale(scale_factor=scale_factor**0.5)
[docs] def expand_uniform(self, padding: float) -> Bbox: """ Return an expanded Bbox by the specified padding. Args: padding (float): The amount to expand the bounding box by. Returns: Bbox: The expanded Bbox instance. """ return Bbox( left=self.left - padding, right=self.right + padding, top=self.top - padding, bottom=self.bottom + padding, )
[docs] def expand( self, left: float = 0, top: float = 0, right: float = 0, bottom: float = 0 ) -> Bbox: """ Return an expanded Bbox by the specified padding for each side. Args: left (float, optional): The amount to expand the left side of the bounding box by. Defaults to 0. top (float, optional): The amount to expand the top side of the bounding box by. Defaults to 0. right (float, optional): The amount to expand the right side of the bounding box by. Defaults to 0. bottom (float, optional): The amount to expand the bottom side of the bounding box by. Defaults to 0. Returns: Bbox: The expanded Bbox instance. """ return Bbox( left=self.left - left, right=self.right + right, top=self.top - top, bottom=self.bottom + bottom, )
[docs] def pad_to_square(self) -> Bbox: """Returns a padded Bbox to make it a square.""" width = self.width height = self.height if width > height: diff = (width - height) / 2 return Bbox( left=self.left, right=self.right, top=self.top - diff, bottom=self.bottom + diff, ) if height > width: diff = (height - width) / 2 return Bbox( left=self.left - diff, right=self.right + diff, top=self.top, bottom=self.bottom, ) return self.model_copy()
[docs] def pad_to_aspect_ratio(self, target_ratio: float) -> Bbox: """ Returns a padded Bbox to achieve the target aspect ratio. Args: target_ratio (float): The target aspect ratio. Returns: Bbox: A Bbox instance padded to the correct ratio. Raises: ValueError: If target_ratio is <= 0. """ if target_ratio <= 0: raise ValueError( f"Target ratio cannot be negative or zero. Received {target_ratio}" ) current_ratio = self.aspect_ratio if current_ratio > target_ratio: # Need to increase height new_height = self.width / target_ratio diff = (new_height - self.height) / 2 return Bbox( left=self.left, right=self.right, top=self.top - diff, bottom=self.bottom + diff, ) if current_ratio < target_ratio: # Need to increase width new_width = self.height * target_ratio diff = (new_width - self.width) / 2 return Bbox( left=self.left - diff, right=self.right + diff, top=self.top, bottom=self.bottom, ) return self.model_copy()
[docs] def clip_to_img(self, img_w: int, img_h: int) -> Bbox: """ Returns a clipped Bbox to the image dimensions. Remember that the bottom and right edges are considered excluded from the bbox, so `Bbox(left=-10, top=-20, right=100, bottom=120).clip_to_img(img_w=32, img_h=64)` returns `Bbox(left=0, top=0, right=32, bottom=64)` Args: img_w (int): The image width in pixels. img_h (int): The image height in pixels. Returns: Bbox: The clipped Bbox. """ return Bbox( left=max(0, self.left), top=max(0, self.top), right=min(img_w, self.right), bottom=min(img_h, self.bottom), )
# endregion # region Utility Functions
[docs] def overlaps(self, other: Bbox) -> bool: """ Checks if the current bounding box overlaps with another bounding box. Two bboxes are considered as overlapping if they intersect with a non-zero area. Args: other (Bbox): The other bounding box to check for overlap. Returns: bool: True if the bounding boxes overlap, False otherwise. """ inter = self.intersection(other) return inter is not None and inter.area > 0
[docs] def contains_point(self, x: float, y: float) -> bool: """ Checks if a point is inside the bounding box. Args: x (float): The x-coordinate of the point. y (float): The y-coordinate of the point. Returns: bool: True if the point is inside the bounding box, False otherwise. """ return self.left <= x <= self.right and self.top <= y <= self.bottom
[docs] def contains(self, other: Bbox) -> bool: """ Checks if the other Bbox is contained by this one. Args: other (Bbox): The other bounding box. Returns: bool: True if this bounding box contains the other one. """ return ( self.left <= other.left and self.right >= other.right and self.top <= other.top and self.bottom >= other.bottom )
[docs] def is_inside(self, other: Bbox) -> bool: """ Checks if this Bbox is contained by the other one. Args: other (Bbox): The other bounding box. Returns: bool: True if this bounding box is contained by the other one. """ return ( self.left >= other.left and self.right <= other.right and self.top >= other.top and self.bottom <= other.bottom )
[docs] def union(self, other: Bbox) -> Bbox: """ Calculates the minimal Bbox that englobes this one AND the other. Args: other (Bbox): The other bounding box to calculate the union with. Returns: Bbox: The minimal englobing Bbox. """ return Bbox( left=min(self.left, other.left), top=min(self.top, other.top), right=max(self.right, other.right), bottom=max(self.bottom, other.bottom), )
[docs] def intersection(self, other: Bbox) -> Optional[Bbox]: """ Calculates the intersection with another Bbox. If the resulting Bbox is not valid (ie `left > right` or `top > bottom`, returns None. Args: other (Bbox): The other bounding box to calculate the intersection with. Returns: Optional[Bbox]: The intersection of the two bounding boxes if valid. """ left = max(self.left, other.left) top = max(self.top, other.top) right = min(self.right, other.right) bottom = min(self.bottom, other.bottom) if left > right or top > bottom: return None return Bbox(left=left, top=top, right=right, bottom=bottom)
[docs] def iou(self, other: Bbox) -> float: """Calculates the Intersection over Union (IoU) with another bounding box. Args: other (Bbox): The other Bbox. Returns: float: The IoU between the two bounding boxes. """ # Calculate the intersection area inter = self.intersection(other) intersection_area = inter.area if inter is not None else 0 # Calculate the union area union_area = self.area + other.area - intersection_area if union_area == 0: return 0 # Calculate the IoU return intersection_area / union_area
[docs] def distance_to_point( self, x: float, y: float, dist: DistanceMetric = DistanceMetric.L2 ) -> float: """ Calculates the distance from the bounding box to a point. Args: x (float): The x-coordinate of the point. y (float): The y-coordinate of the point. dist (DistanceMetric): The distance metric to use. Defaults to ``DistanceMetric.L2``. Returns: float: The distance from the bounding box to the point. Raises: ValueError: If a wrong distance metric is provided. """ dx = max(self.left - x, 0, x - self.right) dy = max(self.top - y, 0, y - self.bottom) if dist == DistanceMetric.L1: # Calculate L1 distance (Manhattan distance) return dx + dy if dist == DistanceMetric.L2: # Calculate L2 distance (Euclidean distance) return (dx**2 + dy**2) ** 0.5 raise ValueError(f"{dist} is not a valid distance function.")
[docs] def distance_to_bbox( self, other: Bbox, dist: DistanceMetric = DistanceMetric.L2 ) -> float: """ Calculates the distance between the edges of this Bbox and another Bbox. Distance is zero if the boxes overlap or touch. Args: other (Bbox): The other bounding box. dist (DistanceMetric): The distance metric to use. Defaults to ``DistanceMetric.L2``. Returns: float: The distance between the two bounding boxes. Raises: ValueError: If a wrong distance metric is provided. """ dx = max(0, self.left - other.right, other.left - self.right) dy = max(0, self.top - other.bottom, other.top - self.bottom) if dist == DistanceMetric.L1: # Calculate L1 distance (Manhattan distance) return dx + dy if dist == DistanceMetric.L2: # Calculate L2 distance (Euclidean distance) return (dx**2 + dy**2) ** 0.5 raise ValueError(f"{dist} is not a valid distance function.")
# endregion @property def width(self) -> float: """The width of the Bbox.""" return self.right - self.left @property def height(self): """The height of the Bbox.""" return self.bottom - self.top @property def area(self) -> float: """The area of the Bbox""" return self.width * self.height @property def center(self) -> Tuple[float, float]: """The center of the Bbox in (x, y) format.""" return (self.left + self.right) / 2, (self.top + self.bottom) / 2 @property def aspect_ratio(self) -> float: """The aspect ratio of the Bbox (width over height).""" return self.width / self.height __or__ = union __and__ = intersection
def _assert_sequence_len(seq: Sequence[float], target_len: int = 4) -> None: """ Asserts that the length of the sequence is 4. Args: seq (Sequence[float]): The sequence to check the length of. target_len (int, optional): The target length of the sequence. Defaults to 4. Raises: ValueError: If the length of the sequence is not the target one. """ if len(seq) != target_len: raise ValueError( f"A sequence of len {len(seq)} has been passed. Need a sequence of len {target_len}." )