# """ # Pipeline Editor window with stage counting functionality. # This module provides the main pipeline editor interface with visual node-based # pipeline design and automatic stage counting display. # Main Components: # - PipelineEditor: Main pipeline editor window # - Stage counting display in canvas # - Node graph integration # - Pipeline validation and analysis # Usage: # from cluster4npu_ui.ui.windows.pipeline_editor import PipelineEditor # editor = PipelineEditor() # editor.show() # """ # import sys # from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, # QLabel, QStatusBar, QFrame, QPushButton, QAction, # QMenuBar, QToolBar, QSplitter, QTextEdit, QMessageBox, # QScrollArea) # from PyQt5.QtCore import Qt, QTimer, pyqtSignal # from PyQt5.QtGui import QFont, QPixmap, QIcon, QTextCursor # try: # from NodeGraphQt import NodeGraph # from NodeGraphQt.constants import IN_PORT, OUT_PORT # NODEGRAPH_AVAILABLE = True # except ImportError: # NODEGRAPH_AVAILABLE = False # print("NodeGraphQt not available. Install with: pip install NodeGraphQt") # from ...core.pipeline import get_stage_count, analyze_pipeline_stages, get_pipeline_summary # from ...core.nodes.exact_nodes import ( # ExactInputNode, ExactModelNode, ExactPreprocessNode, # ExactPostprocessNode, ExactOutputNode # ) # # Keep the original imports as fallback # try: # from ...core.nodes.model_node import ModelNode # from ...core.nodes.preprocess_node import PreprocessNode # from ...core.nodes.postprocess_node import PostprocessNode # from ...core.nodes.input_node import InputNode # from ...core.nodes.output_node import OutputNode # except ImportError: # # Use ExactNodes as fallback # ModelNode = ExactModelNode # PreprocessNode = ExactPreprocessNode # PostprocessNode = ExactPostprocessNode # InputNode = ExactInputNode # OutputNode = ExactOutputNode # class StageCountWidget(QWidget): # """Widget to display stage count information in the pipeline editor.""" # def __init__(self, parent=None): # super().__init__(parent) # self.stage_count = 0 # self.pipeline_valid = True # self.pipeline_error = "" # self.setup_ui() # self.setFixedSize(200, 80) # def setup_ui(self): # """Setup the stage count widget UI.""" # layout = QVBoxLayout() # layout.setContentsMargins(10, 5, 10, 5) # # Stage count label # self.stage_label = QLabel("Stages: 0") # self.stage_label.setFont(QFont("Arial", 11, QFont.Bold)) # self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;") # # Status label # self.status_label = QLabel("Ready") # self.status_label.setFont(QFont("Arial", 9)) # self.status_label.setStyleSheet("color: #666666;") # # Error label (initially hidden) # self.error_label = QLabel("") # self.error_label.setFont(QFont("Arial", 8)) # self.error_label.setStyleSheet("color: #D32F2F;") # self.error_label.setWordWrap(True) # self.error_label.setMaximumHeight(30) # self.error_label.hide() # layout.addWidget(self.stage_label) # layout.addWidget(self.status_label) # layout.addWidget(self.error_label) # self.setLayout(layout) # # Style the widget # self.setStyleSheet(""" # StageCountWidget { # background-color: #F5F5F5; # border: 1px solid #E0E0E0; # border-radius: 5px; # } # """) # def update_stage_count(self, count: int, valid: bool = True, error: str = ""): # """Update the stage count display.""" # self.stage_count = count # self.pipeline_valid = valid # self.pipeline_error = error # # Update stage count # self.stage_label.setText(f"Stages: {count}") # # Update status and styling # if not valid: # self.stage_label.setStyleSheet("color: #D32F2F; font-weight: bold;") # self.status_label.setText("Invalid Pipeline") # self.status_label.setStyleSheet("color: #D32F2F;") # self.error_label.setText(error) # self.error_label.show() # else: # self.stage_label.setStyleSheet("color: #2E7D32; font-weight: bold;") # if count == 0: # self.status_label.setText("No stages defined") # self.status_label.setStyleSheet("color: #FF8F00;") # else: # self.status_label.setText(f"Pipeline ready ({count} stage{'s' if count != 1 else ''})") # self.status_label.setStyleSheet("color: #2E7D32;") # self.error_label.hide() # class PipelineEditor(QMainWindow): # """ # Main pipeline editor window with stage counting functionality. # This window provides a visual node-based pipeline editor with automatic # stage detection and counting displayed in the canvas. # """ # # Signals # pipeline_changed = pyqtSignal() # stage_count_changed = pyqtSignal(int) # def __init__(self, parent=None): # super().__init__(parent) # self.node_graph = None # self.stage_count_widget = None # self.analysis_timer = None # self.previous_stage_count = 0 # Track previous stage count for comparison # self.setup_ui() # self.setup_node_graph() # self.setup_analysis_timer() # # Connect signals # self.pipeline_changed.connect(self.analyze_pipeline) # # Initial analysis # print("Pipeline Editor initialized") # self.analyze_pipeline() # def setup_ui(self): # """Setup the main UI components.""" # self.setWindowTitle("Pipeline Editor - Cluster4NPU") # self.setGeometry(100, 100, 1200, 800) # # Create central widget # central_widget = QWidget() # self.setCentralWidget(central_widget) # # Create main layout # main_layout = QVBoxLayout() # central_widget.setLayout(main_layout) # # Create splitter for main content # splitter = QSplitter(Qt.Horizontal) # main_layout.addWidget(splitter) # # Left panel for node graph # self.graph_widget = QWidget() # self.graph_layout = QVBoxLayout() # self.graph_widget.setLayout(self.graph_layout) # splitter.addWidget(self.graph_widget) # # Right panel for properties and tools # right_panel = QWidget() # right_panel.setMaximumWidth(300) # right_layout = QVBoxLayout() # right_panel.setLayout(right_layout) # # Stage count widget (positioned at bottom right) # self.stage_count_widget = StageCountWidget() # right_layout.addWidget(self.stage_count_widget) # # Properties panel # properties_label = QLabel("Properties") # properties_label.setFont(QFont("Arial", 10, QFont.Bold)) # right_layout.addWidget(properties_label) # self.properties_text = QTextEdit() # self.properties_text.setMaximumHeight(200) # self.properties_text.setReadOnly(True) # right_layout.addWidget(self.properties_text) # # Pipeline info panel # info_label = QLabel("Pipeline Info") # info_label.setFont(QFont("Arial", 10, QFont.Bold)) # right_layout.addWidget(info_label) # self.info_text = QTextEdit() # self.info_text.setReadOnly(True) # right_layout.addWidget(self.info_text) # splitter.addWidget(right_panel) # # Set splitter proportions # splitter.setSizes([800, 300]) # # Create toolbar # self.create_toolbar() # # Create status bar # self.create_status_bar() # # Apply styling # self.apply_styling() # def create_toolbar(self): # """Create the toolbar with pipeline operations.""" # toolbar = self.addToolBar("Pipeline Operations") # # Add nodes actions # add_input_action = QAction("Add Input", self) # add_input_action.triggered.connect(self.add_input_node) # toolbar.addAction(add_input_action) # add_model_action = QAction("Add Model", self) # add_model_action.triggered.connect(self.add_model_node) # toolbar.addAction(add_model_action) # add_preprocess_action = QAction("Add Preprocess", self) # add_preprocess_action.triggered.connect(self.add_preprocess_node) # toolbar.addAction(add_preprocess_action) # add_postprocess_action = QAction("Add Postprocess", self) # add_postprocess_action.triggered.connect(self.add_postprocess_node) # toolbar.addAction(add_postprocess_action) # add_output_action = QAction("Add Output", self) # add_output_action.triggered.connect(self.add_output_node) # toolbar.addAction(add_output_action) # toolbar.addSeparator() # # Pipeline actions # validate_action = QAction("Validate Pipeline", self) # validate_action.triggered.connect(self.validate_pipeline) # toolbar.addAction(validate_action) # clear_action = QAction("Clear Pipeline", self) # clear_action.triggered.connect(self.clear_pipeline) # toolbar.addAction(clear_action) # def create_status_bar(self): # """Create the status bar.""" # self.status_bar = QStatusBar() # self.setStatusBar(self.status_bar) # self.status_bar.showMessage("Ready") # def setup_node_graph(self): # """Setup the node graph widget.""" # if not NODEGRAPH_AVAILABLE: # # Show error message # error_label = QLabel("NodeGraphQt not available. Please install it to use the pipeline editor.") # error_label.setAlignment(Qt.AlignCenter) # error_label.setStyleSheet("color: red; font-size: 14px;") # self.graph_layout.addWidget(error_label) # return # # Create node graph # self.node_graph = NodeGraph() # # Register node types - use ExactNode classes # print("Registering nodes with NodeGraphQt...") # # Try to register ExactNode classes first # try: # self.node_graph.register_node(ExactInputNode) # print(f"✓ Registered ExactInputNode with identifier {ExactInputNode.__identifier__}") # except Exception as e: # print(f"✗ Failed to register ExactInputNode: {e}") # try: # self.node_graph.register_node(ExactModelNode) # print(f"✓ Registered ExactModelNode with identifier {ExactModelNode.__identifier__}") # except Exception as e: # print(f"✗ Failed to register ExactModelNode: {e}") # try: # self.node_graph.register_node(ExactPreprocessNode) # print(f"✓ Registered ExactPreprocessNode with identifier {ExactPreprocessNode.__identifier__}") # except Exception as e: # print(f"✗ Failed to register ExactPreprocessNode: {e}") # try: # self.node_graph.register_node(ExactPostprocessNode) # print(f"✓ Registered ExactPostprocessNode with identifier {ExactPostprocessNode.__identifier__}") # except Exception as e: # print(f"✗ Failed to register ExactPostprocessNode: {e}") # try: # self.node_graph.register_node(ExactOutputNode) # print(f"✓ Registered ExactOutputNode with identifier {ExactOutputNode.__identifier__}") # except Exception as e: # print(f"✗ Failed to register ExactOutputNode: {e}") # print("Node graph setup completed successfully") # # Connect node graph signals # self.node_graph.node_created.connect(self.on_node_created) # self.node_graph.node_deleted.connect(self.on_node_deleted) # self.node_graph.connection_changed.connect(self.on_connection_changed) # # Connect additional signals for more comprehensive updates # if hasattr(self.node_graph, 'nodes_deleted'): # self.node_graph.nodes_deleted.connect(self.on_nodes_deleted) # if hasattr(self.node_graph, 'connection_sliced'): # self.node_graph.connection_sliced.connect(self.on_connection_changed) # # Add node graph widget to layout # self.graph_layout.addWidget(self.node_graph.widget) # def setup_analysis_timer(self): # """Setup timer for pipeline analysis.""" # self.analysis_timer = QTimer() # self.analysis_timer.setSingleShot(True) # self.analysis_timer.timeout.connect(self.analyze_pipeline) # self.analysis_timer.setInterval(500) # 500ms delay # def apply_styling(self): # """Apply custom styling to the editor.""" # self.setStyleSheet(""" # QMainWindow { # background-color: #FAFAFA; # } # QToolBar { # background-color: #FFFFFF; # border: 1px solid #E0E0E0; # spacing: 5px; # padding: 5px; # } # QToolBar QAction { # padding: 5px 10px; # margin: 2px; # border: 1px solid #E0E0E0; # border-radius: 3px; # background-color: #FFFFFF; # } # QToolBar QAction:hover { # background-color: #F5F5F5; # } # QTextEdit { # border: 1px solid #E0E0E0; # border-radius: 3px; # padding: 5px; # background-color: #FFFFFF; # } # QLabel { # color: #333333; # } # """) # def add_input_node(self): # """Add an input node to the pipeline.""" # if self.node_graph: # print("Adding Input Node via toolbar...") # # Try multiple identifier formats # identifiers = [ # 'com.cluster.input_node', # 'com.cluster.input_node.ExactInputNode', # 'com.cluster.input_node.ExactInputNode.ExactInputNode' # ] # node = self.create_node_with_fallback(identifiers, "Input Node") # self.schedule_analysis() # def add_model_node(self): # """Add a model node to the pipeline.""" # if self.node_graph: # print("Adding Model Node via toolbar...") # # Try multiple identifier formats # identifiers = [ # 'com.cluster.model_node', # 'com.cluster.model_node.ExactModelNode', # 'com.cluster.model_node.ExactModelNode.ExactModelNode' # ] # node = self.create_node_with_fallback(identifiers, "Model Node") # self.schedule_analysis() # def add_preprocess_node(self): # """Add a preprocess node to the pipeline.""" # if self.node_graph: # print("Adding Preprocess Node via toolbar...") # # Try multiple identifier formats # identifiers = [ # 'com.cluster.preprocess_node', # 'com.cluster.preprocess_node.ExactPreprocessNode', # 'com.cluster.preprocess_node.ExactPreprocessNode.ExactPreprocessNode' # ] # node = self.create_node_with_fallback(identifiers, "Preprocess Node") # self.schedule_analysis() # def add_postprocess_node(self): # """Add a postprocess node to the pipeline.""" # if self.node_graph: # print("Adding Postprocess Node via toolbar...") # # Try multiple identifier formats # identifiers = [ # 'com.cluster.postprocess_node', # 'com.cluster.postprocess_node.ExactPostprocessNode', # 'com.cluster.postprocess_node.ExactPostprocessNode.ExactPostprocessNode' # ] # node = self.create_node_with_fallback(identifiers, "Postprocess Node") # self.schedule_analysis() # def add_output_node(self): # """Add an output node to the pipeline.""" # if self.node_graph: # print("Adding Output Node via toolbar...") # # Try multiple identifier formats # identifiers = [ # 'com.cluster.output_node', # 'com.cluster.output_node.ExactOutputNode', # 'com.cluster.output_node.ExactOutputNode.ExactOutputNode' # ] # node = self.create_node_with_fallback(identifiers, "Output Node") # self.schedule_analysis() # def create_node_with_fallback(self, identifiers, node_type): # """Try to create a node with multiple identifier fallbacks.""" # for identifier in identifiers: # try: # node = self.node_graph.create_node(identifier) # print(f"✓ Successfully created {node_type} with identifier: {identifier}") # return node # except Exception as e: # continue # print(f"Failed to create {node_type} with any identifier: {identifiers}") # return None # def validate_pipeline(self): # """Validate the current pipeline configuration.""" # if not self.node_graph: # return # print("🔍 Validating pipeline...") # summary = get_pipeline_summary(self.node_graph) # if summary['valid']: # print(f"Pipeline validation passed - {summary['stage_count']} stages, {summary['total_nodes']} nodes") # QMessageBox.information(self, "Pipeline Validation", # f"Pipeline is valid!\n\n" # f"Stages: {summary['stage_count']}\n" # f"Total nodes: {summary['total_nodes']}") # else: # print(f"Pipeline validation failed: {summary['error']}") # QMessageBox.warning(self, "Pipeline Validation", # f"Pipeline validation failed:\n\n{summary['error']}") # def clear_pipeline(self): # """Clear the entire pipeline.""" # if self.node_graph: # print("🗑️ Clearing entire pipeline...") # self.node_graph.clear_session() # self.schedule_analysis() # def schedule_analysis(self): # """Schedule pipeline analysis after a delay.""" # if self.analysis_timer: # self.analysis_timer.start() # def analyze_pipeline(self): # """Analyze the current pipeline and update stage count.""" # if not self.node_graph: # return # try: # # Get pipeline summary # summary = get_pipeline_summary(self.node_graph) # current_stage_count = summary['stage_count'] # # Print detailed pipeline analysis # self.print_pipeline_analysis(summary, current_stage_count) # # Update stage count widget # self.stage_count_widget.update_stage_count( # current_stage_count, # summary['valid'], # summary.get('error', '') # ) # # Update info panel # self.update_info_panel(summary) # # Update status bar # if summary['valid']: # self.status_bar.showMessage(f"Pipeline ready - {current_stage_count} stages") # else: # self.status_bar.showMessage(f"Pipeline invalid - {summary.get('error', 'Unknown error')}") # # Update previous count for next comparison # self.previous_stage_count = current_stage_count # # Emit signal # self.stage_count_changed.emit(current_stage_count) # except Exception as e: # print(f"X Pipeline analysis error: {str(e)}") # self.stage_count_widget.update_stage_count(0, False, f"Analysis error: {str(e)}") # self.status_bar.showMessage(f"Analysis error: {str(e)}") # def print_pipeline_analysis(self, summary, current_stage_count): # """Print detailed pipeline analysis to terminal.""" # # Check if stage count changed # if current_stage_count != self.previous_stage_count: # if self.previous_stage_count == 0 and current_stage_count > 0: # print(f"Initial stage count: {current_stage_count}") # elif current_stage_count != self.previous_stage_count: # change = current_stage_count - self.previous_stage_count # if change > 0: # print(f"Stage count increased: {self.previous_stage_count} → {current_stage_count} (+{change})") # else: # print(f"Stage count decreased: {self.previous_stage_count} → {current_stage_count} ({change})") # # Always print current pipeline status for clarity # print(f"Current Pipeline Status:") # print(f" • Stages: {current_stage_count}") # print(f" • Total Nodes: {summary['total_nodes']}") # print(f" • Model Nodes: {summary['model_nodes']}") # print(f" • Input Nodes: {summary['input_nodes']}") # print(f" • Output Nodes: {summary['output_nodes']}") # print(f" • Preprocess Nodes: {summary['preprocess_nodes']}") # print(f" • Postprocess Nodes: {summary['postprocess_nodes']}") # print(f" • Valid: {'V' if summary['valid'] else 'X'}") # if not summary['valid'] and summary.get('error'): # print(f" • Error: {summary['error']}") # # Print stage details if available # if summary.get('stages') and len(summary['stages']) > 0: # print(f"Stage Details:") # for i, stage in enumerate(summary['stages'], 1): # model_name = stage['model_config'].get('node_name', 'Unknown Model') # preprocess_count = len(stage['preprocess_configs']) # postprocess_count = len(stage['postprocess_configs']) # stage_info = f" Stage {i}: {model_name}" # if preprocess_count > 0: # stage_info += f" (with {preprocess_count} preprocess)" # if postprocess_count > 0: # stage_info += f" (with {postprocess_count} postprocess)" # print(stage_info) # elif current_stage_count > 0: # print(f"{current_stage_count} stage(s) detected but details not available") # print("─" * 50) # Separator line # def update_info_panel(self, summary): # """Update the pipeline info panel with analysis results.""" # info_text = f"""Pipeline Analysis: # Stage Count: {summary['stage_count']} # Valid: {'Yes' if summary['valid'] else 'No'} # {f"Error: {summary['error']}" if summary.get('error') else ""} # Node Statistics: # - Total Nodes: {summary['total_nodes']} # - Input Nodes: {summary['input_nodes']} # - Model Nodes: {summary['model_nodes']} # - Preprocess Nodes: {summary['preprocess_nodes']} # - Postprocess Nodes: {summary['postprocess_nodes']} # - Output Nodes: {summary['output_nodes']} # Stages:""" # for i, stage in enumerate(summary.get('stages', []), 1): # info_text += f"\n Stage {i}: {stage['model_config']['node_name']}" # if stage['preprocess_configs']: # info_text += f" (with {len(stage['preprocess_configs'])} preprocess)" # if stage['postprocess_configs']: # info_text += f" (with {len(stage['postprocess_configs'])} postprocess)" # self.info_text.setPlainText(info_text) # def on_node_created(self, node): # """Handle node creation.""" # node_type = self.get_node_type_name(node) # print(f"+ Node added: {node_type}") # self.schedule_analysis() # def on_node_deleted(self, node): # """Handle node deletion.""" # node_type = self.get_node_type_name(node) # print(f"- Node removed: {node_type}") # self.schedule_analysis() # def on_nodes_deleted(self, nodes): # """Handle multiple node deletion.""" # node_types = [self.get_node_type_name(node) for node in nodes] # print(f"- Multiple nodes removed: {', '.join(node_types)}") # self.schedule_analysis() # def on_connection_changed(self, input_port, output_port): # """Handle connection changes.""" # print(f"🔗 Connection changed: {input_port} <-> {output_port}") # self.schedule_analysis() # def get_node_type_name(self, node): # """Get a readable name for the node type.""" # if hasattr(node, 'NODE_NAME'): # return node.NODE_NAME # elif hasattr(node, '__identifier__'): # # Convert identifier to readable name # identifier = node.__identifier__ # if 'model' in identifier: # return "Model Node" # elif 'input' in identifier: # return "Input Node" # elif 'output' in identifier: # return "Output Node" # elif 'preprocess' in identifier: # return "Preprocess Node" # elif 'postprocess' in identifier: # return "Postprocess Node" # # Fallback to class name # return type(node).__name__ # def get_current_stage_count(self): # """Get the current stage count.""" # return self.stage_count_widget.stage_count if self.stage_count_widget else 0 # def get_pipeline_summary(self): # """Get the current pipeline summary.""" # if self.node_graph: # return get_pipeline_summary(self.node_graph) # return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'} # def main(): # """Main function for testing the pipeline editor.""" # from PyQt5.QtWidgets import QApplication # app = QApplication(sys.argv) # editor = PipelineEditor() # editor.show() # sys.exit(app.exec_()) # if __name__ == '__main__': # main()