286 lines
11 KiB
Python
286 lines
11 KiB
Python
"""
|
|
Postprocessing node implementation for output transformation operations.
|
|
|
|
This module provides the PostprocessNode class which handles output postprocessing
|
|
operations in the pipeline, including result filtering, format conversion, and
|
|
output validation.
|
|
|
|
Main Components:
|
|
- PostprocessNode: Core postprocessing node implementation
|
|
- Result filtering and validation
|
|
- Output format conversion
|
|
|
|
Usage:
|
|
from cluster4npu_ui.core.nodes.postprocess_node import PostprocessNode
|
|
|
|
node = PostprocessNode()
|
|
node.set_property('output_format', 'JSON')
|
|
node.set_property('confidence_threshold', 0.5)
|
|
"""
|
|
|
|
from .base_node import BaseNodeWithProperties
|
|
|
|
|
|
class PostprocessNode(BaseNodeWithProperties):
|
|
"""
|
|
Postprocessing node for output transformation operations.
|
|
|
|
This node handles various postprocessing operations including result filtering,
|
|
format conversion, confidence thresholding, and output validation.
|
|
"""
|
|
|
|
__identifier__ = 'com.cluster.postprocess_node'
|
|
NODE_NAME = 'Postprocess Node'
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
# Setup node connections
|
|
self.add_input('input', multi_input=False, color=(255, 140, 0))
|
|
self.add_output('output', color=(0, 255, 0))
|
|
self.set_color(153, 51, 51)
|
|
|
|
# Initialize properties
|
|
self.setup_properties()
|
|
|
|
def setup_properties(self):
|
|
"""Initialize postprocessing-specific properties."""
|
|
# Output format
|
|
self.create_business_property('output_format', 'JSON', [
|
|
'JSON', 'XML', 'CSV', 'Binary', 'MessagePack', 'YAML'
|
|
])
|
|
|
|
# Confidence filtering
|
|
self.create_business_property('confidence_threshold', 0.5, {
|
|
'min': 0.0,
|
|
'max': 1.0,
|
|
'step': 0.01,
|
|
'description': 'Minimum confidence threshold for results'
|
|
})
|
|
|
|
self.create_business_property('enable_confidence_filter', True, {
|
|
'description': 'Enable confidence-based filtering'
|
|
})
|
|
|
|
# NMS (Non-Maximum Suppression)
|
|
self.create_business_property('nms_threshold', 0.4, {
|
|
'min': 0.0,
|
|
'max': 1.0,
|
|
'step': 0.01,
|
|
'description': 'NMS threshold for overlapping detections'
|
|
})
|
|
|
|
self.create_business_property('enable_nms', True, {
|
|
'description': 'Enable Non-Maximum Suppression'
|
|
})
|
|
|
|
# Result limiting
|
|
self.create_business_property('max_detections', 100, {
|
|
'min': 1,
|
|
'max': 1000,
|
|
'description': 'Maximum number of detections to keep'
|
|
})
|
|
|
|
self.create_business_property('top_k_results', 10, {
|
|
'min': 1,
|
|
'max': 100,
|
|
'description': 'Number of top results to return'
|
|
})
|
|
|
|
# Class filtering
|
|
self.create_business_property('enable_class_filter', False, {
|
|
'description': 'Enable class-based filtering'
|
|
})
|
|
|
|
self.create_business_property('allowed_classes', '', {
|
|
'placeholder': 'comma-separated class names or indices',
|
|
'description': 'Allowed class names or indices'
|
|
})
|
|
|
|
self.create_business_property('blocked_classes', '', {
|
|
'placeholder': 'comma-separated class names or indices',
|
|
'description': 'Blocked class names or indices'
|
|
})
|
|
|
|
# Output validation
|
|
self.create_business_property('validate_output', True, {
|
|
'description': 'Validate output format and structure'
|
|
})
|
|
|
|
self.create_business_property('output_schema', '', {
|
|
'placeholder': 'JSON schema for output validation',
|
|
'description': 'JSON schema for output validation'
|
|
})
|
|
|
|
# Coordinate transformation
|
|
self.create_business_property('coordinate_system', 'relative', [
|
|
'relative', # [0, 1] normalized coordinates
|
|
'absolute', # Pixel coordinates
|
|
'center', # Center-based coordinates
|
|
'custom' # Custom transformation
|
|
])
|
|
|
|
# Post-processing operations
|
|
self.create_business_property('operations', 'filter,nms,format', {
|
|
'placeholder': 'comma-separated: filter,nms,format,validate,transform',
|
|
'description': 'Ordered list of postprocessing operations'
|
|
})
|
|
|
|
# Advanced options
|
|
self.create_business_property('enable_tracking', False, {
|
|
'description': 'Enable object tracking across frames'
|
|
})
|
|
|
|
self.create_business_property('tracking_method', 'simple', [
|
|
'simple', 'kalman', 'deep_sort', 'custom'
|
|
])
|
|
|
|
self.create_business_property('enable_aggregation', False, {
|
|
'description': 'Enable result aggregation across time'
|
|
})
|
|
|
|
self.create_business_property('aggregation_window', 5, {
|
|
'min': 1,
|
|
'max': 100,
|
|
'description': 'Number of frames for aggregation'
|
|
})
|
|
|
|
def validate_configuration(self) -> tuple[bool, str]:
|
|
"""
|
|
Validate the current node configuration.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
# Check confidence threshold
|
|
confidence_threshold = self.get_property('confidence_threshold')
|
|
if not isinstance(confidence_threshold, (int, float)) or confidence_threshold < 0 or confidence_threshold > 1:
|
|
return False, "Confidence threshold must be between 0 and 1"
|
|
|
|
# Check NMS threshold
|
|
nms_threshold = self.get_property('nms_threshold')
|
|
if not isinstance(nms_threshold, (int, float)) or nms_threshold < 0 or nms_threshold > 1:
|
|
return False, "NMS threshold must be between 0 and 1"
|
|
|
|
# Check max detections
|
|
max_detections = self.get_property('max_detections')
|
|
if not isinstance(max_detections, int) or max_detections < 1:
|
|
return False, "Max detections must be at least 1"
|
|
|
|
# Validate operations string
|
|
operations = self.get_property('operations')
|
|
valid_operations = ['filter', 'nms', 'format', 'validate', 'transform', 'track', 'aggregate']
|
|
|
|
if operations:
|
|
ops_list = [op.strip() for op in operations.split(',')]
|
|
invalid_ops = [op for op in ops_list if op not in valid_operations]
|
|
if invalid_ops:
|
|
return False, f"Invalid operations: {', '.join(invalid_ops)}"
|
|
|
|
return True, ""
|
|
|
|
def get_postprocessing_config(self) -> dict:
|
|
"""
|
|
Get postprocessing configuration for pipeline execution.
|
|
|
|
Returns:
|
|
Dictionary containing postprocessing configuration
|
|
"""
|
|
return {
|
|
'node_id': self.id,
|
|
'node_name': self.name(),
|
|
'output_format': self.get_property('output_format'),
|
|
'confidence_threshold': self.get_property('confidence_threshold'),
|
|
'enable_confidence_filter': self.get_property('enable_confidence_filter'),
|
|
'nms_threshold': self.get_property('nms_threshold'),
|
|
'enable_nms': self.get_property('enable_nms'),
|
|
'max_detections': self.get_property('max_detections'),
|
|
'top_k_results': self.get_property('top_k_results'),
|
|
'enable_class_filter': self.get_property('enable_class_filter'),
|
|
'allowed_classes': self._parse_class_list(self.get_property('allowed_classes')),
|
|
'blocked_classes': self._parse_class_list(self.get_property('blocked_classes')),
|
|
'validate_output': self.get_property('validate_output'),
|
|
'output_schema': self.get_property('output_schema'),
|
|
'coordinate_system': self.get_property('coordinate_system'),
|
|
'operations': self._parse_operations_list(self.get_property('operations')),
|
|
'enable_tracking': self.get_property('enable_tracking'),
|
|
'tracking_method': self.get_property('tracking_method'),
|
|
'enable_aggregation': self.get_property('enable_aggregation'),
|
|
'aggregation_window': self.get_property('aggregation_window')
|
|
}
|
|
|
|
def _parse_class_list(self, value_str: str) -> list[str]:
|
|
"""Parse comma-separated class names or indices."""
|
|
if not value_str:
|
|
return []
|
|
return [x.strip() for x in value_str.split(',') if x.strip()]
|
|
|
|
def _parse_operations_list(self, operations_str: str) -> list[str]:
|
|
"""Parse comma-separated operations list."""
|
|
if not operations_str:
|
|
return []
|
|
return [op.strip() for op in operations_str.split(',') if op.strip()]
|
|
|
|
def get_supported_formats(self) -> list[str]:
|
|
"""Get list of supported output formats."""
|
|
return ['JSON', 'XML', 'CSV', 'Binary', 'MessagePack', 'YAML']
|
|
|
|
def get_estimated_processing_time(self, num_detections: int = None) -> float:
|
|
"""
|
|
Estimate processing time for given number of detections.
|
|
|
|
Args:
|
|
num_detections: Number of input detections
|
|
|
|
Returns:
|
|
Estimated processing time in milliseconds
|
|
"""
|
|
if num_detections is None:
|
|
num_detections = self.get_property('max_detections')
|
|
|
|
# Base processing time (ms per detection)
|
|
base_time = 0.1
|
|
|
|
# Operation-specific time factors
|
|
operations = self._parse_operations_list(self.get_property('operations'))
|
|
operation_factors = {
|
|
'filter': 0.05,
|
|
'nms': 0.5,
|
|
'format': 0.1,
|
|
'validate': 0.2,
|
|
'transform': 0.1,
|
|
'track': 1.0,
|
|
'aggregate': 0.3
|
|
}
|
|
|
|
total_factor = sum(operation_factors.get(op, 0.1) for op in operations)
|
|
|
|
return num_detections * base_time * total_factor
|
|
|
|
def estimate_output_size(self, num_detections: int = None) -> dict:
|
|
"""
|
|
Estimate output data size for different formats.
|
|
|
|
Args:
|
|
num_detections: Number of detections
|
|
|
|
Returns:
|
|
Dictionary with estimated sizes in bytes for each format
|
|
"""
|
|
if num_detections is None:
|
|
num_detections = self.get_property('max_detections')
|
|
|
|
# Estimated bytes per detection for each format
|
|
format_sizes = {
|
|
'JSON': 150, # JSON with metadata
|
|
'XML': 200, # XML with structure
|
|
'CSV': 50, # Compact CSV
|
|
'Binary': 30, # Binary format
|
|
'MessagePack': 40, # MessagePack
|
|
'YAML': 180 # YAML with structure
|
|
}
|
|
|
|
return {
|
|
format_name: size * num_detections
|
|
for format_name, size in format_sizes.items()
|
|
} |