""" Pipeline Editor window with stage counting functionality. This module provides the main pipeline editor interface with visual node-based pipeline design and automatic stage counting display. Main Components: - PipelineEditor: Main pipeline editor window - Stage counting display in canvas - Node graph integration - Pipeline validation and analysis Usage: from cluster4npu_ui.ui.windows.pipeline_editor import PipelineEditor editor = PipelineEditor() editor.show() """ import sys from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QStatusBar, QFrame, QPushButton, QAction, QMenuBar, QToolBar, QSplitter, QTextEdit, QMessageBox, QScrollArea) from PyQt5.QtCore import Qt, QTimer, pyqtSignal from PyQt5.QtGui import QFont, QPixmap, QIcon, QTextCursor try: from NodeGraphQt import NodeGraph from NodeGraphQt.constants import IN_PORT, OUT_PORT NODEGRAPH_AVAILABLE = True except ImportError: NODEGRAPH_AVAILABLE = False print("NodeGraphQt not available. Install with: pip install NodeGraphQt") from ...core.pipeline import get_stage_count, analyze_pipeline_stages, get_pipeline_summary from ...core.nodes.exact_nodes import ( ExactInputNode, ExactModelNode, ExactPreprocessNode, ExactPostprocessNode, ExactOutputNode ) # Keep the original imports as fallback try: from ...core.nodes.model_node import ModelNode from ...core.nodes.preprocess_node import PreprocessNode from ...core.nodes.postprocess_node import PostprocessNode from ...core.nodes.input_node import InputNode from ...core.nodes.output_node import OutputNode except ImportError: # Use ExactNodes as fallback ModelNode = ExactModelNode PreprocessNode = ExactPreprocessNode PostprocessNode = ExactPostprocessNode InputNode = ExactInputNode OutputNode = ExactOutputNode class StageCountWidget(QWidget): """Widget to display stage count information in the pipeline editor.""" def __init__(self, parent=None): super().__init__(parent) self.stage_count = 0 self.pipeline_valid = True self.pipeline_error = "" self.setup_ui() self.setFixedSize(200, 80) def setup_ui(self): """Setup the stage count widget UI.""" layout = QVBoxLayout() layout.setContentsMargins(10, 5, 10, 5) # Stage count label self.stage_label = QLabel("Stages: 0") self.stage_label.setFont(QFont("Arial", 11, QFont.Bold)) self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;") # Status label self.status_label = QLabel("Ready") self.status_label.setFont(QFont("Arial", 9)) self.status_label.setStyleSheet("color: #666666;") # Error label (initially hidden) self.error_label = QLabel("") self.error_label.setFont(QFont("Arial", 8)) self.error_label.setStyleSheet("color: #D32F2F;") self.error_label.setWordWrap(True) self.error_label.setMaximumHeight(30) self.error_label.hide() layout.addWidget(self.stage_label) layout.addWidget(self.status_label) layout.addWidget(self.error_label) self.setLayout(layout) # Style the widget self.setStyleSheet(""" StageCountWidget { background-color: #F5F5F5; border: 1px solid #E0E0E0; border-radius: 5px; } """) def update_stage_count(self, count: int, valid: bool = True, error: str = ""): """Update the stage count display.""" self.stage_count = count self.pipeline_valid = valid self.pipeline_error = error # Update stage count self.stage_label.setText(f"Stages: {count}") # Update status and styling if not valid: self.stage_label.setStyleSheet("color: #D32F2F; font-weight: bold;") self.status_label.setText("Invalid Pipeline") self.status_label.setStyleSheet("color: #D32F2F;") self.error_label.setText(error) self.error_label.show() else: self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;") if count == 0: self.status_label.setText("No stages defined") self.status_label.setStyleSheet("color: #FF8F00;") else: self.status_label.setText(f"Pipeline ready ({count} stage{'s' if count != 1 else ''})") self.status_label.setStyleSheet("color: #2E7D32;") self.error_label.hide() class PipelineEditor(QMainWindow): """ Main pipeline editor window with stage counting functionality. This window provides a visual node-based pipeline editor with automatic stage detection and counting displayed in the canvas. """ # Signals pipeline_changed = pyqtSignal() stage_count_changed = pyqtSignal(int) def __init__(self, parent=None): super().__init__(parent) self.node_graph = None self.stage_count_widget = None self.analysis_timer = None self.previous_stage_count = 0 # Track previous stage count for comparison self.setup_ui() self.setup_node_graph() self.setup_analysis_timer() # Connect signals self.pipeline_changed.connect(self.analyze_pipeline) # Initial analysis print("Pipeline Editor initialized") self.analyze_pipeline() def setup_ui(self): """Setup the main UI components.""" self.setWindowTitle("Pipeline Editor - Cluster4NPU") self.setGeometry(100, 100, 1200, 800) # Create central widget central_widget = QWidget() self.setCentralWidget(central_widget) # Create main layout main_layout = QVBoxLayout() central_widget.setLayout(main_layout) # Create splitter for main content splitter = QSplitter(Qt.Horizontal) main_layout.addWidget(splitter) # Left panel for node graph self.graph_widget = QWidget() self.graph_layout = QVBoxLayout() self.graph_widget.setLayout(self.graph_layout) splitter.addWidget(self.graph_widget) # Right panel for properties and tools right_panel = QWidget() right_panel.setMaximumWidth(300) right_layout = QVBoxLayout() right_panel.setLayout(right_layout) # Stage count widget (positioned at bottom right) self.stage_count_widget = StageCountWidget() right_layout.addWidget(self.stage_count_widget) # Properties panel properties_label = QLabel("Properties") properties_label.setFont(QFont("Arial", 10, QFont.Bold)) right_layout.addWidget(properties_label) self.properties_text = QTextEdit() self.properties_text.setMaximumHeight(200) self.properties_text.setReadOnly(True) right_layout.addWidget(self.properties_text) # Pipeline info panel info_label = QLabel("Pipeline Info") info_label.setFont(QFont("Arial", 10, QFont.Bold)) right_layout.addWidget(info_label) self.info_text = QTextEdit() self.info_text.setReadOnly(True) right_layout.addWidget(self.info_text) splitter.addWidget(right_panel) # Set splitter proportions splitter.setSizes([800, 300]) # Create toolbar self.create_toolbar() # Create status bar self.create_status_bar() # Apply styling self.apply_styling() def create_toolbar(self): """Create the toolbar with pipeline operations.""" toolbar = self.addToolBar("Pipeline Operations") # Add nodes actions add_input_action = QAction("Add Input", self) add_input_action.triggered.connect(self.add_input_node) toolbar.addAction(add_input_action) add_model_action = QAction("Add Model", self) add_model_action.triggered.connect(self.add_model_node) toolbar.addAction(add_model_action) add_preprocess_action = QAction("Add Preprocess", self) add_preprocess_action.triggered.connect(self.add_preprocess_node) toolbar.addAction(add_preprocess_action) add_postprocess_action = QAction("Add Postprocess", self) add_postprocess_action.triggered.connect(self.add_postprocess_node) toolbar.addAction(add_postprocess_action) add_output_action = QAction("Add Output", self) add_output_action.triggered.connect(self.add_output_node) toolbar.addAction(add_output_action) toolbar.addSeparator() # Pipeline actions validate_action = QAction("Validate Pipeline", self) validate_action.triggered.connect(self.validate_pipeline) toolbar.addAction(validate_action) clear_action = QAction("Clear Pipeline", self) clear_action.triggered.connect(self.clear_pipeline) toolbar.addAction(clear_action) def create_status_bar(self): """Create the status bar.""" self.status_bar = QStatusBar() self.setStatusBar(self.status_bar) self.status_bar.showMessage("Ready") def setup_node_graph(self): """Setup the node graph widget.""" if not NODEGRAPH_AVAILABLE: # Show error message error_label = QLabel("NodeGraphQt not available. Please install it to use the pipeline editor.") error_label.setAlignment(Qt.AlignCenter) error_label.setStyleSheet("color: red; font-size: 14px;") self.graph_layout.addWidget(error_label) return # Create node graph self.node_graph = NodeGraph() # Register node types - use ExactNode classes print("Registering nodes with NodeGraphQt...") # Try to register ExactNode classes first try: self.node_graph.register_node(ExactInputNode) print(f"✓ Registered ExactInputNode with identifier {ExactInputNode.__identifier__}") except Exception as e: print(f"✗ Failed to register ExactInputNode: {e}") try: self.node_graph.register_node(ExactModelNode) print(f"✓ Registered ExactModelNode with identifier {ExactModelNode.__identifier__}") except Exception as e: print(f"✗ Failed to register ExactModelNode: {e}") try: self.node_graph.register_node(ExactPreprocessNode) print(f"✓ Registered ExactPreprocessNode with identifier {ExactPreprocessNode.__identifier__}") except Exception as e: print(f"✗ Failed to register ExactPreprocessNode: {e}") try: self.node_graph.register_node(ExactPostprocessNode) print(f"✓ Registered ExactPostprocessNode with identifier {ExactPostprocessNode.__identifier__}") except Exception as e: print(f"✗ Failed to register ExactPostprocessNode: {e}") try: self.node_graph.register_node(ExactOutputNode) print(f"✓ Registered ExactOutputNode with identifier {ExactOutputNode.__identifier__}") except Exception as e: print(f"✗ Failed to register ExactOutputNode: {e}") print("Node graph setup completed successfully") # Connect node graph signals self.node_graph.node_created.connect(self.on_node_created) self.node_graph.node_deleted.connect(self.on_node_deleted) self.node_graph.connection_changed.connect(self.on_connection_changed) # Connect additional signals for more comprehensive updates if hasattr(self.node_graph, 'nodes_deleted'): self.node_graph.nodes_deleted.connect(self.on_nodes_deleted) if hasattr(self.node_graph, 'connection_sliced'): self.node_graph.connection_sliced.connect(self.on_connection_changed) # Add node graph widget to layout self.graph_layout.addWidget(self.node_graph.widget) def setup_analysis_timer(self): """Setup timer for pipeline analysis.""" self.analysis_timer = QTimer() self.analysis_timer.setSingleShot(True) self.analysis_timer.timeout.connect(self.analyze_pipeline) self.analysis_timer.setInterval(500) # 500ms delay def apply_styling(self): """Apply custom styling to the editor.""" self.setStyleSheet(""" QMainWindow { background-color: #FAFAFA; } QToolBar { background-color: #FFFFFF; border: 1px solid #E0E0E0; spacing: 5px; padding: 5px; } QToolBar QAction { padding: 5px 10px; margin: 2px; border: 1px solid #E0E0E0; border-radius: 3px; background-color: #FFFFFF; } QToolBar QAction:hover { background-color: #F5F5F5; } QTextEdit { border: 1px solid #E0E0E0; border-radius: 3px; padding: 5px; background-color: #FFFFFF; } QLabel { color: #333333; } """) def add_input_node(self): """Add an input node to the pipeline.""" if self.node_graph: print("Adding Input Node via toolbar...") # Try multiple identifier formats identifiers = [ 'com.cluster.input_node', 'com.cluster.input_node.ExactInputNode', 'com.cluster.input_node.ExactInputNode.ExactInputNode' ] node = self.create_node_with_fallback(identifiers, "Input Node") self.schedule_analysis() def add_model_node(self): """Add a model node to the pipeline.""" if self.node_graph: print("Adding Model Node via toolbar...") # Try multiple identifier formats identifiers = [ 'com.cluster.model_node', 'com.cluster.model_node.ExactModelNode', 'com.cluster.model_node.ExactModelNode.ExactModelNode' ] node = self.create_node_with_fallback(identifiers, "Model Node") self.schedule_analysis() def add_preprocess_node(self): """Add a preprocess node to the pipeline.""" if self.node_graph: print("Adding Preprocess Node via toolbar...") # Try multiple identifier formats identifiers = [ 'com.cluster.preprocess_node', 'com.cluster.preprocess_node.ExactPreprocessNode', 'com.cluster.preprocess_node.ExactPreprocessNode.ExactPreprocessNode' ] node = self.create_node_with_fallback(identifiers, "Preprocess Node") self.schedule_analysis() def add_postprocess_node(self): """Add a postprocess node to the pipeline.""" if self.node_graph: print("Adding Postprocess Node via toolbar...") # Try multiple identifier formats identifiers = [ 'com.cluster.postprocess_node', 'com.cluster.postprocess_node.ExactPostprocessNode', 'com.cluster.postprocess_node.ExactPostprocessNode.ExactPostprocessNode' ] node = self.create_node_with_fallback(identifiers, "Postprocess Node") self.schedule_analysis() def add_output_node(self): """Add an output node to the pipeline.""" if self.node_graph: print("Adding Output Node via toolbar...") # Try multiple identifier formats identifiers = [ 'com.cluster.output_node', 'com.cluster.output_node.ExactOutputNode', 'com.cluster.output_node.ExactOutputNode.ExactOutputNode' ] node = self.create_node_with_fallback(identifiers, "Output Node") self.schedule_analysis() def create_node_with_fallback(self, identifiers, node_type): """Try to create a node with multiple identifier fallbacks.""" for identifier in identifiers: try: node = self.node_graph.create_node(identifier) print(f"✓ Successfully created {node_type} with identifier: {identifier}") return node except Exception as e: continue print(f"Failed to create {node_type} with any identifier: {identifiers}") return None def validate_pipeline(self): """Validate the current pipeline configuration.""" if not self.node_graph: return print("🔍 Validating pipeline...") summary = get_pipeline_summary(self.node_graph) if summary['valid']: print(f"Pipeline validation passed - {summary['stage_count']} stages, {summary['total_nodes']} nodes") QMessageBox.information(self, "Pipeline Validation", f"Pipeline is valid!\n\n" f"Stages: {summary['stage_count']}\n" f"Total nodes: {summary['total_nodes']}") else: print(f"Pipeline validation failed: {summary['error']}") QMessageBox.warning(self, "Pipeline Validation", f"Pipeline validation failed:\n\n{summary['error']}") def clear_pipeline(self): """Clear the entire pipeline.""" if self.node_graph: print("🗑️ Clearing entire pipeline...") self.node_graph.clear_session() self.schedule_analysis() def schedule_analysis(self): """Schedule pipeline analysis after a delay.""" if self.analysis_timer: self.analysis_timer.start() def analyze_pipeline(self): """Analyze the current pipeline and update stage count.""" if not self.node_graph: return try: # Get pipeline summary summary = get_pipeline_summary(self.node_graph) current_stage_count = summary['stage_count'] # Print detailed pipeline analysis self.print_pipeline_analysis(summary, current_stage_count) # Update stage count widget self.stage_count_widget.update_stage_count( current_stage_count, summary['valid'], summary.get('error', '') ) # Update info panel self.update_info_panel(summary) # Update status bar if summary['valid']: self.status_bar.showMessage(f"Pipeline ready - {current_stage_count} stages") else: self.status_bar.showMessage(f"Pipeline invalid - {summary.get('error', 'Unknown error')}") # Update previous count for next comparison self.previous_stage_count = current_stage_count # Emit signal self.stage_count_changed.emit(current_stage_count) except Exception as e: print(f"X Pipeline analysis error: {str(e)}") self.stage_count_widget.update_stage_count(0, False, f"Analysis error: {str(e)}") self.status_bar.showMessage(f"Analysis error: {str(e)}") def print_pipeline_analysis(self, summary, current_stage_count): """Print detailed pipeline analysis to terminal.""" # Check if stage count changed if current_stage_count != self.previous_stage_count: if self.previous_stage_count == 0 and current_stage_count > 0: print(f"Initial stage count: {current_stage_count}") elif current_stage_count != self.previous_stage_count: change = current_stage_count - self.previous_stage_count if change > 0: print(f"Stage count increased: {self.previous_stage_count} → {current_stage_count} (+{change})") else: print(f"Stage count decreased: {self.previous_stage_count} → {current_stage_count} ({change})") # Always print current pipeline status for clarity print(f"Current Pipeline Status:") print(f" • Stages: {current_stage_count}") print(f" • Total Nodes: {summary['total_nodes']}") print(f" • Model Nodes: {summary['model_nodes']}") print(f" • Input Nodes: {summary['input_nodes']}") print(f" • Output Nodes: {summary['output_nodes']}") print(f" • Preprocess Nodes: {summary['preprocess_nodes']}") print(f" • Postprocess Nodes: {summary['postprocess_nodes']}") print(f" • Valid: {'V' if summary['valid'] else 'X'}") if not summary['valid'] and summary.get('error'): print(f" • Error: {summary['error']}") # Print stage details if available if summary.get('stages') and len(summary['stages']) > 0: print(f"Stage Details:") for i, stage in enumerate(summary['stages'], 1): model_name = stage['model_config'].get('node_name', 'Unknown Model') preprocess_count = len(stage['preprocess_configs']) postprocess_count = len(stage['postprocess_configs']) stage_info = f" Stage {i}: {model_name}" if preprocess_count > 0: stage_info += f" (with {preprocess_count} preprocess)" if postprocess_count > 0: stage_info += f" (with {postprocess_count} postprocess)" print(stage_info) elif current_stage_count > 0: print(f"{current_stage_count} stage(s) detected but details not available") print("─" * 50) # Separator line def update_info_panel(self, summary): """Update the pipeline info panel with analysis results.""" info_text = f"""Pipeline Analysis: Stage Count: {summary['stage_count']} Valid: {'Yes' if summary['valid'] else 'No'} {f"Error: {summary['error']}" if summary.get('error') else ""} Node Statistics: - Total Nodes: {summary['total_nodes']} - Input Nodes: {summary['input_nodes']} - Model Nodes: {summary['model_nodes']} - Preprocess Nodes: {summary['preprocess_nodes']} - Postprocess Nodes: {summary['postprocess_nodes']} - Output Nodes: {summary['output_nodes']} Stages:""" for i, stage in enumerate(summary.get('stages', []), 1): info_text += f"\n Stage {i}: {stage['model_config']['node_name']}" if stage['preprocess_configs']: info_text += f" (with {len(stage['preprocess_configs'])} preprocess)" if stage['postprocess_configs']: info_text += f" (with {len(stage['postprocess_configs'])} postprocess)" self.info_text.setPlainText(info_text) def on_node_created(self, node): """Handle node creation.""" node_type = self.get_node_type_name(node) print(f"+ Node added: {node_type}") self.schedule_analysis() def on_node_deleted(self, node): """Handle node deletion.""" node_type = self.get_node_type_name(node) print(f"- Node removed: {node_type}") self.schedule_analysis() def on_nodes_deleted(self, nodes): """Handle multiple node deletion.""" node_types = [self.get_node_type_name(node) for node in nodes] print(f"- Multiple nodes removed: {', '.join(node_types)}") self.schedule_analysis() def on_connection_changed(self, input_port, output_port): """Handle connection changes.""" print(f"🔗 Connection changed: {input_port} <-> {output_port}") self.schedule_analysis() def get_node_type_name(self, node): """Get a readable name for the node type.""" if hasattr(node, 'NODE_NAME'): return node.NODE_NAME elif hasattr(node, '__identifier__'): # Convert identifier to readable name identifier = node.__identifier__ if 'model' in identifier: return "Model Node" elif 'input' in identifier: return "Input Node" elif 'output' in identifier: return "Output Node" elif 'preprocess' in identifier: return "Preprocess Node" elif 'postprocess' in identifier: return "Postprocess Node" # Fallback to class name return type(node).__name__ def get_current_stage_count(self): """Get the current stage count.""" return self.stage_count_widget.stage_count if self.stage_count_widget else 0 def get_pipeline_summary(self): """Get the current pipeline summary.""" if self.node_graph: return get_pipeline_summary(self.node_graph) return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'} def main(): """Main function for testing the pipeline editor.""" from PyQt5.QtWidgets import QApplication app = QApplication(sys.argv) editor = PipelineEditor() editor.show() sys.exit(app.exec_()) if __name__ == '__main__': main()