125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify node detection methods work correctly.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# Mock Qt application for testing
|
|
import os
|
|
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
|
|
|
|
# Create a minimal Qt application
|
|
from PyQt5.QtWidgets import QApplication
|
|
import sys
|
|
app = QApplication(sys.argv)
|
|
|
|
from core.pipeline import is_model_node, is_input_node, is_output_node, get_stage_count
|
|
from core.nodes.model_node import ModelNode
|
|
from core.nodes.input_node import InputNode
|
|
from core.nodes.output_node import OutputNode
|
|
from core.nodes.preprocess_node import PreprocessNode
|
|
from core.nodes.postprocess_node import PostprocessNode
|
|
|
|
|
|
class MockNodeGraph:
|
|
"""Mock node graph for testing."""
|
|
def __init__(self):
|
|
self.nodes = []
|
|
|
|
def all_nodes(self):
|
|
return self.nodes
|
|
|
|
def add_node(self, node):
|
|
self.nodes.append(node)
|
|
|
|
|
|
def test_node_detection():
|
|
"""Test node detection methods."""
|
|
print("Testing Node Detection Methods...")
|
|
|
|
# Create node instances
|
|
input_node = InputNode()
|
|
model_node = ModelNode()
|
|
output_node = OutputNode()
|
|
preprocess_node = PreprocessNode()
|
|
postprocess_node = PostprocessNode()
|
|
|
|
# Test detection
|
|
print(f"Input node detection: {is_input_node(input_node)}")
|
|
print(f"Model node detection: {is_model_node(model_node)}")
|
|
print(f"Output node detection: {is_output_node(output_node)}")
|
|
|
|
# Test cross-detection (should be False)
|
|
print(f"Model node detected as input: {is_input_node(model_node)}")
|
|
print(f"Input node detected as model: {is_model_node(input_node)}")
|
|
print(f"Output node detected as model: {is_model_node(output_node)}")
|
|
|
|
# Test with mock graph
|
|
graph = MockNodeGraph()
|
|
graph.add_node(input_node)
|
|
graph.add_node(model_node)
|
|
graph.add_node(output_node)
|
|
|
|
stage_count = get_stage_count(graph)
|
|
print(f"Stage count: {stage_count}")
|
|
|
|
# Add another model node
|
|
model_node2 = ModelNode()
|
|
graph.add_node(model_node2)
|
|
|
|
stage_count2 = get_stage_count(graph)
|
|
print(f"Stage count after adding second model: {stage_count2}")
|
|
|
|
assert stage_count == 1, f"Expected 1 stage, got {stage_count}"
|
|
assert stage_count2 == 2, f"Expected 2 stages, got {stage_count2}"
|
|
|
|
print("✓ Node detection tests passed")
|
|
|
|
|
|
def test_node_properties():
|
|
"""Test node properties for detection."""
|
|
print("\nTesting Node Properties...")
|
|
|
|
model_node = ModelNode()
|
|
print(f"Model node type: {type(model_node)}")
|
|
print(f"Model node identifier: {getattr(model_node, '__identifier__', 'None')}")
|
|
print(f"Model node NODE_NAME: {getattr(model_node, 'NODE_NAME', 'None')}")
|
|
print(f"Has get_inference_config: {hasattr(model_node, 'get_inference_config')}")
|
|
|
|
input_node = InputNode()
|
|
print(f"Input node type: {type(input_node)}")
|
|
print(f"Input node identifier: {getattr(input_node, '__identifier__', 'None')}")
|
|
print(f"Input node NODE_NAME: {getattr(input_node, 'NODE_NAME', 'None')}")
|
|
print(f"Has get_input_config: {hasattr(input_node, 'get_input_config')}")
|
|
|
|
output_node = OutputNode()
|
|
print(f"Output node type: {type(output_node)}")
|
|
print(f"Output node identifier: {getattr(output_node, '__identifier__', 'None')}")
|
|
print(f"Output node NODE_NAME: {getattr(output_node, 'NODE_NAME', 'None')}")
|
|
print(f"Has get_output_config: {hasattr(output_node, 'get_output_config')}")
|
|
|
|
|
|
def main():
|
|
"""Run all tests."""
|
|
print("Running Node Detection Tests...")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
test_node_properties()
|
|
test_node_detection()
|
|
|
|
print("\n" + "=" * 50)
|
|
print("All tests passed! ✓")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |