# Copyright 2022 AI Singapore
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Records the nodes' outputs to a CSV file."""
import logging
import textwrap
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from peekingduck.pipeline.nodes.abstract_node import AbstractNode
from peekingduck.pipeline.nodes.output.utils.csvlogger import CSVLogger
[docs]class Node(AbstractNode):
"""Tracks user-specified parameters and outputs the results in a CSV file.
Inputs:
``all`` (:obj:`List`): A placeholder that represents a flexible input.
Actual inputs to be written into the CSV file can be configured in
``stats_to_track``.
Outputs:
|none_output_data|
Configs:
stats_to_track (:obj:`List[str]`):
**default = ["keypoints", "bboxes", "bbox_labels"]**. |br|
Parameters to log into the CSV file. The chosen parameters must be
present in the data pool.
file_path (:obj:`str`):
**default = "PeekingDuck/data/stats.csv"**. |br|
Path of the CSV file to be saved. The resulting file name would have an appended
timestamp.
logging_interval (:obj:`int`): **default = 1**. |br|
Interval between each log, in terms of seconds.
"""
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
super().__init__(config, node_path=__name__, **kwargs)
self.logger = logging.getLogger(__name__)
self.logging_interval = int(self.logging_interval) # type: ignore
self.file_path = Path(self.file_path) # type: ignore
# check if file_path has a '.csv' extension
if self.file_path.suffix != ".csv":
raise ValueError("Filepath must have a '.csv' extension.")
self._file_path_datetime = self._append_datetime_file_path(self.file_path)
self._stats_checked = False
self.stats_to_track: List[str]
self.csv_logger = CSVLogger(
self._file_path_datetime, self.stats_to_track, self.logging_interval
)
def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Writes the current state of the tracked statistics into
the csv file as a row entry
Args:
inputs (dict): The data pool of the pipeline.
Returns:
outputs: [None]
"""
# reset and terminate when there are no more data
if inputs["pipeline_end"]:
self._reset()
return {}
if not self._stats_checked:
self._check_tracked_stats(inputs)
# self._stats_to_track might change after the check
self.csv_logger = CSVLogger(
self._file_path_datetime, self.stats_to_track, self.logging_interval
)
self.csv_logger.write(inputs, self.stats_to_track)
return {}
def _check_tracked_stats(self, inputs: Dict[str, Any]) -> None:
"""Checks whether user input statistics is present in the data pool
of the pipeline. Statistics not present in data pool will be
ignored and dropped.
"""
valid = []
invalid = []
for stat in self.stats_to_track:
if stat in inputs:
valid.append(stat)
else:
invalid.append(stat)
if invalid:
msg = textwrap.dedent(
f"""\
{invalid} are not valid outputs.
Data pool only has this outputs: {list(inputs.keys())}
Only {valid} will be logged in the csv file.
"""
)
self.logger.warning(msg)
# update stats_to_track with valid stats found in data pool
self.stats_to_track = valid
self._stats_checked = True
def _get_config_types(self) -> Dict[str, Any]:
"""Returns dictionary mapping the node's config keys to respective types."""
return {"stats_to_track": List[str], "file_path": str, "logging_interval": int}
def _reset(self) -> None:
del self.csv_logger
# initialize for use in run
self._stats_checked = False
@staticmethod
def _append_datetime_file_path(file_path: Path) -> Path:
"""Append time stamp to the filename."""
current_time = datetime.now()
# output as '240621-15-09-13'
time_str = current_time.strftime("%d%m%y-%H-%M-%S")
# append timestamp to filename before extension
# Format: filename_timestamp.extension
file_path_with_timestamp = file_path.with_name(
f"{file_path.stem}_{time_str}{file_path.suffix}"
)
return file_path_with_timestamp