Cluster/core/pipeline.py
2025-07-17 17:04:56 +08:00

545 lines
19 KiB
Python

"""
Pipeline stage analysis and management functionality.
This module provides functions to analyze pipeline node connections and automatically
determine the number of stages in a pipeline. Each stage consists of a model node
with optional preprocessing and postprocessing nodes.
Main Components:
- Stage detection and analysis
- Pipeline structure validation
- Stage configuration generation
- Connection path analysis
Usage:
from cluster4npu_ui.core.pipeline import analyze_pipeline_stages, get_stage_count
stage_count = get_stage_count(node_graph)
stages = analyze_pipeline_stages(node_graph)
"""
from typing import List, Dict, Any, Optional, Tuple
from .nodes.model_node import ModelNode
from .nodes.preprocess_node import PreprocessNode
from .nodes.postprocess_node import PostprocessNode
from .nodes.input_node import InputNode
from .nodes.output_node import OutputNode
class PipelineStage:
"""Represents a single stage in the pipeline."""
def __init__(self, stage_id: int, model_node: ModelNode):
self.stage_id = stage_id
self.model_node = model_node
self.preprocess_nodes: List[PreprocessNode] = []
self.postprocess_nodes: List[PostprocessNode] = []
self.input_connections = []
self.output_connections = []
def add_preprocess_node(self, node: PreprocessNode):
"""Add a preprocessing node to this stage."""
self.preprocess_nodes.append(node)
def add_postprocess_node(self, node: PostprocessNode):
"""Add a postprocessing node to this stage."""
self.postprocess_nodes.append(node)
def get_stage_config(self) -> Dict[str, Any]:
"""Get configuration for this stage."""
# Get model config safely
model_config = {}
try:
if hasattr(self.model_node, 'get_inference_config'):
model_config = self.model_node.get_inference_config()
else:
model_config = {'node_name': getattr(self.model_node, 'NODE_NAME', 'Unknown Model')}
except:
model_config = {'node_name': 'Unknown Model'}
# Get preprocess configs safely
preprocess_configs = []
for node in self.preprocess_nodes:
try:
if hasattr(node, 'get_preprocessing_config'):
preprocess_configs.append(node.get_preprocessing_config())
else:
preprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Preprocess')})
except:
preprocess_configs.append({'node_name': 'Unknown Preprocess'})
# Get postprocess configs safely
postprocess_configs = []
for node in self.postprocess_nodes:
try:
if hasattr(node, 'get_postprocessing_config'):
postprocess_configs.append(node.get_postprocessing_config())
else:
postprocess_configs.append({'node_name': getattr(node, 'NODE_NAME', 'Unknown Postprocess')})
except:
postprocess_configs.append({'node_name': 'Unknown Postprocess'})
config = {
'stage_id': self.stage_id,
'model_config': model_config,
'preprocess_configs': preprocess_configs,
'postprocess_configs': postprocess_configs
}
return config
def validate_stage(self) -> Tuple[bool, str]:
"""Validate this stage configuration."""
# Validate model node
is_valid, error = self.model_node.validate_configuration()
if not is_valid:
return False, f"Stage {self.stage_id} model error: {error}"
# Validate preprocessing nodes
for i, node in enumerate(self.preprocess_nodes):
is_valid, error = node.validate_configuration()
if not is_valid:
return False, f"Stage {self.stage_id} preprocess {i} error: {error}"
# Validate postprocessing nodes
for i, node in enumerate(self.postprocess_nodes):
is_valid, error = node.validate_configuration()
if not is_valid:
return False, f"Stage {self.stage_id} postprocess {i} error: {error}"
return True, ""
def find_connected_nodes(node, visited=None, direction='forward'):
"""
Find all nodes connected to a given node.
Args:
node: Starting node
visited: Set of already visited nodes
direction: 'forward' for outputs, 'backward' for inputs
Returns:
List of connected nodes
"""
if visited is None:
visited = set()
if node in visited:
return []
visited.add(node)
connected = []
if direction == 'forward':
# Get connected output nodes
for output in node.outputs():
for connected_input in output.connected_inputs():
connected_node = connected_input.node()
if connected_node not in visited:
connected.append(connected_node)
connected.extend(find_connected_nodes(connected_node, visited, direction))
else:
# Get connected input nodes
for input_port in node.inputs():
for connected_output in input_port.connected_outputs():
connected_node = connected_output.node()
if connected_node not in visited:
connected.append(connected_node)
connected.extend(find_connected_nodes(connected_node, visited, direction))
return connected
def analyze_pipeline_stages(node_graph) -> List[PipelineStage]:
"""
Analyze a node graph to identify pipeline stages.
Each stage consists of:
1. A model node (required) that is connected in the pipeline flow
2. Optional preprocessing nodes (before model)
3. Optional postprocessing nodes (after model)
Args:
node_graph: NodeGraphQt graph object
Returns:
List of PipelineStage objects
"""
stages = []
all_nodes = node_graph.all_nodes()
# Find all model nodes - these define the stages
model_nodes = []
input_nodes = []
output_nodes = []
for node in all_nodes:
# Detect model nodes
if is_model_node(node):
model_nodes.append(node)
# Detect input nodes
elif is_input_node(node):
input_nodes.append(node)
# Detect output nodes
elif is_output_node(node):
output_nodes.append(node)
if not input_nodes or not output_nodes:
return [] # Invalid pipeline - must have input and output
# Use all model nodes when we have valid input/output structure
# Simplified approach: if we have input and output nodes, count all model nodes as stages
connected_model_nodes = model_nodes # Use all model nodes
# For nodes without connections, just create stages in the order they appear
try:
# Sort model nodes by their position in the pipeline
model_nodes_with_distance = []
for model_node in connected_model_nodes:
# Calculate distance from input nodes
distance = calculate_distance_from_input(model_node, input_nodes)
model_nodes_with_distance.append((model_node, distance))
# Sort by distance from input (closest first)
model_nodes_with_distance.sort(key=lambda x: x[1])
# Create stages
for stage_id, (model_node, _) in enumerate(model_nodes_with_distance, 1):
stage = PipelineStage(stage_id, model_node)
# Find preprocessing nodes (nodes that connect to this model but aren't models themselves)
preprocess_nodes = find_preprocess_nodes_for_model(model_node, all_nodes)
for preprocess_node in preprocess_nodes:
stage.add_preprocess_node(preprocess_node)
# Find postprocessing nodes (nodes that this model connects to but aren't models)
postprocess_nodes = find_postprocess_nodes_for_model(model_node, all_nodes)
for postprocess_node in postprocess_nodes:
stage.add_postprocess_node(postprocess_node)
stages.append(stage)
except Exception as e:
# Fallback: just create simple stages for all model nodes
print(f"Warning: Pipeline distance calculation failed ({e}), using simple stage creation")
for stage_id, model_node in enumerate(connected_model_nodes, 1):
stage = PipelineStage(stage_id, model_node)
stages.append(stage)
return stages
def calculate_distance_from_input(target_node, input_nodes):
"""Calculate the shortest distance from any input node to the target node."""
min_distance = float('inf')
for input_node in input_nodes:
distance = find_shortest_path_distance(input_node, target_node)
if distance < min_distance:
min_distance = distance
return min_distance if min_distance != float('inf') else 0
def find_shortest_path_distance(start_node, target_node, visited=None, distance=0):
"""Find shortest path distance between two nodes."""
if visited is None:
visited = set()
if start_node == target_node:
return distance
if start_node in visited:
return float('inf')
visited.add(start_node)
min_distance = float('inf')
# Check all connected nodes - handle nodes without proper connections
try:
if hasattr(start_node, 'outputs'):
for output in start_node.outputs():
if hasattr(output, 'connected_inputs'):
for connected_input in output.connected_inputs():
if hasattr(connected_input, 'node'):
connected_node = connected_input.node()
if connected_node not in visited:
path_distance = find_shortest_path_distance(
connected_node, target_node, visited.copy(), distance + 1
)
min_distance = min(min_distance, path_distance)
except:
# If there's any error in path finding, return a default distance
pass
return min_distance
def find_preprocess_nodes_for_model(model_node, all_nodes):
"""Find preprocessing nodes that connect to the given model node."""
preprocess_nodes = []
# Get all nodes that connect to the model's inputs
for input_port in model_node.inputs():
for connected_output in input_port.connected_outputs():
connected_node = connected_output.node()
if isinstance(connected_node, PreprocessNode):
preprocess_nodes.append(connected_node)
return preprocess_nodes
def find_postprocess_nodes_for_model(model_node, all_nodes):
"""Find postprocessing nodes that the given model node connects to."""
postprocess_nodes = []
# Get all nodes that the model connects to
for output in model_node.outputs():
for connected_input in output.connected_inputs():
connected_node = connected_input.node()
if isinstance(connected_node, PostprocessNode):
postprocess_nodes.append(connected_node)
return postprocess_nodes
def is_model_node(node):
"""Check if a node is a model node using multiple detection methods."""
if hasattr(node, '__identifier__'):
identifier = node.__identifier__
if 'model' in identifier.lower():
return True
if hasattr(node, 'type_') and 'model' in str(node.type_).lower():
return True
if hasattr(node, 'NODE_NAME') and 'model' in str(node.NODE_NAME).lower():
return True
if 'model' in str(type(node)).lower():
return True
# Check if it's our ModelNode class
if hasattr(node, 'get_inference_config'):
return True
# Check for ExactModelNode
if 'exactmodel' in str(type(node)).lower():
return True
return False
def is_input_node(node):
"""Check if a node is an input node using multiple detection methods."""
if hasattr(node, '__identifier__'):
identifier = node.__identifier__
if 'input' in identifier.lower():
return True
if hasattr(node, 'type_') and 'input' in str(node.type_).lower():
return True
if hasattr(node, 'NODE_NAME') and 'input' in str(node.NODE_NAME).lower():
return True
if 'input' in str(type(node)).lower():
return True
# Check if it's our InputNode class
if hasattr(node, 'get_input_config'):
return True
# Check for ExactInputNode
if 'exactinput' in str(type(node)).lower():
return True
return False
def is_output_node(node):
"""Check if a node is an output node using multiple detection methods."""
if hasattr(node, '__identifier__'):
identifier = node.__identifier__
if 'output' in identifier.lower():
return True
if hasattr(node, 'type_') and 'output' in str(node.type_).lower():
return True
if hasattr(node, 'NODE_NAME') and 'output' in str(node.NODE_NAME).lower():
return True
if 'output' in str(type(node)).lower():
return True
# Check if it's our OutputNode class
if hasattr(node, 'get_output_config'):
return True
# Check for ExactOutputNode
if 'exactoutput' in str(type(node)).lower():
return True
return False
def get_stage_count(node_graph) -> int:
"""
Get the number of stages in a pipeline.
Args:
node_graph: NodeGraphQt graph object
Returns:
Number of stages (model nodes) in the pipeline
"""
if not node_graph:
return 0
all_nodes = node_graph.all_nodes()
# Use robust detection for model nodes
model_nodes = [node for node in all_nodes if is_model_node(node)]
return len(model_nodes)
def validate_pipeline_structure(node_graph) -> Tuple[bool, str]:
"""
Validate the overall pipeline structure.
Args:
node_graph: NodeGraphQt graph object
Returns:
Tuple of (is_valid, error_message)
"""
if not node_graph:
return False, "No pipeline graph provided"
all_nodes = node_graph.all_nodes()
# Check for required node types using our detection functions
input_nodes = [node for node in all_nodes if is_input_node(node)]
output_nodes = [node for node in all_nodes if is_output_node(node)]
model_nodes = [node for node in all_nodes if is_model_node(node)]
if not input_nodes:
return False, "Pipeline must have at least one input node"
if not output_nodes:
return False, "Pipeline must have at least one output node"
if not model_nodes:
return False, "Pipeline must have at least one model node"
# Skip connectivity checks for now since nodes may not have proper connections
# In a real NodeGraphQt environment, this would check actual connections
return True, ""
def is_node_connected_to_pipeline(node, input_nodes, output_nodes):
"""Check if a node is connected to both input and output sides of the pipeline."""
# Check if there's a path from any input to this node
connected_to_input = any(
has_path_between_nodes(input_node, node) for input_node in input_nodes
)
# Check if there's a path from this node to any output
connected_to_output = any(
has_path_between_nodes(node, output_node) for output_node in output_nodes
)
return connected_to_input and connected_to_output
def has_path_between_nodes(start_node, end_node, visited=None):
"""Check if there's a path between two nodes."""
if visited is None:
visited = set()
if start_node == end_node:
return True
if start_node in visited:
return False
visited.add(start_node)
# Check all connected nodes
try:
if hasattr(start_node, 'outputs'):
for output in start_node.outputs():
if hasattr(output, 'connected_inputs'):
for connected_input in output.connected_inputs():
if hasattr(connected_input, 'node'):
connected_node = connected_input.node()
if has_path_between_nodes(connected_node, end_node, visited):
return True
elif hasattr(output, 'connected_ports'):
# Alternative connection method
for connected_port in output.connected_ports():
if hasattr(connected_port, 'node'):
connected_node = connected_port.node()
if has_path_between_nodes(connected_node, end_node, visited):
return True
except Exception:
# If there's any error accessing connections, assume no path
pass
return False
def get_pipeline_summary(node_graph) -> Dict[str, Any]:
"""
Get a summary of the pipeline structure.
Args:
node_graph: NodeGraphQt graph object
Returns:
Dictionary containing pipeline summary information
"""
if not node_graph:
return {'stage_count': 0, 'valid': False, 'error': 'No pipeline graph'}
all_nodes = node_graph.all_nodes()
# Count nodes by type using robust detection
input_count = 0
output_count = 0
model_count = 0
preprocess_count = 0
postprocess_count = 0
for node in all_nodes:
# Detect input nodes
if is_input_node(node):
input_count += 1
# Detect output nodes
elif is_output_node(node):
output_count += 1
# Detect model nodes
elif is_model_node(node):
model_count += 1
# Detect preprocess nodes
elif ((hasattr(node, '__identifier__') and 'preprocess' in node.__identifier__.lower()) or \
(hasattr(node, 'type_') and 'preprocess' in str(node.type_).lower()) or \
(hasattr(node, 'NODE_NAME') and 'preprocess' in str(node.NODE_NAME).lower()) or \
('preprocess' in str(type(node)).lower()) or \
('exactpreprocess' in str(type(node)).lower()) or \
hasattr(node, 'get_preprocessing_config')):
preprocess_count += 1
# Detect postprocess nodes
elif ((hasattr(node, '__identifier__') and 'postprocess' in node.__identifier__.lower()) or \
(hasattr(node, 'type_') and 'postprocess' in str(node.type_).lower()) or \
(hasattr(node, 'NODE_NAME') and 'postprocess' in str(node.NODE_NAME).lower()) or \
('postprocess' in str(type(node)).lower()) or \
('exactpostprocess' in str(type(node)).lower()) or \
hasattr(node, 'get_postprocessing_config')):
postprocess_count += 1
stages = analyze_pipeline_stages(node_graph)
is_valid, error = validate_pipeline_structure(node_graph)
return {
'stage_count': len(stages),
'valid': is_valid,
'error': error if not is_valid else None,
'stages': [stage.get_stage_config() for stage in stages],
'total_nodes': len(all_nodes),
'input_nodes': input_count,
'output_nodes': output_count,
'model_nodes': model_count,
'preprocess_nodes': preprocess_count,
'postprocess_nodes': postprocess_count
}