Source code for dabble.check_nearby_objs

# Copyright 2022 AI Singapore
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Checks if detected objects are near each other.
"""

from typing import Any, Dict

import numpy as np

from peekingduck.pipeline.nodes.abstract_node import AbstractNode


[docs]class Node(AbstractNode): """Checks if any objects are near each other. It does so by comparing the 3D locations of all objects to see which ones are near each other. If the distance between two objects is below the minimum threshold, both would be flagged as near with ``tag_msg``. These flags can be accessed by the ``flags`` key of :term:`obj_attrs`. Inputs: |obj_3D_locs_data| Outputs: |obj_attrs_data| :mod:`dabble.check_nearby_objs` produces the ``flags`` attribute which contains either the ``tag_msg`` for objects that are near each other or an empty string for objects with no other objects nearby. Configs: near_threshold (:obj:`float`): **default = 2.0**. |br| Threshold of distance, in metres, between two objects. Objects with distance less than ``near_threshold`` would be considered as 'near'. tag_msg (:obj:`str`): **default = "TOO CLOSE!"**. |br| Tag to identify objects which are near others. .. versionchanged:: 1.2.0 :mod:`draw.check_nearby_objs` used to return ``obj_tags`` (:obj:`List[str]`) as an output data type, which has been deprecated and now subsumed under :term:`obj_attrs`. The same attribute is accessed by using the ``flags`` key of :term:`obj_attrs`. """ def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None: super().__init__(config, node_path=__name__, **kwargs) def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Compares the 3D locations of all objects to see which objects are close to each other. If an object is close to another, tag it. """ obj_flags = [""] * len(inputs["obj_3D_locs"]) for idx_1, loc_1 in enumerate(inputs["obj_3D_locs"]): for idx_2, loc_2 in enumerate(inputs["obj_3D_locs"]): if idx_1 == idx_2: continue dist_bet = np.linalg.norm(loc_1 - loc_2) if dist_bet < self.near_threshold: obj_flags[idx_1] = self.tag_msg break return {"obj_attrs": {"flags": obj_flags}} def _get_config_types(self) -> Dict[str, Any]: """Returns dictionary mapping the node's config keys to respective types.""" return {"near_threshold": float, "tag_msg": str}