cluster4npu/multi_series_dongle_manager.py

558 lines
22 KiB
Python

import kp
from collections import defaultdict, deque
from typing import Union, Dict, List, Tuple, Optional, Any
import os
import sys
import time
import threading
import queue
import numpy as np
import cv2
from dataclasses import dataclass
from enum import Enum
@dataclass
class InferenceTask:
sequence_id: int
image_data: np.ndarray
image_format: kp.ImageFormat
timestamp: float
@dataclass
class InferenceResult:
sequence_id: int
result: Any
dongle_series: str
timestamp: float
class DongleSeriesSpec:
KL520_GOPS = 3
KL720_GOPS = 28
SERIES_SPECS = {
0x100: {"name": "KL520", "gops": KL520_GOPS},
0x720: {"name": "KL720", "gops": KL720_GOPS},
0x630: {"name": "KL630", "gops": 400},
0x730: {"name": "KL730", "gops": 1600},
0x540: {"name": "KL540", "gops": 800}
}
class MultiSeriesDongleManager:
def __init__(self, max_queue_size: int = 100, result_buffer_size: int = 1000):
self.dongle_groups = {} # product_id -> device_group
self.dongle_specs = {} # product_id -> spec info
self.model_descriptors = {} # product_id -> model descriptor
self.inference_descriptors = {} # product_id -> inference descriptor
# Load balancing
self.gops_weights = {} # product_id -> normalized weight
self.current_loads = {} # product_id -> current queue size
# Threading and queues
self.input_queue = queue.Queue(maxsize=max_queue_size)
self.result_queues = {} # product_id -> queue
self.ordered_output_queue = queue.Queue()
# Sequence management
self.sequence_counter = 0
self.sequence_lock = threading.Lock()
self.pending_results = {} # sequence_id -> InferenceResult
self.next_output_sequence = 0
self.result_buffer_size = result_buffer_size
# Threading
self.stop_event = threading.Event()
self.dispatcher_thread = None
self.send_threads = {} # product_id -> thread
self.receive_threads = {} # product_id -> thread
self.result_ordering_thread = None
# Statistics
self.stats = {
'total_dispatched': 0,
'total_completed': 0,
'dongle_stats': {} # product_id -> {'sent': count, 'received': count}
}
def scan_and_initialize_devices(self, firmware_paths: Dict[str, Dict[str, str]],
model_paths: Dict[str, str]) -> bool:
"""
Scan, connect, and initialize all available devices
Args:
firmware_paths: {"KL520": {"scpu": path, "ncpu": path}, "KL720": {...}}
model_paths: {"KL520": model_path, "KL720": model_path}
"""
device_list = kp.core.scan_devices()
if not device_list or device_list.device_descriptor_number == 0:
print("No devices found")
return False
# Group devices by product_id
grouped_devices = defaultdict(list)
for device in device_list.device_descriptor_list:
grouped_devices[device.product_id].append(device.usb_port_id)
print(f"Found device groups: {dict(grouped_devices)}")
# Connect and initialize each group
total_gops = 0
for product_id, port_ids in grouped_devices.items():
if product_id not in DongleSeriesSpec.SERIES_SPECS:
print(f"Unknown product ID: {hex(product_id)}")
continue
series_info = DongleSeriesSpec.SERIES_SPECS[product_id]
series_name = series_info["name"]
try:
# Connect device group
device_group = kp.core.connect_devices(port_ids)
self.dongle_groups[product_id] = device_group
self.dongle_specs[product_id] = series_info
# Initialize statistics
self.stats['dongle_stats'][product_id] = {'sent': 0, 'received': 0}
self.current_loads[product_id] = 0
print(f"Connected to {series_name} group with ports {port_ids}")
# Set timeout
kp.core.set_timeout(device_group=device_group, milliseconds=5000)
# Upload firmware if provided
if series_name in firmware_paths:
fw_paths = firmware_paths[series_name]
print(f"[{series_name}] Uploading firmware...")
kp.core.load_firmware_from_file(
device_group=device_group,
scpu_fw_path=fw_paths["scpu"],
ncpu_fw_path=fw_paths["ncpu"]
)
print(f"[{series_name}] Firmware upload success")
# Upload model
if series_name in model_paths:
print(f"[{series_name}] Uploading model...")
model_descriptor = kp.core.load_model_from_file(
device_group=device_group,
file_path=model_paths[series_name]
)
self.model_descriptors[product_id] = model_descriptor
# Store model descriptor for later use
# Note: inference descriptors will be created per task
print(f"[{series_name}] Model upload success")
# Create result queue for this dongle
self.result_queues[product_id] = queue.Queue()
total_gops += series_info["gops"] * len(port_ids)
except kp.ApiKPException as e:
print(f"Failed to initialize {series_name}: {e}")
return False
# Calculate load balancing weights based on GOPS
for product_id, spec in self.dongle_specs.items():
port_count = len(grouped_devices[product_id])
effective_gops = spec["gops"] * port_count
self.gops_weights[product_id] = effective_gops / total_gops
print(f"Load balancing weights (by GOPS): {self.gops_weights}")
return True
def start(self):
"""Start all processing threads"""
if not self.dongle_groups:
raise RuntimeError("No dongles initialized. Call scan_and_initialize_devices first.")
self.stop_event.clear()
# Start dispatcher thread
self.dispatcher_thread = threading.Thread(target=self._dispatcher_worker, daemon=True)
self.dispatcher_thread.start()
# Start send/receive threads for each dongle
for product_id in self.dongle_groups.keys():
# Send thread
send_thread = threading.Thread(
target=self._send_worker,
args=(product_id,),
daemon=True
)
send_thread.start()
self.send_threads[product_id] = send_thread
# Receive thread
receive_thread = threading.Thread(
target=self._receive_worker,
args=(product_id,),
daemon=True
)
receive_thread.start()
self.receive_threads[product_id] = receive_thread
# Start result ordering thread
self.result_ordering_thread = threading.Thread(target=self._result_ordering_worker, daemon=True)
self.result_ordering_thread.start()
print(f"Started MultiSeriesDongleManager with {len(self.dongle_groups)} dongle series")
def stop(self):
"""Stop all threads and disconnect devices"""
print("Stopping MultiSeriesDongleManager...")
self.stop_event.set()
# Join all threads with timeout
threads_to_join = [
(self.dispatcher_thread, "Dispatcher"),
(self.result_ordering_thread, "Result Ordering")
]
for product_id in self.dongle_groups.keys():
threads_to_join.extend([
(self.send_threads.get(product_id), f"Send-{hex(product_id)}"),
(self.receive_threads.get(product_id), f"Receive-{hex(product_id)}")
])
for thread, name in threads_to_join:
if thread and thread.is_alive():
thread.join(timeout=2.0)
if thread.is_alive():
print(f"Warning: {name} thread didn't stop cleanly")
# Disconnect device groups
for product_id, device_group in self.dongle_groups.items():
try:
kp.core.disconnect_devices(device_group)
print(f"Disconnected {hex(product_id)} device group")
except kp.ApiKPException as e:
print(f"Error disconnecting {hex(product_id)}: {e}")
self.dongle_groups.clear()
def put_input(self, image: Union[str, np.ndarray], image_format: str = 'BGR565') -> int:
"""
Submit an image for inference
Returns:
int: sequence_id for tracking this inference
"""
# Process image input
if isinstance(image, str):
image_data = cv2.imread(image)
image_data = cv2.cvtColor(image_data, cv2.COLOR_RGB2BGR565)
if image_data is None:
raise FileNotFoundError(f"Image file not found: {image}")
elif isinstance(image, np.ndarray):
image_data = image.copy()
else:
raise ValueError("Image must be file path or numpy array")
# Convert format string to enum
format_mapping = {
'BGR565': kp.ImageFormat.KP_IMAGE_FORMAT_RGB565,
# 'RGB888': kp.ImageFormat.KP_IMAGE_FORMAT_RGB888,
# 'BGR888': kp.ImageFormat.KP_IMAGE_FORMAT_RGB888, # OpenCV uses BGR by default
'RGB8888': kp.ImageFormat.KP_IMAGE_FORMAT_RGBA8888,
'YUYV': kp.ImageFormat.KP_IMAGE_FORMAT_YUYV,
'RAW8': kp.ImageFormat.KP_IMAGE_FORMAT_RAW8
}
image_format_enum = format_mapping.get(image_format)
if image_format_enum is None:
raise ValueError(f"Unsupported format: {image_format}")
# Generate sequence ID
with self.sequence_lock:
sequence_id = self.sequence_counter
self.sequence_counter += 1
# Create task and enqueue
task = InferenceTask(
sequence_id=sequence_id,
image_data=image_data,
image_format=image_format_enum,
timestamp=time.time()
)
self.input_queue.put(task)
self.stats['total_dispatched'] += 1
return sequence_id
def get_result(self, timeout: float = None) -> Optional[InferenceResult]:
"""Get next inference result in original order"""
try:
return self.ordered_output_queue.get(block=timeout is not None, timeout=timeout)
except queue.Empty:
return None
def _dispatcher_worker(self):
"""Dispatcher thread: assigns tasks to dongles based on load balancing"""
print("Dispatcher thread started")
while not self.stop_event.is_set():
try:
task = self.input_queue.get(timeout=0.1)
if task is None: # Sentinel value
continue
# Select optimal dongle based on current load and capacity
selected_product_id = self._select_optimal_dongle()
if selected_product_id is None:
continue
# Enqueue to selected dongle
self.result_queues[selected_product_id].put(task)
self.current_loads[selected_product_id] += 1
except queue.Empty:
continue
except Exception as e:
print(f"Error in dispatcher: {e}")
print("Dispatcher thread stopped")
def _select_optimal_dongle(self) -> Optional[int]:
"""Select dongle with best load/capacity ratio"""
if not self.dongle_groups:
return None
best_ratio = float('inf')
selected_product_id = None
for product_id in self.dongle_groups.keys():
current_load = self.current_loads[product_id]
weight = self.gops_weights[product_id]
# Calculate load ratio (lower is better)
load_ratio = current_load / weight if weight > 0 else float('inf')
if load_ratio < best_ratio:
best_ratio = load_ratio
selected_product_id = product_id
return selected_product_id
def _send_worker(self, product_id: int):
"""Send thread for specific dongle series"""
series_name = self.dongle_specs[product_id]["name"]
print(f"Send worker started for {series_name}")
device_group = self.dongle_groups[product_id]
result_queue = self.result_queues[product_id]
while not self.stop_event.is_set():
try:
task = result_queue.get(timeout=0.1)
if task is None:
continue
# print(f"Sending task {task.sequence_id} to {series_name}")
inference_descriptor = kp.GenericImageInferenceDescriptor(
model_id=self.model_descriptors[product_id].models[0].id,
)
# print(f"Using model ID: {inference_descriptor.model_id}")
inference_descriptor.inference_number = task.sequence_id
# print(f"Task sequence ID: {task.sequence_id}")
# Create new inference descriptor for each task to avoid state issues
# print ("image data: ", task.image_data, ", format: ", task.image_format)
# print(device_group)
inference_descriptor.input_node_image_list = [
kp.GenericInputNodeImage(
image=task.image_data,
image_format=task.image_format,
resize_mode=kp.ResizeMode.KP_RESIZE_ENABLE,
padding_mode=kp.PaddingMode.KP_PADDING_CORNER,
normalize_mode=kp.NormalizeMode.KP_NORMALIZE_KNERON
)
]
# print(f"Task image shape: {task.image_data.shape}, format: {task.image_format}")
# Send inference
kp.inference.generic_image_inference_send(
device_group=device_group,
generic_inference_input_descriptor=inference_descriptor
)
print(f"Task {task.sequence_id} sent successfully to {series_name}")
self.stats['dongle_stats'][product_id]['sent'] += 1
except queue.Empty:
continue
except kp.ApiKPException as e:
print(f"Error in {series_name} send worker: {e}")
self.stop_event.set()
except Exception as e:
print(f"Unexpected error in {series_name} send worker: {e}")
print(f"Send worker stopped for {series_name}")
def _receive_worker(self, product_id: int):
"""Receive thread for specific dongle series"""
series_name = self.dongle_specs[product_id]["name"]
print(f"Receive worker started for {series_name}")
device_group = self.dongle_groups[product_id]
while not self.stop_event.is_set():
try:
# Receive inference result
raw_result = kp.inference.generic_image_inference_receive(device_group=device_group)
# Create result object
result = InferenceResult(
sequence_id=raw_result.header.inference_number,
result=raw_result,
dongle_series=series_name,
timestamp=time.time()
)
# Add to pending results for ordering
self.pending_results[result.sequence_id] = result
self.current_loads[product_id] = max(0, self.current_loads[product_id] - 1)
self.stats['dongle_stats'][product_id]['received'] += 1
except kp.ApiKPException as e:
if not self.stop_event.is_set():
print(f"Error in {series_name} receive worker: {e}")
self.stop_event.set()
except Exception as e:
print(f"Unexpected error in {series_name} receive worker: {e}")
print(f"Receive worker stopped for {series_name}")
def _result_ordering_worker(self):
"""Result ordering thread: ensures results are output in sequence order"""
print("Result ordering worker started")
while not self.stop_event.is_set():
# Check if next expected result is available
if self.next_output_sequence in self.pending_results:
result = self.pending_results.pop(self.next_output_sequence)
self.ordered_output_queue.put(result)
self.next_output_sequence += 1
self.stats['total_completed'] += 1
# Clean up old pending results to prevent memory bloat
if len(self.pending_results) > self.result_buffer_size:
oldest_sequences = sorted(self.pending_results.keys())[:self.result_buffer_size // 2]
for seq_id in oldest_sequences:
if seq_id < self.next_output_sequence:
self.pending_results.pop(seq_id, None)
else:
time.sleep(0.001) # Small delay to prevent busy waiting
print("Result ordering worker stopped")
def get_statistics(self) -> Dict:
"""Get performance statistics"""
stats = self.stats.copy()
stats['pending_results'] = len(self.pending_results)
stats['input_queue_size'] = self.input_queue.qsize()
stats['current_loads'] = self.current_loads.copy()
stats['gops_weights'] = self.gops_weights.copy()
return stats
def print_statistics(self):
"""Print current statistics"""
stats = self.get_statistics()
print(f"\n=== MultiSeriesDongleManager Statistics ===")
print(f"Total dispatched: {stats['total_dispatched']}")
print(f"Total completed: {stats['total_completed']}")
print(f"Pending results: {stats['pending_results']}")
print(f"Input queue size: {stats['input_queue_size']}")
print(f"\nPer-dongle statistics:")
for product_id, dongle_stats in stats['dongle_stats'].items():
series_name = self.dongle_specs[product_id]["name"]
current_load = stats['current_loads'][product_id]
weight = stats['gops_weights'][product_id]
print(f" {series_name}: sent={dongle_stats['sent']}, received={dongle_stats['received']}, "
f"current_load={current_load}, weight={weight:.3f}")
# Example usage and test
if __name__ == "__main__":
# Configuration
firmware_paths = {
"KL520": {
"scpu": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\firmware\KL520\fw_scpu.bin",
"ncpu": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\firmware\KL520\fw_ncpu.bin"
},
"KL720": {
"scpu": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\firmware\KL720\fw_scpu.bin",
"ncpu": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\firmware\KL720\fw_ncpu.bin"
}
}
model_paths = {
"KL520": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\models\KL520\yolov5-noupsample_w640h640_kn-model-zoo\kl520_20005_yolov5-noupsample_w640h640.nef",
"KL720": r"C:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\models\KL720\yolov5-noupsample_w640h640_kn-model-zoo\kl720_20005_yolov5-noupsample_w640h640.nef"
}
image_path = r"c:\Users\mason\Downloads\kneron_plus_v3.1.2\kneron_plus\res\images\people_talk_in_street_640x640.bmp"
try:
# Initialize manager
manager = MultiSeriesDongleManager()
if not manager.scan_and_initialize_devices(firmware_paths, model_paths):
print("Failed to initialize devices")
sys.exit(1)
# Start processing
manager.start()
# Submit test images
num_images = 50
print(f"\nSubmitting {num_images} images for inference...")
start_time = time.time()
sequence_ids = []
for i in range(num_images):
seq_id = manager.put_input(image_path, 'BGR565')
sequence_ids.append(seq_id)
if (i + 1) % 10 == 0:
print(f"Submitted {i + 1} images")
# Collect results
print(f"\nCollecting results...")
results = []
for i in range(num_images):
result = manager.get_result(timeout=10.0)
if result:
results.append(result)
if (i + 1) % 10 == 0:
print(f"Received {i + 1} results")
else:
print(f"Timeout waiting for result {i + 1}")
break
end_time = time.time()
# Print results
print(f"\n=== Test Results ===")
print(f"Total time: {end_time - start_time:.2f} seconds")
print(f"Average FPS: {len(results) / (end_time - start_time):.2f}")
print(f"Results received: {len(results)}/{num_images}")
# Verify ordering
is_ordered = all(results[i].sequence_id == sequence_ids[i] for i in range(len(results)))
print(f"Results in correct order: {is_ordered}")
# Show dongle utilization
manager.print_statistics()
except KeyboardInterrupt:
print("\nInterrupted by user")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
finally:
if 'manager' in locals():
manager.stop()