47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Dict, Protocol
|
|
|
|
|
|
class QuantizationBackend(Protocol):
|
|
def analyze(
|
|
self,
|
|
onnx_path: str,
|
|
input_mapping: Dict,
|
|
output_dir: str,
|
|
**kwargs,
|
|
) -> str:
|
|
"""Run quantization and return the generated BIE path."""
|
|
|
|
|
|
class KneronQuantizationBackend:
|
|
def analyze(
|
|
self,
|
|
onnx_path: str,
|
|
input_mapping: Dict,
|
|
output_dir: str,
|
|
**kwargs,
|
|
) -> str:
|
|
import ktc
|
|
|
|
model = kwargs.get("onnx_model")
|
|
if model is None:
|
|
import onnx
|
|
|
|
model = onnx.load(onnx_path)
|
|
model = ktc.onnx_optimizer.onnx2onnx_flow(model, eliminate_tail=True, opt_matmul=True)
|
|
|
|
km = ktc.ModelConfig(
|
|
kwargs["model_id"],
|
|
kwargs["version"],
|
|
kwargs["platform"],
|
|
onnx_model=model,
|
|
)
|
|
return km.analysis(input_mapping, output_dir=output_dir)
|
|
|
|
|
|
def get_quantization_backend(name: str | None = None) -> QuantizationBackend:
|
|
# Placeholder for future backend selection logic.
|
|
_ = name
|
|
return KneronQuantizationBackend()
|