Cluster/core/nodes/postprocess_node.py
2025-07-17 17:04:56 +08:00

286 lines
11 KiB
Python

"""
Postprocessing node implementation for output transformation operations.
This module provides the PostprocessNode class which handles output postprocessing
operations in the pipeline, including result filtering, format conversion, and
output validation.
Main Components:
- PostprocessNode: Core postprocessing node implementation
- Result filtering and validation
- Output format conversion
Usage:
from cluster4npu_ui.core.nodes.postprocess_node import PostprocessNode
node = PostprocessNode()
node.set_property('output_format', 'JSON')
node.set_property('confidence_threshold', 0.5)
"""
from .base_node import BaseNodeWithProperties
class PostprocessNode(BaseNodeWithProperties):
"""
Postprocessing node for output transformation operations.
This node handles various postprocessing operations including result filtering,
format conversion, confidence thresholding, and output validation.
"""
__identifier__ = 'com.cluster.postprocess_node'
NODE_NAME = 'Postprocess Node'
def __init__(self):
super().__init__()
# Setup node connections
self.add_input('input', multi_input=False, color=(255, 140, 0))
self.add_output('output', color=(0, 255, 0))
self.set_color(153, 51, 51)
# Initialize properties
self.setup_properties()
def setup_properties(self):
"""Initialize postprocessing-specific properties."""
# Output format
self.create_business_property('output_format', 'JSON', [
'JSON', 'XML', 'CSV', 'Binary', 'MessagePack', 'YAML'
])
# Confidence filtering
self.create_business_property('confidence_threshold', 0.5, {
'min': 0.0,
'max': 1.0,
'step': 0.01,
'description': 'Minimum confidence threshold for results'
})
self.create_business_property('enable_confidence_filter', True, {
'description': 'Enable confidence-based filtering'
})
# NMS (Non-Maximum Suppression)
self.create_business_property('nms_threshold', 0.4, {
'min': 0.0,
'max': 1.0,
'step': 0.01,
'description': 'NMS threshold for overlapping detections'
})
self.create_business_property('enable_nms', True, {
'description': 'Enable Non-Maximum Suppression'
})
# Result limiting
self.create_business_property('max_detections', 100, {
'min': 1,
'max': 1000,
'description': 'Maximum number of detections to keep'
})
self.create_business_property('top_k_results', 10, {
'min': 1,
'max': 100,
'description': 'Number of top results to return'
})
# Class filtering
self.create_business_property('enable_class_filter', False, {
'description': 'Enable class-based filtering'
})
self.create_business_property('allowed_classes', '', {
'placeholder': 'comma-separated class names or indices',
'description': 'Allowed class names or indices'
})
self.create_business_property('blocked_classes', '', {
'placeholder': 'comma-separated class names or indices',
'description': 'Blocked class names or indices'
})
# Output validation
self.create_business_property('validate_output', True, {
'description': 'Validate output format and structure'
})
self.create_business_property('output_schema', '', {
'placeholder': 'JSON schema for output validation',
'description': 'JSON schema for output validation'
})
# Coordinate transformation
self.create_business_property('coordinate_system', 'relative', [
'relative', # [0, 1] normalized coordinates
'absolute', # Pixel coordinates
'center', # Center-based coordinates
'custom' # Custom transformation
])
# Post-processing operations
self.create_business_property('operations', 'filter,nms,format', {
'placeholder': 'comma-separated: filter,nms,format,validate,transform',
'description': 'Ordered list of postprocessing operations'
})
# Advanced options
self.create_business_property('enable_tracking', False, {
'description': 'Enable object tracking across frames'
})
self.create_business_property('tracking_method', 'simple', [
'simple', 'kalman', 'deep_sort', 'custom'
])
self.create_business_property('enable_aggregation', False, {
'description': 'Enable result aggregation across time'
})
self.create_business_property('aggregation_window', 5, {
'min': 1,
'max': 100,
'description': 'Number of frames for aggregation'
})
def validate_configuration(self) -> tuple[bool, str]:
"""
Validate the current node configuration.
Returns:
Tuple of (is_valid, error_message)
"""
# Check confidence threshold
confidence_threshold = self.get_property('confidence_threshold')
if not isinstance(confidence_threshold, (int, float)) or confidence_threshold < 0 or confidence_threshold > 1:
return False, "Confidence threshold must be between 0 and 1"
# Check NMS threshold
nms_threshold = self.get_property('nms_threshold')
if not isinstance(nms_threshold, (int, float)) or nms_threshold < 0 or nms_threshold > 1:
return False, "NMS threshold must be between 0 and 1"
# Check max detections
max_detections = self.get_property('max_detections')
if not isinstance(max_detections, int) or max_detections < 1:
return False, "Max detections must be at least 1"
# Validate operations string
operations = self.get_property('operations')
valid_operations = ['filter', 'nms', 'format', 'validate', 'transform', 'track', 'aggregate']
if operations:
ops_list = [op.strip() for op in operations.split(',')]
invalid_ops = [op for op in ops_list if op not in valid_operations]
if invalid_ops:
return False, f"Invalid operations: {', '.join(invalid_ops)}"
return True, ""
def get_postprocessing_config(self) -> dict:
"""
Get postprocessing configuration for pipeline execution.
Returns:
Dictionary containing postprocessing configuration
"""
return {
'node_id': self.id,
'node_name': self.name(),
'output_format': self.get_property('output_format'),
'confidence_threshold': self.get_property('confidence_threshold'),
'enable_confidence_filter': self.get_property('enable_confidence_filter'),
'nms_threshold': self.get_property('nms_threshold'),
'enable_nms': self.get_property('enable_nms'),
'max_detections': self.get_property('max_detections'),
'top_k_results': self.get_property('top_k_results'),
'enable_class_filter': self.get_property('enable_class_filter'),
'allowed_classes': self._parse_class_list(self.get_property('allowed_classes')),
'blocked_classes': self._parse_class_list(self.get_property('blocked_classes')),
'validate_output': self.get_property('validate_output'),
'output_schema': self.get_property('output_schema'),
'coordinate_system': self.get_property('coordinate_system'),
'operations': self._parse_operations_list(self.get_property('operations')),
'enable_tracking': self.get_property('enable_tracking'),
'tracking_method': self.get_property('tracking_method'),
'enable_aggregation': self.get_property('enable_aggregation'),
'aggregation_window': self.get_property('aggregation_window')
}
def _parse_class_list(self, value_str: str) -> list[str]:
"""Parse comma-separated class names or indices."""
if not value_str:
return []
return [x.strip() for x in value_str.split(',') if x.strip()]
def _parse_operations_list(self, operations_str: str) -> list[str]:
"""Parse comma-separated operations list."""
if not operations_str:
return []
return [op.strip() for op in operations_str.split(',') if op.strip()]
def get_supported_formats(self) -> list[str]:
"""Get list of supported output formats."""
return ['JSON', 'XML', 'CSV', 'Binary', 'MessagePack', 'YAML']
def get_estimated_processing_time(self, num_detections: int = None) -> float:
"""
Estimate processing time for given number of detections.
Args:
num_detections: Number of input detections
Returns:
Estimated processing time in milliseconds
"""
if num_detections is None:
num_detections = self.get_property('max_detections')
# Base processing time (ms per detection)
base_time = 0.1
# Operation-specific time factors
operations = self._parse_operations_list(self.get_property('operations'))
operation_factors = {
'filter': 0.05,
'nms': 0.5,
'format': 0.1,
'validate': 0.2,
'transform': 0.1,
'track': 1.0,
'aggregate': 0.3
}
total_factor = sum(operation_factors.get(op, 0.1) for op in operations)
return num_detections * base_time * total_factor
def estimate_output_size(self, num_detections: int = None) -> dict:
"""
Estimate output data size for different formats.
Args:
num_detections: Number of detections
Returns:
Dictionary with estimated sizes in bytes for each format
"""
if num_detections is None:
num_detections = self.get_property('max_detections')
# Estimated bytes per detection for each format
format_sizes = {
'JSON': 150, # JSON with metadata
'XML': 200, # XML with structure
'CSV': 50, # Compact CSV
'Binary': 30, # Binary format
'MessagePack': 40, # MessagePack
'YAML': 180 # YAML with structure
}
return {
format_name: size * num_detections
for format_name, size in format_sizes.items()
}