diff --git a/configs/_base_/datasets/kn_cityscapes.py b/configs/_base_/datasets/kn_cityscapes.py new file mode 100644 index 0000000..e15ad34 --- /dev/null +++ b/configs/_base_/datasets/kn_cityscapes.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +img_norm_cfg = dict( + mean=[128., 128., 128.], std=[256., 256., 256.], to_rgb=True) +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + # img_scale=(2048, 1024), + img_scale=(1024, 512), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/train', + ann_dir='gtFine/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/configs/stdc/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py b/configs/stdc/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py new file mode 100644 index 0000000..9d12e27 --- /dev/null +++ b/configs/stdc/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py @@ -0,0 +1,14 @@ +checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth' # noqa +_base_ = [ + '../_base_/models/stdc.py', '../_base_/datasets/kn_cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +lr_config = dict(warmup='linear', warmup_iters=1000) +data = dict( + samples_per_gpu=12, + workers_per_gpu=4, +) +model = dict( + backbone=dict( + backbone_cfg=dict( + init_cfg=dict(type='Pretrained', checkpoint=checkpoint)))) diff --git a/docs_kneron/stdc_step_by_step.md b/docs_kneron/stdc_step_by_step.md new file mode 100644 index 0000000..ed9d1d7 --- /dev/null +++ b/docs_kneron/stdc_step_by_step.md @@ -0,0 +1,331 @@ +# Step 1: Environment + +## Step 1-1: Prerequisites + +- Python 3.6+ +- PyTorch 1.3+ (We recommend you installing PyTorch using Conda following the [Official PyTorch Installation Instruction](https://pytorch.org/)) +- (Optional) CUDA 9.2+ (If you installed PyTorch with cuda using Conda following the [Official PyTorch Installation Instruction](https://pytorch.org/), you can skip CUDA installation) +- (Optional, used to build from source) GCC 5+ +- [mmcv-full](https://mmcv.readthedocs.io/en/latest/#installation) (Note: not `mmcv`!) + +**Note:** You need to run `pip uninstall mmcv` first if you have `mmcv` installed. +If mmcv and mmcv-full are both installed, there will be `ModuleNotFoundError`. + +## Step 1-2: Install MMSegmentationKN + +### Step 1-2-1: Install PyTorch + +You can follow [Official PyTorch Installation Instruction](https://pytorch.org/) to install PyTorch using Conda: + +```shell +conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -y +``` + +### Step 1-2-2: Install mmcv-full + +We recommend you installing mmcv-full using pip: + +```shell +pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html +``` + +Please replace `cu113` and `torch1.11.0` in the url to your desired one. For example, to install the `mmcv-full` with `CUDA 11.1` and `PyTorch 1.9.0`, use the following command: + +```shell +pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html +``` + +If you see error messages while installing mmcv-full, please check if your installation instruction matches your installed version of PyTorch and Cuda, and see [MMCV pip Installation Instruction](https://github.com/open-mmlab/mmcv#install-with-pip) for different versions of MMCV compatible to different PyTorch and CUDA versions. + +### Step 1-2-3: Clone MMSegmentationKN Repository + +```shell +git clone https://github.com/kneron/MMSegmentationKN.git +cd MMSegmentationKN +``` + +### Step 1-2-4: Install Required Python Packages for Building and Installing MMSegmentationKN + +```shell +pip install -r requirements_kneron.txt +pip install -v -e . # or "python setup.py develop" +``` + +# Step 2: Training Models on Standard Datasets + +MMSegmentationKN provides many existing and existing semantic segmentation models in [Model Zoo](https://mmsegmentation.readthedocs.io/en/latest/model_zoo.html), and supports several standard datasets like CityScapes, Pascal Context, Coco Stuff, ADE20K, etc. Here we demonstrate how to train *STDC-Seg*, a semantic segmentation algorithm, on *CityScapes*, a well-known semantic segmentation dataset. + +## Step 2-1: Download CityScapes Dataset + +1. Go to [CityScapes Official Website](https://www.cityscapes-dataset.com) and click *Download* link on the top of the page. If you're not logged in, it will navigate you to login page. +2. If it is the first time you visiting CityScapes website, to download CityScapes dataset, you have to register an account. +3. Click the *Register* link and it will navigate you to the registeration page. +4. Fill in all the *required* fields, accept the terms and conditions, and click the *Register* button. If everything goes well, you will see *Registration Successful* on the page and recieve a registration confirmation mail in your email inbox. +5. Click on the link provided in the confirmation mail, login with your newly registered account and password, and you should be able to download the CityScapes dataset. +6. Download *leftImg8bit_trainvaltest.zip* (images) and *gtFine_trainvaltest.zip* (labels) and place them onto your server. + +## Step 2-2: Dataset Preparation + +We suggest that you extract the zipped files to somewhere outside the project directory and symlink (`ln`) the dataset root to `MMSegmentationKN/data` so you can use the dataset outside this project, as shown below: + +```shell +# Replace all "path/to/your" below with where you want to put the dataset! + +# Extracting Cityscapes +mkdir -p path/to/your/cityscapes +unzip leftImg8bit_trainvaltest.zip -d path/to/your/cityscapes +unzip gtFine_trainvaltest.zip -d path/to/your/cityscapes + +# symlink dataset to MMSegmentationKN/data # where "MMSegmentationKN" is the repository you cloned in step 0-4 +mkdir -p MMSegmentationKN/data +ln -s $(realpath path/to/your/cityscapes) MMSegmentationKN/data + +# Replace all "path/to/your" above with where you want to put the dataset! +``` + +Then, we need *cityscapesScripts* to preprocess the CityScapes dataset. If you completely followed our [Step 1-2-4](#step-1-2-4-install-required-python-packages-for-building-and-installing-mmsegmentationkn), you should have python package *cityscapesScripts* installed (if no, execute `pip install cityscapesScripts` command). + +```shell +# Replace "path/to/your" with where you want to put the dataset! +export CITYSCAPES_DATASET=$(realpath path/to/your/cityscapes) +csCreateTrainIdLabelImgs +``` + +Wait several minutes and you'll see something like this: + +```plain +Processing 5000 annotation files +Progress: 100.0 % +``` + +The files inside the dataset folder should be something like: + +```plain +MMSegmentationKN/data/cityscapes +├── gtFine +│ ├── test +│ │ ├── ... +│ ├── train +│ │ ├── ... +│ ├── val +│ │ ├── frankfurt +│ │ │ ├── frankfurt_000000_000294_gtFine_color.png +│ │ │ ├── frankfurt_000000_000294_gtFine_instanceIds.png +│ │ │ ├── frankfurt_000000_000294_gtFine_labelIds.png +│ │ │ ├── frankfurt_000000_000294_gtFine_labelTrainIds.png +│ │ │ ├── frankfurt_000000_000294_gtFine_polygons.png +│ │ │ ├── ... +│ │ ├── ... +├── leftImg8bit +│ ├── test +│ │ ├── ... +│ ├── train +│ │ ├── ... +│ ├── val +│ │ ├── frankfurt +│ │ │ ├── frankfurt_000000_000294_leftImg8bit.png +│ │ ├── ... +... +``` + +It's recommended that you *symlink* the dataset folder to mmdetection folder. However, if you place your dataset folder at different place and do not want to symlink, you have to change the corresponding paths in the config file. + +Now the dataset should be ready for training. + + +## Step 2-3: Train STDC-Seg on CityScapes + +Short-Term Dense Concatenate Network (STDC network) is a light-weight network structure for convolutional neural network. If we apply this network structure to semantic segmentation task, it's called STDC-Seg. It's first introduced in [Rethinking BiSeNet For Real-time Semantic Segmentation +](https://arxiv.org/abs/2104.13188). Please check the paper if you want to know the algorithm details. + +We only need a configuration file to train a deep learning model in either the original MMSegmentation or MMSegmentationKN. STDC-Seg is provided in the original MMSegmentation repository, but the original configuration file needs some modification due to our hardware limitation so that we can apply the trained model to our Kneron dongle. + +To make a configuration file compatible with our device, we have to: + +* Change the mean and std value in image normalization to `mean=[128., 128., 128.]` and `std=[256., 256., 256.]`. +* Shrink the input size during inference phase. The original CityScapes image size is too large (2048(w)x1024(h)) for our device; 1024(w)x512(h) might be good for our device. + +To achieve this, you can modify the `img_scale` in `test_pipeline` and `img_norm_cfg` in the configuration file `configs/_base_/datasets/cityscapes.py`. + +Luckily, here in MMSegmentationKN, we provide a modified STDC-Seg configuration file (`configs/stdc/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py`) so we can easily apply the trained model to our device. + +To train STDC-Seg compatible with our device, just execute: + +```shell +cd MMSegmentationKN +python tools/train.py configs/stdc/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py +``` + +And MMSegmentationKN will generate `work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes` folder and save the configuration file and all checkpoints there. + +# Step 3: Test Trained Model +`tools/test.py` is a script that generates inference results from test set with our pytorch model and evaluates the results to see if our pytorch model is well trained (if `--eval` argument is given). Note that it's always good to evluate our pytorch model before deploying it. + +```shell +python tools/test.py \ + work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py \ + work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.pth \ + --eval mIoU +``` +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py` can be your training config. +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.pth` can be your model checkpoint. + +The expected result of the command above should be something similar to the following text (the numbers may slightly differ): +``` +... ++---------------+-------+-------+ +| Class | IoU | Acc | ++---------------+-------+-------+ +| road | 97.49 | 98.59 | +| sidewalk | 80.17 | 88.71 | +| building | 89.52 | 95.25 | +| wall | 57.92 | 66.99 | +| fence | 55.5 | 70.15 | +| pole | 38.93 | 47.51 | +| traffic light | 49.95 | 59.97 | +| traffic sign | 62.1 | 70.05 | +| vegetation | 89.02 | 95.27 | +| terrain | 60.18 | 72.26 | +| sky | 91.84 | 96.34 | +| person | 68.98 | 84.35 | +| rider | 47.79 | 60.98 | +| car | 91.63 | 96.48 | +| truck | 74.31 | 83.52 | +| bus | 80.24 | 86.83 | +| train | 66.45 | 76.78 | +| motorcycle | 48.69 | 58.18 | +| bicycle | 65.81 | 81.68 | ++---------------+-------+-------+ +Summary: + ++------+-------+-------+ +| aAcc | mIoU | mAcc | ++------+-------+-------+ +| 94.3 | 69.29 | 78.42 | ++------+-------+-------+ +``` + +# Step 4: Export ONNX and Verify + +## Step 4-1: Export ONNX + +`tools/pytorch2onnx_kneron.py` is a script provided by MMSegmentationKN to help users to convert our trained pytorch model to ONNX: +```shell +python tools/pytorch2onnx_kneron.py \ + work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py \ + --checkpoint work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.pth \ + --output-file work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.onnx +``` +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py` can be your training config. +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.pth` can be your model checkpoint. +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.onnx` can be any other path. Here for convenience, the ONNX file is placed in the same folder of our pytorch checkpoint. + +## Step 4-2: Verify ONNX + +`tools/deploy_test_kneron.py` is a script provided by MMSegmentationKN to help users to verify if our exported ONNX generates similar outputs with what our PyTorch model does: +```shell +python tools/deploy_test_kneron.py \ + work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py \ + work_dirs/kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.onnx \ + --eval mIoU +``` +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/kn_stdc1_in1k-pre_512x1024_80k_cityscapes.py` can be your training config. +* `kn_stdc1_in1k-pre_512x1024_80k_cityscapes/latest.pth` can be your exported ONNX file. + +The expected result of the command above should be something similar to the following text (the numbers may slightly differ): + +``` +``` + +Note that the ONNX results may differ from the PyTorch results due to some implementation differences between PyTorch and ONNXRuntime. + +# Step 5: Convert ONNX File to [NEF](http://doc.kneron.com/docs/#toolchain/manual/#5-nef-workflow) Model for Kneron Platform + +### Step 5-1: Install Kneron toolchain docker: + +* check [document](http://doc.kneron.com/docs/#toolchain/manual/#1-installation) + +### Step 5-2: Mout Kneron toolchain docker + +* Mount a folder (e.g. '/mnt/hgfs/Competition') to toolchain docker container as `/data1`. The converted ONNX in Step 3 should be put here. All the toolchain operation should happen in this folder. +``` +sudo docker run --rm -it -v /mnt/hgfs/Competition:/data1 kneron/toolchain:latest +``` + +### Step 5-3: Import KTC and required lib in python shell +* Here we demonstrate how to go through all Kneron Toolchain (KTC) flow through Python API: +```python +import ktc +import numpy as np +import os +import onnx +from PIL import Image +``` + +### Step 5-4: Optimize the onnx model +```python +onnx_path = '/data1/latest.onnx' +m = onnx.load(onnx_path) +m = ktc.onnx_optimizer.onnx2onnx_flow(m) +onnx.save(m,'latest.opt.onnx') +``` + +### Step 5-5: Configure and load data necessary for ktc, and check if onnx is ok for toolchain +```python +# npu (only) performance simulation +km = ktc.ModelConfig((&)model_id_on_public_field, "0001", "720", onnx_model=m) +eval_result = km.evaluate() +print("\nNpu performance evaluation result:\n" + str(eval_result)) +``` + +### Step 5-6: quantize the onnx model +We [sampled 3 images from Cityscapes dataset](https://www.kneron.com/tw/support/education-center/?folder=MMLab/MMSegmentationKN/&download=41) (3 images) as quantization data. To test our quantized model: +1. Download the zip file +2. Extract the zip file as a folder named `cityscapes_minitest` +3. Put the `cityscapes_minitest` into docker mounted folder (the path in docker container should be `/data1/cityscapes_minitest`) + +The following script will do some preprocess(should be the same as training code) on our quantization data, and put it in a list: + +```python +import os +from os import walk + +img_list = [] +for (dirpath, dirnames, filenames) in walk("/data1/cityscapes_minitest"): + for f in filenames: + fullpath = os.path.join(dirpath, f) + + image = Image.open(fullpath) + image = image.convert("RGB") + image = Image.fromarray(np.array(image)[...,::-1]) + img_data = np.array(image.resize((1024, 512), Image.BILINEAR)) / 256 - 0.5 + print(fullpath) + img_list.append(img_data) +``` + +Then perform quantization. The BIE model will be generated at `/data1/output.bie`. + +```python +# fixed-point analysis +bie_model_path = km.analysis({"input": img_list}) +print("\nFixed-point analysis done. Save bie model to '" + str(bie_model_path) + "'") +``` + +### Step 5-7: Compile + +The final step is compile the BIE model into an NEF model. +```python +# compile +nef_model_path = ktc.compile([km]) +print("\nCompile done. Save Nef file to '" + str(nef_model_path) + "'") +``` + +You can find the NEF file at `/data1/batch_compile/models_720.nef`. `models_720.nef` is the final compiled model. + +# Step 6: Run [NEF](http://doc.kneron.com/docs/#toolchain/manual/#5-nef-workflow) model on KL720 + +* Check Kneron PLUS official document: + * Python version: + http://doc.kneron.com/docs/#plus_python/#_top + * C version: + http://doc.kneron.com/docs/#plus_c/getting_started/ \ No newline at end of file diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index c688180..a3a2933 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,11 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .inference import ( + inference_segmentor, + inference_segmentor_kn, + init_segmentor, + init_segmentor_kn, + show_result_pyplot, +) from .test import multi_gpu_test, single_gpu_test from .train import (get_root_logger, init_random_seed, set_random_seed, train_segmentor) __all__ = [ - 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', - 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', + 'get_root_logger', 'set_random_seed', 'train_segmentor', + 'init_segmentor', 'init_segmentor_kn', 'inference_segmentor', + 'inference_segmentor_kn', 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot', 'init_random_seed' ] diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9069438..648f255 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -7,6 +7,7 @@ from mmcv.runner import load_checkpoint from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor +from mmseg.models.segmentors import ONNXRuntimeSegmentorKN def init_segmentor(config, checkpoint=None, device='cuda:0'): @@ -40,6 +41,32 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): return model +def init_segmentor_kn(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if checkpoint is None or not checkpoint.endswith(".onnx"): + return init_segmentor(config, checkpoint, device) + try: + _, device_id = device.split(":") + device_id = int(device_id) + except Exception: + device_id = None if device == 'cpu' else 0 + model = ONNXRuntimeSegmentorKN( + checkpoint, cfg=config, device_id=device_id + ).eval() + return model + + class LoadImage: """A simple pipeline to load image.""" @@ -99,6 +126,20 @@ def inference_segmentor(model, img): return result +@torch.no_grad() +def inference_segmentor_kn(model, img): + if isinstance(model, ONNXRuntimeSegmentorKN): + cfg = model.cfg + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + return model(return_loss=False, rescale=True, **data) + else: + return inference_segmentor(model, img) + + def show_result_pyplot(model, img, result, diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py index 387c858..f7ede7b 100644 --- a/mmseg/models/segmentors/__init__.py +++ b/mmseg/models/segmentors/__init__.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseSegmentor +from .base import BaseSegmentor, ONNXRuntimeSegmentorKN from .cascade_encoder_decoder import CascadeEncoderDecoder from .encoder_decoder import EncoderDecoder -__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] +__all__ = [ + 'BaseSegmentor', + 'ONNXRuntimeSegmentorKN', + 'EncoderDecoder', + 'CascadeEncoderDecoder' +] diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 9b22a7c..339b6b3 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -2,12 +2,16 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict +from typing import Any, Iterable, Union +from os import path as osp import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.runner import BaseModule, auto_fp16 +from mmseg.core import get_classes, get_palette +from mmseg.ops import resize class BaseSegmentor(BaseModule, metaclass=ABCMeta): @@ -284,3 +288,163 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img + + +class ONNXRuntimeSegmentorKN(BaseSegmentor): + + def __init__( + self, + onnx_file: str, + cfg: Any, + device_id: Union[int, None] = 0): + super(ONNXRuntimeSegmentorKN, self).__init__() + import onnxruntime as ort + + # get the custom op path + ort_custom_op_path = '' + try: + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + except (ImportError, ModuleNotFoundError): + warnings.warn( + 'If input model has custom op from mmcv, you may ' + 'have to build mmcv with ONNXRuntime from source.') + session_options = ort.SessionOptions() + # register custom op for onnxruntime + if osp.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + providers = ['CPUExecutionProvider'] + provider_options = [{}] + is_cuda_available = ort.get_device() == 'GPU' and torch.cuda.is_available() + if is_cuda_available: + providers.insert(0, 'CUDAExecutionProvider') + device_id = device_id or 0 + provider_options.insert(0, {'device_id': device_id}) + sess = ort.InferenceSession( + onnx_file, session_options, providers, provider_options + ) + self.sess = sess + sess_inputs = sess.get_inputs() + assert len(sess_inputs) == 1, "Only onnx with 1 input is supported" + self.input_name = sess_inputs[0].name + sess_outputs = sess.get_outputs() + self.num_classes = sess_outputs[0].shape[1] + assert len(sess_outputs) == 1, "Only onnx with 1 output is supported" + self.output_name_list = [sess_outputs[0].name] + self.cfg = cfg # TODO: necessary? + self.test_cfg = cfg.model.test_cfg + self.test_mode = self.test_cfg.mode # NOTE: should be 'whole' or 'slide' + self.is_cuda_available = is_cuda_available + self.count_mat = None + try: + if 'test' in cfg.data: + dataset_name = cfg.data.test['type'] + else: + dataset_name = cfg.data.train['type'] + dataset_name = dataset_name.lower()[:-7] + self.CLASSES = get_classes(dataset_name) + self.PALETTE = get_palette(dataset_name) + except (AttributeError, KeyError): + warnings.warn( + "Failed to fetch dataset name from config; no CLASSES " + "and PALETTE for this ONNX model" + ) + except ValueError: + warnings.warn( + "Failed to fetch CLASSES and PALETTE from dataset " + f"{dataset_name}; no CLASSES and PALETTE for this " + "ONNX MODEL." + ) + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def encode_decode(self, img, img_metas): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def forward_test(self, imgs, img_metas, **kwargs): + return super().forward_test(imgs, img_metas[0].data, **kwargs) + + def simple_slide_inference( + self, + img: np.ndarray, + img_meta: Union[Iterable, None] = None): + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + _, _, h_img, w_img = img.shape + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = np.zeros((1, num_classes, h_img, w_img), dtype=np.float32) + # NOTE: count_mat should be invariant since + # input shape of kneron's onnx is fixed + if self.count_mat is None: + count_mat = np.zeros((1, 1, h_img, w_img), dtype=np.float32) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.sess.run( + self.output_name_list, + {self.input_name: crop_img} + )[0] + preds += np.pad( + crop_seg_logit, + ([0, 0], + [0, 0], + [int(y1), int(preds.shape[2] - y2)], + [int(x1), int(preds.shape[3] - x2)]), + ) + if self.count_mat is None: + count_mat[:, :, y1:y2, x1:x2] += 1 + if self.count_mat is None: + assert (count_mat == 0).sum() == 0 + self.count_mat = count_mat + preds /= self.count_mat + return preds + + @property + def module(self): + return self + + @torch.no_grad() + def simple_test( + self, + img: torch.Tensor, + img_meta: Union[Iterable, None] = None, + **kwargs) -> list: + img = img.cpu().numpy() + # NOTE: not using run_with_iobinding since some ort versions + # generate wrong results when inferencing with CUDA + if self.test_mode == 'slide': + seg_pred = self.simple_slide_inference(img, img_meta) + else: + seg_pred = self.sess.run( + self.output_name_list, {self.input_name: img} + )[0] + if img_meta is not None: + ori_shape = img_meta[0]['ori_shape'] + if not (ori_shape[0] == seg_pred.shape[-2] + and ori_shape[1] == seg_pred.shape[-1]): + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = resize( + seg_pred, size=tuple(ori_shape[:2]), mode='bilinear') + seg_pred = seg_pred.numpy() + elif img.shape[2:] != seg_pred.shape[2:]: + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = resize( + seg_pred, size=(img.shape[3], img.shape[2]), mode='bilinear') + seg_pred = seg_pred.numpy() + seg_pred = seg_pred.argmax(1) + return list(seg_pred) + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 72467b4..184ecc5 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -115,7 +115,8 @@ class EncoderDecoder(BaseSegmentor): def forward_dummy(self, img): """Dummy forward function.""" - seg_logit = self.encode_decode(img, None) + seg_logit = self.extract_feat(img) + seg_logit = self._decode_head_forward_test(seg_logit, None) return seg_logit diff --git a/requirements/onnx.txt b/requirements/onnx.txt new file mode 100644 index 0000000..7944fcb --- /dev/null +++ b/requirements/onnx.txt @@ -0,0 +1,3 @@ +onnx>=1.6.0 +onnxruntime +onnxoptimizer diff --git a/requirements_kneron.txt b/requirements_kneron.txt new file mode 100644 index 0000000..d188261 --- /dev/null +++ b/requirements_kneron.txt @@ -0,0 +1,2 @@ +-r requirements.txt +-r requirements/onnx.txt diff --git a/setup.py b/setup.py index 91afefb..1784626 100755 --- a/setup.py +++ b/setup.py @@ -170,13 +170,13 @@ if __name__ == '__main__': setup( name='mmsegmentation', version=get_version(), - description='Open MMLab Semantic Segmentation Toolbox and Benchmark', + description='Open MMLab Semantic Segmentation Toolbox and Benchmark (Kneron Edition)', long_description=readme(), long_description_content_type='text/markdown', - author='MMSegmentation Contributors', - author_email='openmmlab@gmail.com', + author='MMSegmentation Contributors and Kneron', + author_email='', keywords='computer vision, semantic segmentation', - url='http://github.com/open-mmlab/mmsegmentation', + url='http://github.com/kneron/MMSegmentationKN', packages=find_packages(exclude=('configs', 'tools', 'demo')), include_package_data=True, classifiers=[ @@ -191,10 +191,11 @@ if __name__ == '__main__': license='Apache License 2.0', install_requires=parse_requirements('requirements/runtime.txt'), extras_require={ - 'all': parse_requirements('requirements.txt'), + 'all': parse_requirements('requirements_kneron.txt'), 'tests': parse_requirements('requirements/tests.txt'), 'build': parse_requirements('requirements/build.txt'), 'optional': parse_requirements('requirements/optional.txt'), + 'onnx': parse_requirements('requirements/onnx.txt'), }, ext_modules=[], zip_safe=False) diff --git a/tools/deploy_test_kneron.py b/tools/deploy_test_kneron.py new file mode 100644 index 0000000..aa12f66 --- /dev/null +++ b/tools/deploy_test_kneron.py @@ -0,0 +1,214 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import shutil +import warnings + +import mmcv +import torch +from mmcv.runner import get_dist_info +from mmcv.utils import DictAction + +from mmseg.apis import single_gpu_test +from mmseg.datasets import build_dataloader, build_dataset +from mmseg.models.segmentors.base import ONNXRuntimeSegmentorKN + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='mmseg backend test (and eval)') + parser.add_argument('config', help='test config file path') + parser.add_argument('model', help='Input model file (onnx only)') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' + ' for generic datasets, and "cityscapes" for Cityscapes') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=None, + help='input image height and width.') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + assert args.out or args.eval or args.format_only or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + cfg.model.pretrained = None + cfg.data.test.test_mode = True + if args.shape is not None: + + if len(args.shape) == 1: + shape = (args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + shape = (args.shape[1], args.shape[0]) + else: + raise ValueError('invalid input shape') + + test_mode = cfg.model.test_cfg.mode + if test_mode == 'slide': + warnings.warn( + "We suggest you NOT assigning shape when exporting " + "slide-mode models. Assigning shape to slide-mode models " + "may result in unexpected results. To see which mode the " + "model is using, check cfg.model.test_cfg.mode, which " + "should be either 'whole' or 'slide'." + ) + cfg.model.test_cfg['crop_size'] = shape + else: + cfg.test_pipeline[1]['img_scale'] = shape + cfg.data.test['pipeline'][1]['img_scale'] = shape + + # init distributed env first, since logger depends on the dist info. + distributed = False + + # build the dataloader + # TODO: support multiple images per gpu (only minor changes are needed) + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # load onnx config and meta + cfg.model.train_cfg = None + + model = ONNXRuntimeSegmentorKN(args.model, cfg=cfg, device_id=0) + + model.CLASSES = dataset.CLASSES + model.PALETTE = dataset.PALETTE + + # clean gpu memory when starting a new evaluation. + torch.cuda.empty_cache() + eval_kwargs = {} if args.eval_options is None else args.eval_options + + # Deprecated + efficient_test = eval_kwargs.get('efficient_test', False) + if efficient_test: + warnings.warn( + '``efficient_test=True`` does not have effect in tools/test_kneron.py, ' + 'the evaluation and format results are CPU memory efficient by ' + 'default') + + eval_on_format_results = ( + args.eval is not None and 'cityscapes' in args.eval) + if eval_on_format_results: + assert len(args.eval) == 1, 'eval on format results is not ' \ + 'applicable for metrics other than ' \ + 'cityscapes' + if args.format_only or eval_on_format_results: + if 'imgfile_prefix' in eval_kwargs: + tmpdir = eval_kwargs['imgfile_prefix'] + else: + tmpdir = '.format_cityscapes' + eval_kwargs.setdefault('imgfile_prefix', tmpdir) + mmcv.mkdir_or_exist(tmpdir) + else: + tmpdir = None + + results = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + False, + args.opacity, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + warnings.warn( + 'The behavior of ``args.out`` has been changed since MMSeg ' + 'v0.16, the pickled outputs could be seg map as type of ' + 'np.array, pre-eval results or file paths for ' + '``dataset.format_results()``.') + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + if args.eval: + dataset.evaluate(results, args.eval, **eval_kwargs) + if tmpdir is not None and eval_on_format_results: + # remove tmp dir when cityscapes evaluation + shutil.rmtree(tmpdir) + + +if __name__ == '__main__': + main() diff --git a/tools/optimizer_scripts/.clang-format b/tools/optimizer_scripts/.clang-format new file mode 100644 index 0000000..2593ef5 --- /dev/null +++ b/tools/optimizer_scripts/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: Google \ No newline at end of file diff --git a/tools/optimizer_scripts/.gitignore b/tools/optimizer_scripts/.gitignore new file mode 100644 index 0000000..991fd07 --- /dev/null +++ b/tools/optimizer_scripts/.gitignore @@ -0,0 +1,7 @@ +__pycache__ +.vscode +*.pyc +models.py +temp.py +.ssh/ +docker/test_models/ \ No newline at end of file diff --git a/tools/optimizer_scripts/README.md b/tools/optimizer_scripts/README.md new file mode 100644 index 0000000..cac99c5 --- /dev/null +++ b/tools/optimizer_scripts/README.md @@ -0,0 +1,189 @@ +# Converter Scripts + +[![pipeline status](http://192.168.200.1:8088/jiyuan/converter_scripts/badges/master/pipeline.svg)](http://192.168.200.1:8088/jiyuan/converter_scripts/commits/master) + +This project collects various optimization scripts and converter scritps for +Kneron toolchain. This collection does not include the Keras to ONNX converter +and the Caffe to ONNX converter. They are in seperate projects. + +**The scripts not listed below are used as libraries and cannot be used +directly.** + +## onnx2onnx.py + +### 1.1. Description + +General optimizations on ONNX model for Kneron toolchain. Though Kneron +toolchains are designed to take ONNX models as input, they have some +restrictions on the models (e.g. inferenced shapes for all value_info). Thus, we +have this tool to do some general optimization and conversion on ONNX models. +**Notice that this script should take an valid ONNX model as input.** It cannot +turn an invalid ONNX model into a valid one. + +### 1.2. Basic Usage + +```bash +python onnx2onnx.py input.onnx -o output.onnx +``` + +### 1.3. Optimizations Included + +* Fusing BN into Conv. +* Fusing BN into Gemm. +* Fusing consecutive Gemm. +* Eliminating Identify layers and Dropout layers. +* Eliminating last shape changing nodes. +* Replacing initializers into Constant nodes. +* Replacing global AveragePool with GAP. +* Replacing Squeeze and Unsqueeze with Reshape. +* Replacing 1x1 depthwise with BN. +* Inferencing Upsample shapes. +* Transposing B in Gemm. + +## pytorch2onnx.py + +### 2.1. Description + +Convert Pytorch models or Pytorch generated ONNX models into Kneron toolchain +compatible ONNX files. This script include most of the optimizations in +`onnx2onnx.py`. It also includes some optimizations for Pytorch model only. + +### 2.2. Basic Usage + +```bash +# Take Pytorch model name, input channel number, input height, input width +python pytorch2onnx.py input.pth output.onnx --input-size 3 224 224 +# Or take Pytorch exported ONNX. +python pytorch2onnx.py input.onnx output.onnx +``` + +### 2.3. Optimizations Included + +* Adding name to nodes. +* Unsqueeze nodes constant folding. +* Reshape nodes constant folding. +* Optimizations in `onnx2onnx.py`. + +## editor.py + +### 3.1. Description + +This is an simple ONNX editor which achieves the following functions: + +* Add nop BN or Conv nodes. +* Delete specific nodes or inputs. +* Cut the graph from certain node (Delete all the nodes following the node). +* Reshape inputs and outputs + +### 3.2 Usage + +``` +usage: editor.py [-h] [-c CUT_NODE [CUT_NODE ...]] + [--cut-type CUT_TYPE [CUT_TYPE ...]] + [-d DELETE_NODE [DELETE_NODE ...]] + [--delete-input DELETE_INPUT [DELETE_INPUT ...]] + [-i INPUT_CHANGE [INPUT_CHANGE ...]] + [-o OUTPUT_CHANGE [OUTPUT_CHANGE ...]] + [--add-conv ADD_CONV [ADD_CONV ...]] + [--add-bn ADD_BN [ADD_BN ...]] + in_file out_file + +Edit an ONNX model. The processing sequense is 'delete nodes/values' -> 'add +nodes' -> 'change shapes'. Cutting cannot be done with other operations +together + +positional arguments: + in_file input ONNX FILE + out_file ouput ONNX FILE + +optional arguments: + -h, --help show this help message and exit + -c CUT_NODE [CUT_NODE ...], --cut CUT_NODE [CUT_NODE ...] + remove nodes from the given nodes(inclusive) + --cut-type CUT_TYPE [CUT_TYPE ...] + remove nodes by type from the given nodes(inclusive) + -d DELETE_NODE [DELETE_NODE ...], --delete DELETE_NODE [DELETE_NODE ...] + delete nodes by names and only those nodes + --delete-input DELETE_INPUT [DELETE_INPUT ...] + delete inputs by names + -i INPUT_CHANGE [INPUT_CHANGE ...], --input INPUT_CHANGE [INPUT_CHANGE ...] + change input shape (e.g. -i 'input_0 1 3 224 224') + -o OUTPUT_CHANGE [OUTPUT_CHANGE ...], --output OUTPUT_CHANGE [OUTPUT_CHANGE ...] + change output shape (e.g. -o 'input_0 1 3 224 224') + --add-conv ADD_CONV [ADD_CONV ...] + add nop conv using specific input + --add-bn ADD_BN [ADD_BN ...] + add nop bn using specific input +``` + +### 3.3. Example + +Here is an example of when and how to use the editor.py. + +```bash +# In the `res` folder, there is a vdsr model from tensorflow. +# We need to convert this model firstly. +./tf2onnx.sh res/vdsr_41_20layer_1.pb res/tmp.onnx images:0 output:0 +# This onnx file seems valid. But, it's channel last for the input and output. +# It is using Traspose to convert to channel first, affacting the performance. +# Thus, here we use the editor to delete these Transpose and reset the shapes. +python editor.py debug.onnx new.onnx -d Conv2D__6 Conv2D_19__84 -i 'images:0 1 3 41 41' -o 'output:0 1 3 41 41' +# Now, it has no Transpose and take channel first inputs directly. +``` + +## test_models_opt.py + +### 4.1. Description +Compare all original and optimized onnx models under a specified directory. +Using different endings to locate original and optimized model paths. Apply +onnxruntime inference to the models, and compare the results from original +and optimized models. Calculate basic statistics and store to a csv file. + +### 4.2. Usage + +```bash +python DIR ending1 ending2 csv_out_file -p=Y/N + +# csv_out_file is file path for the stats data. +# -p --plot is the plot option, if Y, stats plots will be generated. +``` + +### 4.3. Statistics +* max_rel_diff +* max_abs_diff +* mean_rel_diff +* mean_abs_diff +* std_rel_diff +* std_abs_diff +* acc_with_diff_precision +* percentile + +### 4.4. Plots +* Max Relative Difference Histogram +* Max Absolute Difference Histogram +* Rel_diff Percentiles of Raw and Optimized Models +* Abs_diff Percentiles of Raw and Optimized Models +* Accuracies with Different Precisions + +## tensorflow2onnx.py + +### 5.1. Description +Convert and optimize tensorflow models. If input file is frozen tensorflow .pb model, +convert to onnx model and do the custmized optimization afterwards. If input model is already +onnx model, apply optimization and save optimized model. + +### 5.2 Dependency + +This scripts depends on the tensorflow-onnx project. Please [check and install it](https://github.com/onnx/tensorflow-onnx/tree/r1.5) before using this script. We currently support up to version 1.5.5. For other versions, you may need to try it our yourself. + +### 5.3. Basic Usage +```bash +python tensorflow2onnx.py in_file out_file -t=True/False + +# -t --test, is the option for test mode, if True, shape change after input will not be eliminated. +``` + +### 5.4. Model Save Paths +`in_file` is the input model path, `out_file` specifies output optimized model path. +If input file is `.pb` model, an unoptimized onnx model will be saved to the output directory as well. + diff --git a/tools/optimizer_scripts/consecutive_conv_opt.py b/tools/optimizer_scripts/consecutive_conv_opt.py new file mode 100644 index 0000000..c7d4068 --- /dev/null +++ b/tools/optimizer_scripts/consecutive_conv_opt.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import sys + +from tools.other import topological_sort +from tools import helper + +def fuse_bias_in_consecutive_1x1_conv(g): + for second in g.node: + # Find two conv + if second.op_type != 'Conv': + continue + first = helper.find_node_by_output_name(g, second.input[0]) + if first is None or first.op_type != 'Conv': + continue + # Check if the first one has only one folloing node + if len(helper.find_following_nodes_by_input_value_name(g, first.output[0])) != 1: + continue + # If first node has no bias, continue + if len(first.input) == 2: + continue + # Check their kernel size + first_kernel_shape = helper.get_list_attribute_by_name(first, 'kernel_shape', 'int') + second_kernel_shape = helper.get_list_attribute_by_name(second, 'kernel_shape', 'int') + prod = first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1] + if prod != 1: + continue + print('Found: ', first.name, ' ', second.name) + # Get bias of the nodes + first_bias_node = helper.find_node_by_output_name(g, first.input[2]) + second_weight_node = helper.find_node_by_output_name(g, second.input[1]) + second_bias_node = helper.find_node_by_output_name(g, second.input[2]) + first_bias = helper.constant_to_numpy(first_bias_node) + second_weight = helper.constant_to_numpy(second_weight_node) + second_bias = helper.constant_to_numpy(second_bias_node) + # Calculate the weight for second node + first_bias = np.reshape(first_bias, (1, first_bias.size)) + second_weight = np.reshape(second_weight, (second_weight.shape[0], second_weight.shape[1])) + second_weight = np.transpose(second_weight) + new_second_bias = second_bias + np.matmul(first_bias, second_weight) + new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,)) + # Generate new weight + new_first_bias = np.reshape(first_bias, (first_bias.size, )) + for i in range(new_first_bias.shape[0]): + new_first_bias[i] = 0.0 + new_first_bias_node = helper.numpy_to_constant(first_bias_node.output[0], new_first_bias) + new_second_bias_node = helper.numpy_to_constant(second_bias_node.output[0], new_second_bias) + # Delete old weight and add new weights + g.node.remove(first_bias_node) + g.node.remove(second_bias_node) + g.node.extend([new_first_bias_node, new_second_bias_node]) + topological_sort(g) + +if __name__ == "__main__": + if len(sys.argv) != 3: + exit(1) + m = onnx.load(sys.argv[1]) + fuse_bias_in_consecutive_1x1_conv(m.graph) + onnx.save(m, sys.argv[2]) \ No newline at end of file diff --git a/tools/optimizer_scripts/docker/Dockerfile b/tools/optimizer_scripts/docker/Dockerfile new file mode 100644 index 0000000..bb62f7f --- /dev/null +++ b/tools/optimizer_scripts/docker/Dockerfile @@ -0,0 +1,24 @@ +FROM continuumio/miniconda3:latest +LABEL maintainer="jiyuan@kneron.us" + +# Install python packages +RUN conda update -y conda && \ +conda install -y python=3.6 && \ +conda install -y -c intel caffe && \ +conda install -y -c pytorch pytorch=1.3.1 torchvision=0.4.2 cpuonly && \ +conda install -y -c conda-forge tensorflow=1.5.1 keras=2.2.4 && \ +pip install onnx==1.4.1 onnxruntime==1.1.0 tf2onnx==1.5.4 && \ +ln -s /opt/conda/lib/libgflags.so.2.2.2 /opt/conda/lib/libgflags.so.2 + +# Install git lfs packages +RUN apt-get update && apt-get install -y curl apt-utils && \ +curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ +apt-get install -y git-lfs + +RUN conda clean -a -y && rm -rf /var/lib/apt/lists/* + +# copy the test data +COPY ./test_models /test_models + +# Clean the environment and finalize the process +WORKDIR /root \ No newline at end of file diff --git a/tools/optimizer_scripts/editor.py b/tools/optimizer_scripts/editor.py new file mode 100644 index 0000000..8ccc6ca --- /dev/null +++ b/tools/optimizer_scripts/editor.py @@ -0,0 +1,118 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import argparse + +import tools.modhelper as helper +import tools.other as other +import tools.replacing as replacing +# Main process +# Argument parser +parser = argparse.ArgumentParser(description="Edit an ONNX model.\nThe processing sequense is 'delete nodes/values' -> 'add nodes' -> 'change shapes'.\nCutting cannot be done with other operations together") +parser.add_argument('in_file', type=str, help='input ONNX FILE') +parser.add_argument('out_file', type=str, help="ouput ONNX FILE") +parser.add_argument('-c', '--cut', dest='cut_node', type=str, nargs='+', help="remove nodes from the given nodes(inclusive)") +parser.add_argument('--cut-type', dest='cut_type', type=str, nargs='+', help="remove nodes by type from the given nodes(inclusive)") +parser.add_argument('-d', '--delete', dest='delete_node', type=str, nargs='+', help="delete nodes by names and only those nodes") +parser.add_argument('--delete-input', dest='delete_input', type=str, nargs='+', help="delete inputs by names") +parser.add_argument('--delete-output', dest='delete_output', type=str, nargs='+', help="delete outputs by names") +parser.add_argument('-i', '--input', dest='input_change', type=str, nargs='+', help="change input shape (e.g. -i 'input_0 1 3 224 224')") +parser.add_argument('-o', '--output', dest='output_change', type=str, nargs='+', help="change output shape (e.g. -o 'input_0 1 3 224 224')") +parser.add_argument('--add-conv', dest='add_conv', type=str, nargs='+', help='add nop conv using specific input') +parser.add_argument('--add-bn', dest='add_bn', type=str, nargs='+', help='add nop bn using specific input') +parser.add_argument('--rename-output', dest='rename_output', type=str, nargs='+', help='Rename the specific output(e.g. --rename-output old_name new_name)') +parser.add_argument('--pixel-bias-value', dest='pixel_bias_value', type=str, nargs='+', help='(per channel) set pixel value bias bn layer at model front for normalization( e.g. --pixel_bias_value "[104.0, 117.0, 123.0]" )') +parser.add_argument('--pixel-scale-value', dest='pixel_scale_value', type=str, nargs='+', help='(per channel) set pixel value scale bn layer at model front for normalization( e.g. --pixel_scale_value "[0.0078125, 0.0078125, 0.0078125]" )') + +args = parser.parse_args() + +# Load model and polish +m = onnx.load(args.in_file) +m = other.polish_model(m) +g = m.graph +replacing.replace_initializer_with_Constant(g) +other.topological_sort(g) + +# Remove nodes according to the given arguments. +if args.delete_node is not None: + helper.delete_nodes(g, args.delete_node) + +if args.delete_input is not None: + helper.delete_input(g, args.delete_input) + +if args.delete_output is not None: + helper.delete_output(g, args.delete_output) + +# Add do-nothing Conv node +if args.add_conv is not None: + other.add_nop_conv_after(g, args.add_conv) + other.topological_sort(g) + +# Add do-nothing BN node +if args.add_bn is not None: + other.add_nop_bn_after(g, args.add_bn) + other.topological_sort(g) + +# Add bias scale BN node +if args.pixel_bias_value is not None or args.pixel_scale_value is not None: + + if len(g.input) > 1: + raise ValueError(" '--pixel-bias-value' and '--pixel-scale-value' only support one input node model currently") + + i_n = g.input[0] + + pixel_bias_value = [0] * i_n.type.tensor_type.shape.dim[1].dim_value + pixel_scale_value = [1] * i_n.type.tensor_type.shape.dim[1].dim_value + + if args.pixel_bias_value is not None and len(args.pixel_bias_value) == 1: + pixel_bias_value = [float(n) for n in args.pixel_bias_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + + if args.pixel_scale_value is not None and len(args.pixel_scale_value) == 1: + pixel_scale_value = [float(n) for n in args.pixel_scale_value[0].replace( '[' , '' ).replace( ']' , '' ).split(',')] + + + if i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_bias_value) or i_n.type.tensor_type.shape.dim[1].dim_value != len(pixel_scale_value): + raise ValueError("--pixel-bias-value (" + str(pixel_bias_value) + ") and --pixel-scale-value (" + str(pixel_scale_value) + ") should be same as input dimension:" + str(i_n.type.tensor_type.shape.dim[1].dim_value) ) + other.add_bias_scale_bn_after(g, i_n.name, pixel_bias_value, pixel_scale_value) + +# Change input and output shapes as requested +if args.input_change is not None: + other.change_input_shape(g, args.input_change) +if args.output_change is not None: + other.change_output_shape(g, args.output_change) + +# Cutting nodes according to the given arguments. +if args.cut_node is not None or args.cut_type is not None: + if args.cut_node is None: + other.remove_nodes(g, cut_types=args.cut_type) + elif args.cut_type is None: + other.remove_nodes(g, cut_nodes=args.cut_node) + else: + other.remove_nodes(g, cut_nodes=args.cut_node, cut_types=args.cut_type) + other.topological_sort(g) + +# Rename nodes +if args.rename_output: + if len(args.rename_output) % 2 != 0: + print("Rename output should be paires of names.") + else: + for i in range(0, len(args.rename_output), 2): + other.rename_output_name(g, args.rename_output[i], args.rename_output[i + 1]) + +# Remove useless nodes +if args.delete_node or args.delete_input or args.input_change or args.output_change: + # If shape changed during the modification, redo shape inference. + while(len(g.value_info) > 0): + g.value_info.pop() +passes = ['extract_constant_to_initializer'] +m = optimizer.optimize(m, passes) +g = m.graph +replacing.replace_initializer_with_Constant(g) +other.topological_sort(g) +# Polish and output +m = other.polish_model(m) +other.add_output_to_value_info(m.graph) +onnx.save(m, args.out_file) \ No newline at end of file diff --git a/tools/optimizer_scripts/norm_on_scaled_onnx.py b/tools/optimizer_scripts/norm_on_scaled_onnx.py new file mode 100644 index 0000000..f99a866 --- /dev/null +++ b/tools/optimizer_scripts/norm_on_scaled_onnx.py @@ -0,0 +1,52 @@ +import onnx +import sys +import json + +from tools import special + +if len(sys.argv) != 3: + print("python norm_on_scaled_onnx.py input.onnx input.json") + exit(1) + +# Modify onnx +m = onnx.load(sys.argv[1]) +special.add_0_5_to_normalized_input(m) +onnx.save(m, sys.argv[1][:-4] + 'norm.onnx') + +# Change input node +origin_file = open(sys.argv[2], 'r') +origin_json = json.load(origin_file) +origin_json["input_node"]["output_datapath_radix"] = [8] +new_json_str = json.dumps(origin_json) + +# Modify json +file = open(sys.argv[1][:-4] + 'norm.onnx' + '.json', 'w') +s = """{{ + \"{0}\" : + {{ + \"bias_bitwidth\" : 16, + \"{0}_bias\" : [15], + \"{0}_weight\" : [3,3,3], + \"conv_coarse_shift\" : [-4,-4,-4], + \"conv_fine_shift\" : [0,0,0], + \"conv_total_shift\" : [-4,-4,-4], + \"cpu_mode\" : false, + \"delta_input_bitwidth\" : [0], + \"delta_output_bitwidth\" : 8, + \"flag_radix_bias_eq_output\" : true, + \"input_scale\" : [[1.0,1.0,1.0]], + \"output_scale\" : [1.0, 1.0, 1.0], + \"psum_bitwidth\" : 16, + \"weight_bitwidth\" : 8, + \"input_datapath_bitwidth\" : [8], + \"input_datapath_radix\" : [8], + \"working_input_bitwidth\" : 8, + \"working_input_radix\" : [8], + \"working_output_bitwidth\" : 16, + \"working_output_radix\" : 15, + \"output_datapath_bitwidth\" : 8, + \"output_datapath_radix\" : 7 + }},\n""".format('input_norm') +file.write(s + new_json_str[1:]) +file.close() +origin_file.close() diff --git a/tools/optimizer_scripts/onnx1_3to1_4.py b/tools/optimizer_scripts/onnx1_3to1_4.py new file mode 100644 index 0000000..64b72b5 --- /dev/null +++ b/tools/optimizer_scripts/onnx1_3to1_4.py @@ -0,0 +1,135 @@ +# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git + +import sys +import onnx +import numpy as np +from onnx import numpy_helper +from tools import other, helper + +""" +Change onnx model from version 1.3 to version 1.4. +Modify the BN node by removing the spatial attribute +Modify the Upsample node by removing the 'scales' attribute, and adding a constant node instead. +Model's ir_version and opset_import are updated. +""" + +def remove_BN_spatial(g): + for node in g.node: + if node.op_type != 'BatchNormalization': + continue + for att in node.attribute: + if att.name == 'spatial': + node.attribute.remove(att) + + +def upsample_attribute_to_const(g): + for node in g.node: + if node.op_type != 'Upsample': + continue + scales_exist = False + for att in node.attribute: + if att.name == 'scales': + scales_exist = True + break + if not scales_exist: + continue + + shape = [len(att.floats)] + node.attribute.remove(att) + new_node = helper.list_to_constant(node.name+'_input', shape, att.floats) + + g.node.extend([new_node]) + value_info = onnx.helper.make_tensor_value_info(node.name+'_input', onnx.TensorProto.FLOAT, shape) + node.input.extend([node.name+'_input']) + g.value_info.extend([value_info]) + +def relu6_to_clip(g): + for node in g.node: + if node.op_type != 'Relu': + continue + max_val = helper.get_var_attribute_by_name(node, 'max', 'float') + if max_val is None: + continue + new_node = onnx.helper.make_node( + "Clip", + node.input, + node.output, + name=node.name, + max=max_val, + min=0.0 + ) + g.node.remove(node) + g.node.extend([new_node]) + +def PRelu_weight_reshape(g): + # For PRelu with single dimension weight. Expand it to 1, x, 1, 1 + for node in g.node: + if node.op_type != "PRelu": + continue + slope = helper.find_node_by_output_name(g, node.input[1]) + if slope is not None: + # Constant node + if len(slope.attribute[0].t.dims) != 1: + continue + slope.attribute[0].t.dims.append(slope.attribute[0].t.dims[0]) + slope.attribute[0].t.dims[0] = 1 + slope.attribute[0].t.dims.append(1) + slope.attribute[0].t.dims.append(1) + else: + # Initializer + for i in g.initializer: + if i.name == node.input[1]: + slope = i + break + if len(slope.dims) != 1: + continue + slope.dims.append(slope.dims[0]) + slope.dims[0] = 1 + slope.dims.append(1) + slope.dims.append(1) + input_value = helper.find_input_by_name(g, node.input[1]) + new_input = onnx.helper.make_tensor_value_info( + node.input[1], + input_value.type.tensor_type.elem_type, + (1, slope.dims[1], 1, 1)) + g.input.remove(input_value) + g.input.append(new_input) + value_info = helper.find_value_by_name(g, node.input[1]) + if value_info is not None: + g.value_info.remove(value_info) + +def do_convert(m): + graph = m.graph + + # Modify the nodes. + remove_BN_spatial(graph) + upsample_attribute_to_const(graph) + relu6_to_clip(graph) + PRelu_weight_reshape(graph) + other.topological_sort(graph) + + # Change model properties. + m.ir_version = 4 + m.opset_import[0].version = 9 + return m + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage:{} file_in file_out".format(sys.argv[0])) + exit(1) + + model = onnx.load(sys.argv[1]) + graph = model.graph + + # Modify the nodes. + remove_BN_spatial(graph) + upsample_attribute_to_const(graph) + relu6_to_clip(graph) + PRelu_weight_reshape(graph) + other.topological_sort(graph) + + # Change model properties. + model.ir_version = 4 + model.opset_import[0].version = 9 + + onnx.save(model, sys.argv[2]) diff --git a/tools/optimizer_scripts/onnx1_4to1_6.py b/tools/optimizer_scripts/onnx1_4to1_6.py new file mode 100644 index 0000000..825b3cd --- /dev/null +++ b/tools/optimizer_scripts/onnx1_4to1_6.py @@ -0,0 +1,184 @@ +# ref http://192.168.200.1:8088/jiyuan/converter_scripts.git + +import sys +import onnx +import onnx.utils +import numpy as np +from onnx import numpy_helper +from tools import other, helper, replacing + +""" +Change onnx model from version 1.4 to version 1.6. +""" + +def replace_all_attribute_to_const_node_in_pad_node(g): + node_to_remove = [] + node_to_extend = [] + for node in g.node: + if node.op_type != 'Pad': + continue + + pad_loc_node = None # must have + pad_mode = 'constant' + pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [0.0]) # need scalar + for att in node.attribute: + if att.name == 'mode': + pad_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') + if att.name == 'pads': + pad_loc_node = helper.list_to_constant(node.name+'_pad_loc', [len(att.ints)], att.ints) + if att.name == 'value': + pad_value_node = helper.list_to_constant(node.name+'_pad_value', [], [att.f]) + + new_node = onnx.helper.make_node( + "Pad", + [node.input[0], pad_loc_node.name, pad_value_node.name], + [node.output[0]], + name=node.output[0], + mode=pad_mode, + ) + node_to_remove.append(node) + node_to_extend.append(new_node) + node_to_extend.append(pad_loc_node) + node_to_extend.append(pad_value_node) + + for node in node_to_remove: + g.node.remove(node) + for node in node_to_extend: + g.node.extend([node]) + + +def upsampling_to_resize(g): + for node in g.node: + if node.op_type != 'Upsample': + continue + upsampling_mode = helper.get_var_attribute_by_name(node, 'mode', 'string') + + scale_value_node = helper.find_node_by_output_name(g, node.input[1]) + if scale_value_node.op_type != "Constant": + raise TypeError('seems there is a dynamic "scales" param in Upsampling node: ' + node.name + ' , you might need to do constant folding first') + + roi_node = helper.list_to_constant(node.name+'_roi_value', [0], []) + + new_node = onnx.helper.make_node( + "Resize", + [node.input[0], roi_node.name, scale_value_node.name], + [node.output[0]], + name=node.output[0], + mode=upsampling_mode, + coordinate_transformation_mode = 'asymmetric' + ) + + g.node.remove(node) + g.node.extend([new_node]) + g.node.extend([roi_node]) + + +def replace_all_attribute_to_const_node_in_slice_node(g): + for node in g.node: + if node.op_type != 'Slice': + continue + + axes_const_node = None + ends_const_node = None + starts_const_node = None + steps_const_node = None + for att in node.attribute: + if att.name == 'axes': + axes_const_node = helper.list_to_constant(node.name+'_axes_value', [len(att.ints)], att.ints) + + if att.name == 'ends': + ends_const_node = helper.list_to_constant(node.name+'_ends_value', [len(att.ints)], att.ints) + + if att.name == 'starts': + starts_const_node = helper.list_to_constant(node.name+'_starts_value', [len(att.ints)], att.ints) + + if att.name == 'steps': + steps_const_node = helper.list_to_constant(node.name+'_steps_value',[ len(att.ints)], att.ints) + + ## pop out from back + attr_len = len(node.attribute) + for i in range(attr_len): + node.attribute.remove(node.attribute[ attr_len -1 - i ]) + + ## according the spec, we need to add node in specific order + if starts_const_node != None: + g.node.extend([starts_const_node]) + node.input.extend([starts_const_node.name]) + if ends_const_node != None: + g.node.extend([ends_const_node]) + node.input.extend([ends_const_node.name]) + if axes_const_node != None: + g.node.extend([axes_const_node]) + node.input.extend([axes_const_node.name]) + if steps_const_node != None: + g.node.extend([steps_const_node]) + node.input.extend([steps_const_node.name]) + + +def replace_min_max_attribute_to_const_node_in_clip_node(g): + for node in g.node: + if node.op_type != 'Clip': + continue + + max_const_node = None + min_const_node = None + for att in node.attribute: + if att.name == 'max': + max_const_node = helper.list_to_constant(node.name+'_max_value', [], [att.f]) + + if att.name == 'min': + min_const_node = helper.list_to_constant(node.name+'_min_value', [], [att.f]) + + ## pop out from back + node.attribute.remove(node.attribute[1]) + node.attribute.remove(node.attribute[0]) + + ## according the spec, we need to add node in specific order + g.node.extend([min_const_node]) + g.node.extend([max_const_node]) + node.input.extend([min_const_node.name]) + node.input.extend([max_const_node.name]) + +def onnx1_4to1_6(model: onnx.ModelProto) -> onnx.ModelProto: + """Update ir_version from 4 to 6 and update opset from 9 to 11. + + Args: + model (onnx.ModelProto): input onnx model. + + Returns: + onnx.ModelProto: updated onnx model. + """ + graph = model.graph + + if model.opset_import[0].version == 11: + print("(Stop) the input model is already opset 11, no need to upgrade") + exit(1) + + # deal with empty node name issue + other.add_name_to_node(graph) + # simplify the node param type from initializer to constant + replacing.replace_initializer_with_Constant(graph) + + # Modify the nodes. + replace_min_max_attribute_to_const_node_in_clip_node(graph) + replace_all_attribute_to_const_node_in_slice_node(graph) + replace_all_attribute_to_const_node_in_pad_node(graph) + upsampling_to_resize(graph) + other.topological_sort(graph) + + # Change model properties. + model.ir_version = 6 + model.opset_import[0].version = 11 + + model = other.polish_model(model) + return model + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage:{} file_in file_out".format(sys.argv[0])) + exit(1) + + model = onnx.load(sys.argv[1]) + model = onnx1_4to1_6(model) + + onnx.save(model, sys.argv[2]) diff --git a/tools/optimizer_scripts/onnx2onnx.py b/tools/optimizer_scripts/onnx2onnx.py new file mode 100644 index 0000000..b820378 --- /dev/null +++ b/tools/optimizer_scripts/onnx2onnx.py @@ -0,0 +1,136 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import argparse +import logging + +from tools import eliminating +from tools import fusing +from tools import replacing +from tools import other +from tools import special +from tools import combo +from tools.helper import logger +# from tools import temp + +def onnx2onnx_flow(m: onnx.ModelProto, + disable_fuse_bn=False, + bn_on_skip=False, + bn_before_add=False, + bgr=False, + norm=False, + rgba2yynn=False, + eliminate_tail=False, + opt_matmul=False, + duplicate_shared_weights=True) -> onnx.ModelProto: + """Optimize the onnx. + + Args: + m (ModelProto): the input onnx ModelProto + disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. + bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False. + bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False. + bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False. + norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False. + rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False. + eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False. + opt_matmul(bool, optional): optimize the MatMul layers according to the NPU limit. Defaults to False. + duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + + Returns: + ModelProto: the optimized onnx model object. + """ + # temp.weight_broadcast(m.graph) + m = combo.preprocess(m, disable_fuse_bn, duplicate_shared_weights) + # temp.fuse_bias_in_consecutive_1x1_conv(m.graph) + + # Add BN on skip branch + if bn_on_skip: + other.add_bn_on_skip_branch(m.graph) + elif bn_before_add: + other.add_bn_before_add(m.graph) + other.add_bn_before_activation(m.graph) + + # My optimization + m = combo.common_optimization(m) + # Special options + if bgr: + special.change_input_from_bgr_to_rgb(m) + if norm: + special.add_0_5_to_normalized_input(m) + if rgba2yynn: + special.add_rgb2yynn_node(m) + + # Remove useless last node + if eliminate_tail: + eliminating.remove_useless_last_nodes(m.graph) + + # Postprocessing + m = combo.postprocess(m) + + # Put matmul after postprocess to avoid transpose moving downwards + if opt_matmul: + special.special_MatMul_process(m.graph) + m = other.polish_model(m) + + return m + +# Main process +if __name__ == "__main__": + # Argument parser + parser = argparse.ArgumentParser(description="Optimize an ONNX model for Kneron compiler") + parser.add_argument('in_file', help='input ONNX FILE') + parser.add_argument('-o', '--output', dest='out_file', type=str, help="ouput ONNX FILE") + parser.add_argument('--log', default='i', type=str, help="set log level") + parser.add_argument('--bgr', action='store_true', default=False, help="set if the model is trained in BGR mode") + parser.add_argument('--norm', action='store_true', default=False, help="set if you have the input -0.5~0.5") + parser.add_argument('--rgba2yynn', action='store_true', default=False, help="set if the model has yynn input but you want to take rgba images") + parser.add_argument('--add-bn-on-skip', dest='bn_on_skip', action='store_true', default=False, + help="set if you only want to add BN on skip branches") + parser.add_argument('--add-bn', dest='bn_before_add', action='store_true', default=False, + help="set if you want to add BN before Add") + parser.add_argument('-t', '--eliminate-tail-unsupported', dest='eliminate_tail', action='store_true', default=False, + help='whether remove the last unsupported node for hardware') + parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + parser.add_argument('--opt-matmul', dest='opt_matmul', action='store_true', default=False, + help="set if you want to optimize the MatMul operations for the kneron hardware.") + parser.add_argument('--no-duplicate-shared-weights', dest='no_duplicate_shared_weights', action='store_true', default=False, + help='do not duplicate shared weights. Defaults to False.') + args = parser.parse_args() + + if args.out_file is None: + outfile = args.in_file[:-5] + "_polished.onnx" + else: + outfile = args.out_file + + if args.log == 'w': + logging.basicConfig(level=logging.WARN) + elif args.log == 'd': + logging.basicConfig(level=logging.DEBUG) + elif args.log == 'e': + logging.basicConfig(level=logging.ERROR) + else: + logging.basicConfig(level=logging.INFO) + + # onnx Polish model includes: + # -- nop + # -- eliminate_identity + # -- eliminate_nop_transpose + # -- eliminate_nop_pad + # -- eliminate_unused_initializer + # -- fuse_consecutive_squeezes + # -- fuse_consecutive_transposes + # -- fuse_add_bias_into_conv + # -- fuse_transpose_into_gemm + + # Basic model organize + m = onnx.load(args.in_file) + + m = onnx2onnx_flow(m, args.disable_fuse_bn, args.bn_on_skip, args.bn_before_add, args.bgr, args.norm, args.rgba2yynn, args.eliminate_tail, args.opt_matmul, not args.no_duplicate_shared_weights) + + onnx.save(m, outfile) diff --git a/tools/optimizer_scripts/onnx_vs_onnx.py b/tools/optimizer_scripts/onnx_vs_onnx.py new file mode 100644 index 0000000..c04c65b --- /dev/null +++ b/tools/optimizer_scripts/onnx_vs_onnx.py @@ -0,0 +1,134 @@ +import onnxruntime +import onnx +import argparse +import numpy as np +from tools import helper + + +onnx2np_dtype = {0: 'float', 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', 6: 'int32', 7: 'int64', 8: 'str', 9: 'bool', 10: 'float16', 11: 'double', 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', 16: 'float'} + + +def onnx_model_results(path_a, path_b, total_times=10): + """ using onnxruntime to inference two onnx models' ouputs + + :onnx model paths: two model paths + :total_times: inference times, default to be 10 + :returns: inference results of two models + """ + # load model a and model b to runtime + session_a = onnxruntime.InferenceSession(path_a, None) + session_b = onnxruntime.InferenceSession(path_b, None) + outputs_a = session_a.get_outputs() + outputs_b = session_b.get_outputs() + + # check outputs + assert len(outputs_a) == len(outputs_b), 'Two models have different output numbers.' + for i in range(len(outputs_a)): + out_shape_a, out_shape_b = outputs_a[i].shape, outputs_b[i].shape + out_shape_a = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_a)) + out_shape_b = list(map(lambda x: x if type(x) == type(1) else 1, out_shape_b)) + assert out_shape_a == out_shape_b, 'Output {} has unmatched shapes'.format(i) + + + # load onnx graph_a and graph_b, to find the initializer and inputs + # then compare to remove the items in the inputs which will be initialized + model_a, model_b = onnx.load(path_a), onnx.load(path_b) + graph_a, graph_b = model_a.graph, model_b.graph + inputs_a, inputs_b = graph_a.input, graph_b.input + init_a, init_b = graph_a.initializer, graph_b.initializer + + # remove initializer from raw inputs + input_names_a, input_names_b = set([ele.name for ele in inputs_a]), set([ele.name for ele in inputs_b]) + init_names_a, init_names_b = set([ele.name for ele in init_a]), set([ele.name for ele in init_b]) + real_inputs_names_a, real_inputs_names_b = input_names_a - init_names_a, input_names_b - init_names_b + + # prepare and figure out matching of real inputs a and real inputs b + # try to keep original orders of each inputs + real_inputs_a, real_inputs_b = [], [] + for item in inputs_a: + if item.name in real_inputs_names_a: + real_inputs_a.append(item) + for item in inputs_b: + if item.name in real_inputs_names_b: + real_inputs_b.append(item) + + # suppose there's only one real single input tensor for each model + # find the real single inputs for model_a and model_b + real_single_input_a = None + real_single_input_b = None + size_a, size_b = 0, 0 + shape_a, shape_b = [], [] + for item_a in real_inputs_a: + size, shape = helper.find_size_shape_from_value(item_a) + if size: + assert real_single_input_a is None, 'Multiple inputs of first model, single input expected.' + real_single_input_a = item_a + size_a, shape_a = size, shape + for item_b in real_inputs_b: + size, shape = helper.find_size_shape_from_value(item_b) + if size: + assert real_single_input_b is None, 'Multiple inputs of second model, single input expected.' + real_single_input_b = item_b + size_b, shape_b = size, shape + assert size_a == size_b, 'Sizes of two models do not match.' + + + # construct inputs tensors + input_data_type_a = real_single_input_a.type.tensor_type.elem_type + input_data_type_b = real_single_input_b.type.tensor_type.elem_type + input_data_type_a = onnx2np_dtype[input_data_type_a] + input_data_type_b = onnx2np_dtype[input_data_type_b] + + # run inference + times = 0 + results_a = [[] for i in range(len(outputs_a))] + results_b = [[] for i in range(len(outputs_b))] + while times < total_times: + # initialize inputs by random data, default to be uniform + data = np.random.random(size_a) + input_a = np.reshape(data, shape_a).astype(input_data_type_a) + input_b = np.reshape(data, shape_b).astype(input_data_type_b) + + input_dict_a = {} + input_dict_b = {} + for item_a in real_inputs_a: + item_type_a = onnx2np_dtype[item_a.type.tensor_type.elem_type] + input_dict_a[item_a.name] = np.array([]).astype(item_type_a) \ + if item_a.name != real_single_input_a.name else input_a + for item_b in real_inputs_b: + item_type_b = onnx2np_dtype[item_b.type.tensor_type.elem_type] + input_dict_b[item_b.name] = np.array([]).astype(item_type_b) \ + if item_b.name != real_single_input_b.name else input_b + + ra = session_a.run([], input_dict_a) + rb = session_b.run([], input_dict_b) + for i in range(len(outputs_a)): + results_a[i].append(ra[i]) + results_b[i].append(rb[i]) + times += 1 + + return results_a, results_b + +if __name__ == '__main__': + # Argument parser. + parser = argparse.ArgumentParser(description="Compare two ONNX models to check if they have the same output.") + parser.add_argument('in_file_a', help='input ONNX file a') + parser.add_argument('in_file_b', help='input ONNX file b') + + args = parser.parse_args() + + results_a, results_b = onnx_model_results(args.in_file_a, args.in_file_b, total_times=10) + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + try: + np.testing.assert_almost_equal(ra_raw, rb_raw, 4) + print('Two models have the same behaviour.') + except Exception as mismatch: + print(mismatch) + exit(1) diff --git a/tools/optimizer_scripts/onnx_vs_onnx_opt.py b/tools/optimizer_scripts/onnx_vs_onnx_opt.py new file mode 100644 index 0000000..b660cf4 --- /dev/null +++ b/tools/optimizer_scripts/onnx_vs_onnx_opt.py @@ -0,0 +1,221 @@ +import onnx +import argparse +import glob +import csv +import numpy as np +import matplotlib.pyplot as plt + +from tools import helper +import onnx_vs_onnx as onnx_tester + +def compare_results(results_a, results_b): + """ compare onnx model inference results + calculate basic statistical values + results: results from inference multiple times + returns: list of basic statistical values + """ + # input results data can be of nonuniform shape + # get flatten data to compare + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + # the statistical values + max_rel_diff = 0 # defined to be max( { abs(diff)/max(abs(ra), abs(rb) ) } ) + max_abs_diff = 0 # defined to be max( { abs(ra-rb) } ) + mean_rel_diff = 0 + mean_abs_diff = 0 + std_rel_diff = 0 + std_abs_diff = 0 + acc_with_diff_precision = [] + rel_diff = [] + abs_diff_percentiles = [] # rel_diff percentiles + rel_diff_percentiles = [] # abs_diff precentiles + + raw_diff = [ra_raw[i]-rb_raw[i] for i in range(len(ra_raw))] + abs_diff = [abs(num) for num in raw_diff] + for i in range(len(ra_raw)): + divider = max([abs(ra_raw[i]), abs(rb_raw[i])]) + val = abs_diff[i]/divider if divider != 0 else 0 + rel_diff.append(val) + + max_rel_diff = max(rel_diff) + max_abs_diff = max(abs_diff) + mean_rel_diff = np.average(rel_diff) + mean_abs_diff = np.average(abs_diff) + std_rel_diff = np.std(rel_diff) + std_abs_diff = np.std(abs_diff) + + # calculate accuracy with different precison + for digit in range(8): + correct = 0 + for i in range(len(ra_raw)): + if format(ra_raw[i], '.'+str(digit)+'f')\ + == format(rb_raw[i], '.'+str(digit)+'f'): + correct += 1 + acc_with_diff_precision.append([digit, float(format(correct/len(ra_raw), '.3f'))]) + + # analyze rel_diff distribution + rel_diff.sort() + abs_diff.sort() + for i in range(20): + rel_diff_percentiles.append(['{}%'.format(i*5), rel_diff[int((i/20)*len(rel_diff))]]) + abs_diff_percentiles.append(['{}%'.format(i*5), abs_diff[int((i/20)*len(abs_diff))]]) + + results = [ + ['max_rel_diff', max_rel_diff], + ['max_abs_diff', max_abs_diff], + ['mean_rel_diff', mean_rel_diff], + ['mean_abs_diff', mean_abs_diff], + ['std_rel_diff', std_rel_diff], + ['std_abs_diff', std_abs_diff], + ['acc_with_diff_precision', acc_with_diff_precision], + ['rel_diff_percentiles', rel_diff_percentiles], + ['abs_diff_percentiles', abs_diff_percentiles] + ] + + return results + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='test model optimization results') + + parser.add_argument('dir', type=str, help='the directory that stores onnx models') + parser.add_argument('ending1', type=str, help='model file name ending(eg, .onnx)') + parser.add_argument('ending2', type=str, help='opt model file name ending(eg. _opt.onnx)') + parser.add_argument('out_file', type=str, help='output csv file name') + parser.add_argument('-p', '--plot', default='N', help='get plots (Y/N)') + parser.add_argument('-i', '--iter_times', default=10, type=int, help='inference times') + + args = parser.parse_args() + + old_models_paths = glob.glob(args.dir+'*'+args.ending1) + new_models_paths = glob.glob(args.dir+'*'+args.ending2) + + stats_table = [[ + 'Model', + 'max_rel_diff', + 'max_abs_diff', + 'mean_rel_diff', + 'mean_abs_diff', + 'std_rel_diff', + 'std_abs_diff', + 'acc_with_diff_precision', + 'rel_diff_percentiles', + 'abs_diff_percentiles' + ]] + + for new_model_path in new_models_paths: + old_model_path = new_model_path[:-len(args.ending2)] + args.ending1 + if old_model_path not in old_models_paths: + continue + + # run inference + results_a, results_b = onnx_tester.onnx_model_results(old_model_path, new_model_path, total_times=args.iter_times) + + # compare inference results + comparision = compare_results(results_a, results_b) + + new_line = [old_model_path.split('/')[-1]] + for item in comparision: + new_line.append(item[1]) + + stats_table.append(new_line) + + # try to read existing file + old_stats_table = [] + try: + old_file = open(args.out_file, 'r') + reader = csv.reader(old_file) + old_header = reader.__next__() + for row in reader: + old_stats_table.append(row) + old_file.close() + except: + pass + + # compare and merge possible old stat data file with new stat data file + header = stats_table[0] + stats_table = stats_table[1:] + new_model_names = set([item[0] for item in stats_table]) + for row in old_stats_table: + if row[0] not in new_model_names: + stats_table.append(row) + stats_table.insert(0, header) + + # write a new stat data file, overwrite old file + new_file = open(args.out_file, 'w', newline='') + writer = csv.writer(new_file) + for row in stats_table: + writer.writerow(row) + new_file.close() + + # make some plots + if args.plot == 'Y': + if len(stats_table) < 2: + exit(0) + + sample_table = stats_table[1:] if len(stats_table) < 6 else stats_table[1:6] + + max_rel_diffs = [round(float(item[1]), 2) for item in stats_table[1:]] + plt.hist(max_rel_diffs, bins=15) + plt.title('Max Relavtive Difference Histogram') + plt.xlabel('Max Relative Difference') + plt.ylabel('Counts') + plt.savefig('max_rel_diff_hist.png') + plt.close() + + max_abs_diffs = [round(float(item[2]), 2) for item in stats_table[1:]] + plt.hist(max_abs_diffs, bins=15) + plt.title('Max Absolute Difference Histogram') + plt.xlabel('Max Absolute Difference') + plt.ylabel('Counts') + plt.savefig('max_abs_diff_hist.png') + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-2] + x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title('Rel_diff Percentiles of Raw and Optimized Models') + plt.xlabel('percentage') + plt.ylabel('relative difference') + plt.legend() + plt.savefig('rel_diff_percentiles.png') + plt.close() + + for line in sample_table: + model_name = line[0] + percentiles = line[-1] + x = [round(i*(1/len(percentiles)), 2) for i in range(len(percentiles))] + y = [ele[1] for ele in percentiles] + plt.plot(x, y, label=model_name) + plt.title('Abs_diff Percentiles of Raw and Optimized Models') + plt.xlabel('percentage') + plt.ylabel('absolute difference') + plt.legend() + plt.savefig('abs_diff_percentiles.png') + plt.close() + + for line in sample_table: + model_name = line[0] + accuracies = line[-3] + x = [acc[0] for acc in accuracies] + y = [acc[1] for acc in accuracies] + plt.plot(x, y, label=model_name) + plt.title('Accuracies with Different Precisions') + plt.xlabel('Decimals') + plt.ylabel('Precision') + plt.legend() + plt.savefig('precisions.png') + plt.close() + + + + + diff --git a/tools/optimizer_scripts/pytorch2onnx.py b/tools/optimizer_scripts/pytorch2onnx.py new file mode 100644 index 0000000..0f2c559 --- /dev/null +++ b/tools/optimizer_scripts/pytorch2onnx.py @@ -0,0 +1,81 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import numpy as np +import struct +import logging +import argparse + +from tools import eliminating +from tools import fusing +from tools import replacing +from tools import other +from tools import combo +from tools import special +from pytorch_exported_onnx_preprocess import torch_exported_onnx_flow + +# Debug use +# logging.basicConfig(level=logging.DEBUG) + +###################################### +# Generate a prototype onnx # +###################################### + +parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") +parser.add_argument('in_file', help='input ONNX or PTH FILE') +parser.add_argument('out_file', help="ouput ONNX FILE") +parser.add_argument('--input-size', dest='input_size', nargs=3, + help='if you using pth, please use this argument to set up the input size of the model. It should be in \'CH H W\' format, e.g. \'--input-size 3 256 512\'.') +parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + +args = parser.parse_args() + +if len(args.in_file) <= 4: + # When the filename is too short. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) +elif args.in_file[-4:] == '.pth': + # Pytorch pth case + logging.warning("Converting from pth to onnx is not recommended.") + onnx_in = args.out_file + # Import pytorch libraries + from torch.autograd import Variable + import torch + import torch.onnx + # import torchvision + # Standard ImageNet input - 3 channels, 224x224. + # Values don't matter as we care about network structure. + # But they can also be real inputs. + if args.input_size is None: + logging.error("\'--input-size\' is required for the pth input file.") + exit(1) + dummy_input = Variable(torch.randn(1, int(args.input_size[0]), int(args.input_size[1]), int(args.input_size[2]))) + # Obtain your model, it can be also constructed in your script explicitly. + model = torch.load(sys.argv[1], map_location='cpu') + # model = torchvision.models.resnet34(pretrained=True) + # Invoke export. + # torch.save(model, "resnet34.pth") + torch.onnx.export(model, dummy_input, args.out_file, opset_version=11) +elif args.in_file[-4:] == 'onnx': + onnx_in = args.in_file +else: + # When the file is neither an onnx or a pytorch pth. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + +onnx_out = args.out_file + +###################################### +# Optimize onnx # +###################################### + +m = onnx.load(onnx_in) + +m = torch_exported_onnx_flow(m, args.disable_fuse_bn) + +onnx.save(m, onnx_out) diff --git a/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py new file mode 100644 index 0000000..509db82 --- /dev/null +++ b/tools/optimizer_scripts/pytorch_exported_onnx_preprocess.py @@ -0,0 +1,80 @@ +import onnx +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer +import sys +import numpy as np +import struct +import logging +import argparse + +from .tools import eliminating +from .tools import fusing +from .tools import replacing +from .tools import other +from .tools import combo +from .tools import special + +# Define general pytorch exported onnx optimize process +def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto: + """Optimize the Pytorch exported onnx. + + Args: + m (ModelProto): the input onnx model + disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False. + + Returns: + ModelProto: the optimized onnx model + """ + m = combo.preprocess(m, disable_fuse_bn) + m = combo.pytorch_constant_folding(m) + m = combo.common_optimization(m) + m = combo.postprocess(m) + + return m + + +# Main Process +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize a Pytorch generated model for Kneron compiler") + parser.add_argument('in_file', help='input ONNX') + parser.add_argument('out_file', help="ouput ONNX FILE") + parser.add_argument('--log', default='i', type=str, help="set log level") + parser.add_argument('--no-bn-fusion', dest='disable_fuse_bn', action='store_true', default=False, + help="set if you have met errors which related to inferenced shape mismatch. This option will prevent fusing BatchNormailization into Conv.") + + args = parser.parse_args() + + if args.log == 'w': + logging.basicConfig(level=logging.WARN) + elif args.log == 'd': + logging.basicConfig(level=logging.DEBUG) + elif args.log == 'e': + logging.basicConfig(level=logging.ERROR) + else: + logging.basicConfig(level=logging.INFO) + + if len(args.in_file) <= 4: + # When the filename is too short. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + elif args.in_file[-4:] == 'onnx': + onnx_in = args.in_file + else: + # When the file is not an onnx file. + logging.error("Invalid input file: {}".format(args.in_file)) + exit(1) + + onnx_out = args.out_file + + ###################################### + # Optimize onnx # + ###################################### + + m = onnx.load(onnx_in) + + m = torch_exported_onnx_flow(m, args.disable_fuse_bn) + + onnx.save(m, onnx_out) diff --git a/tools/optimizer_scripts/res/first_insert_layer.json b/tools/optimizer_scripts/res/first_insert_layer.json new file mode 100644 index 0000000..4fe3f59 --- /dev/null +++ b/tools/optimizer_scripts/res/first_insert_layer.json @@ -0,0 +1,27 @@ +{ + "LAYERNAME" : + { + "bias_bitwidth" : 16, + "LAYERNAME_bias" : [15], + "LAYERNAME_weight" : [3,3,3], + "conv_coarse_shift" : [-4,-4,-4], + "conv_fine_shift" : [0,0,0], + "conv_total_shift" : [-4,-4,-4], + "cpu_mode" : false, + "delta_input_bitwidth" : [0], + "delta_output_bitwidth" : 8, + "flag_radix_bias_eq_output" : true, + "input_scale" : [[1.0,1.0,1.0]], + "output_scale" : [1.0, 1.0, 1.0], + "psum_bitwidth" : 16, + "weight_bitwidth" : 8, + "input_datapath_bitwidth" : [8], + "input_datapath_radix" : [7], + "working_input_bitwidth" : 8, + "working_input_radix" : [7], + "working_output_bitwidth" : 16, + "working_output_radix" : 15, + "output_datapath_bitwidth" : 8, + "output_datapath_radix" : 7 + } +} diff --git a/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh b/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh new file mode 100644 index 0000000..342b198 --- /dev/null +++ b/tools/optimizer_scripts/res/test_onnx_tester_on_difference.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python onnx_tester.py /test_models/mobilenet_v2_224.onnx /test_models/mobilenet_v2_224.cut.onnx +if [ $? -eq 0 ]; then + echo "Those two model results should be different!" + exit 1 +fi + +exit 0 diff --git a/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb b/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb new file mode 100644 index 0000000..81096de Binary files /dev/null and b/tools/optimizer_scripts/res/vdsr_41_20layer_1.pb differ diff --git a/tools/optimizer_scripts/tensorflow2onnx.py b/tools/optimizer_scripts/tensorflow2onnx.py new file mode 100644 index 0000000..13c0dab --- /dev/null +++ b/tools/optimizer_scripts/tensorflow2onnx.py @@ -0,0 +1,147 @@ +import tensorflow as tf +import tf2onnx +import argparse +import logging +import sys +import onnx +import onnx.utils +from tensorflow.python.platform import gfile +from tools import combo, eliminating, replacing, other + +def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto: + """Convert frozen graph pb file into onnx + + Args: + pb_path (str): input pb file path + test_mode (bool, optional): test mode. Defaults to False. + + Raises: + Exception: invalid input file + + Returns: + onnx.ModelProto: converted onnx + """ + TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', '')) + + if 160 <= TF2ONNX_VERSION: + from tf2onnx import tf_loader + else: + from tf2onnx import loader as tf_loader + + if pb_path[-3:] == '.pb': + model_name = pb_path.split('/')[-1][:-3] + + # always reset tensorflow session at begin + tf.reset_default_graph() + + with tf.Session() as sess: + with gfile.FastGFile(pb_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + + if 160 <= int(tf2onnx.version.version.replace('.', '')): + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx( + sess.graph, + {}) + else: + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx( + sess.graph.get_operations(), + {}) + + for n in onnx_nodes: + if len(n.output) == 0: + onnx_nodes.remove(n) + + # find inputs and outputs of graph + nodes_inputs = set() + nodes_outputs = set() + + for n in onnx_nodes: + if n.op_type == 'Placeholder': + continue + for input in n.input: + nodes_inputs.add(input) + for output in n.output: + nodes_outputs.add(output) + + graph_input_names = set() + for input_name in nodes_inputs: + if input_name not in nodes_outputs: + graph_input_names.add(input_name) + + graph_output_names = set() + for n in onnx_nodes: + if n.input and n.input[0] not in nodes_outputs: + continue + if len(n.output) == 0: + n.output.append(n.name + ':0') + graph_output_names.add(n.output[0]) + else: + output_name = n.output[0] + if (output_name not in nodes_inputs) and (0 < len(n.input)): + graph_output_names.add(output_name) + + logging.info('Model Inputs: %s', str(list(graph_input_names))) + logging.info('Model Outputs: %s', str(list(graph_output_names))) + + graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path, + input_names=list(graph_input_names), + output_names=list(graph_output_names)) + + with tf.Graph().as_default() as tf_graph: + tf.import_graph_def(graph_def, name='') + + if 160 <= TF2ONNX_VERSION: + with tf_loader.tf_session(graph=tf_graph): + onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11) + else: + with tf.Session(graph=tf_graph): + onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph, + input_names=inputs, + output_names=outputs, + opset=11) + + # Optimize with tf2onnx.optimizer + onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph) + model_proto = onnx_graph.make_model(model_name) + + # Make tf2onnx output compatible with the spec. of other.polish_model + replacing.replace_initializer_with_Constant(model_proto.graph) + model_proto = other.polish_model(model_proto) + + else: + raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"') + + # rename + m = model_proto + + m = combo.preprocess(m) + m = combo.common_optimization(m) + m = combo.tensorflow_optimization(m) + m = combo.postprocess(m) + + if not test_mode: + g = m.graph + eliminating.eliminate_shape_changing_after_input(g) + + m = other.polish_model(m) + return m + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert tensorflow pb file to onnx file and optimized onnx file. Or just optimize tensorflow onnx file.') + parser.add_argument('in_file', help='input file') + parser.add_argument('out_file', help='output optimized model file') + parser.add_argument('-t', '--test_mode', default=False, help='test mode will not eliminate shape changes after input') + + args = parser.parse_args() + logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.INFO) + m = tf2onnx_flow(args.in_file, args.test_mode) + onnx.save(m, args.out_file) + logging.info('Save Optimized ONNX: %s', args.out_file) diff --git a/tools/optimizer_scripts/tflite_vs_onnx.py b/tools/optimizer_scripts/tflite_vs_onnx.py new file mode 100644 index 0000000..ffeecea --- /dev/null +++ b/tools/optimizer_scripts/tflite_vs_onnx.py @@ -0,0 +1,68 @@ +import argparse +import numpy as np +import tensorflow as tf +import onnx +import onnxruntime + +from tools import helper + +def compare_tflite_and_onnx(tflite_file, onnx_file, total_times=10): + # Setup onnx session and get meta data + onnx_session = onnxruntime.InferenceSession(onnx_file, None) + onnx_outputs = onnx_session.get_outputs() + assert len(onnx_outputs) == 1, "The onnx model has more than one output" + onnx_model = onnx.load(onnx_file) + onnx_graph = onnx_model.graph + onnx_inputs = onnx_graph.input + assert len(onnx_inputs) == 1, "The onnx model has more than one input" + _, onnx_input_shape = helper.find_size_shape_from_value(onnx_inputs[0]) + # Setup TFLite sessio and get meta data + tflite_session = tf.lite.Interpreter(model_path=tflite_file) + tflite_session.allocate_tensors() + tflite_inputs = tflite_session.get_input_details() + tflite_outputs = tflite_session.get_output_details() + tflite_input_shape = tflite_inputs[0]['shape'] + # Compare input shape + assert(len(onnx_input_shape) == len(tflite_input_shape)), "TFLite and ONNX shape unmatch." + assert(onnx_input_shape == [tflite_input_shape[0], tflite_input_shape[3], tflite_input_shape[1], tflite_input_shape[2]]), "TFLite and ONNX shape unmatch." + # Generate random number and run + tflite_results = [] + onnx_results = [] + for _ in range(total_times): + # Generate input + tflite_input_data = np.array(np.random.random_sample(tflite_input_shape), dtype=np.float32) + onnx_input_data = np.transpose(tflite_input_data, [0, 3, 1, 2]) + # Run tflite + tflite_session.set_tensor(tflite_inputs[0]['index'], tflite_input_data) + tflite_session.invoke() + tflite_results.append(tflite_session.get_tensor(tflite_outputs[0]['index'])) + # Run onnx + onnx_input_dict = {onnx_inputs[0].name: onnx_input_data} + onnx_results.append(onnx_session.run([], onnx_input_dict)[0]) + + return tflite_results, onnx_results + + +if __name__ == '__main__': + # Argument parser. + parser = argparse.ArgumentParser(description="Compare a TFLite model and an ONNX model to check if they have the same output.") + parser.add_argument('tflite_file', help='input tflite file') + parser.add_argument('onnx_file', help='input ONNX file') + + args = parser.parse_args() + + results_a, results_b = compare_tflite_and_onnx(args.tflite_file, args.onnx_file, total_times=10) + ra_flat = helper.flatten_with_depth(results_a, 0) + rb_flat = helper.flatten_with_depth(results_b, 0) + shape_a = [item[1] for item in ra_flat] + shape_b = [item[1] for item in rb_flat] + assert shape_a == shape_b, 'two results data shape doesn\'t match' + ra_raw = [item[0] for item in ra_flat] + rb_raw = [item[0] for item in rb_flat] + + try: + np.testing.assert_almost_equal(ra_raw, rb_raw, 8) + print('Two models have the same behaviour.') + except Exception as mismatch: + print(mismatch) + exit(1) \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/__init__.py b/tools/optimizer_scripts/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/optimizer_scripts/tools/combo.py b/tools/optimizer_scripts/tools/combo.py new file mode 100644 index 0000000..adadecb --- /dev/null +++ b/tools/optimizer_scripts/tools/combo.py @@ -0,0 +1,258 @@ +"""Combo functions that are usually called together. +""" + +import logging +import onnx.utils +try: + from onnx import optimizer +except ImportError: + import onnxoptimizer as optimizer + +from . import helper +from . import other +from . import replacing +from . import eliminating +from . import fusing +from . import constant_folding +from . import removing_transpose +from . import modhelper +from .common_pattern import torch_pattern_match, tf_pattern_match +from .helper import logger + +def preprocess(model_proto, disable_fuse_bn=False, duplicate_shared_weights=True): + """The most common used functions before other processing. + + Args: + model_proto: the original model input + duplicate_shared_weights(bool, optional): duplicate shared weights. Defaults to True. + + Return: + the new model after preprocessing + + It includes: + + - inference shapes + - optimize model by ONNX library + - give names to the nodes + - replace initializer with Constant node + - replace -1 batch size with 1 + - eliminate dropout and identity + - eliminate no children inputs + - topological sort + + The optimizations provided by ONNX: + + - eliminate_identity + - eliminate_nop_dropout + - eliminate_nop_transpose + - eliminate_nop_pad + - eliminate_unused_initializer + - eliminate_deadend + - fuse_consecutive_squeezes + - fuse_consecutive_transposes + - fuse_add_bias_into_conv + - fuse_transpose_into_gemm + - fuse_matmul_add_bias_into_gemm + - fuse_bn_into_conv + - fuse_pad_into_conv + + """ + logger.info("Preprocessing the model...") + helper.setup_current_opset_version(model_proto) + eliminating.eliminate_empty_value_infos(model_proto.graph) + other.add_name_to_node(model_proto.graph) + other.rename_all_node_name(model_proto.graph) + replacing.replace_initializer_with_Constant(model_proto.graph) + other.topological_sort(model_proto.graph) + m = other.polish_model(model_proto) + passes = ['extract_constant_to_initializer', + 'eliminate_nop_dropout', + 'eliminate_deadend', + 'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv'] + if not disable_fuse_bn: + passes.append('fuse_bn_into_conv') + m = optimizer.optimize(m, passes) + g = m.graph + # Add name again since onnx optimizer higher than 1.7 may remove node names. + other.add_name_to_node(g) + if duplicate_shared_weights: + replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=True) + other.duplicate_param_shared_constant(g) + else: + replacing.replace_initializer_with_Constant(g, duplicate_shared_weights=False) + other.topological_sort(g) + m = other.polish_model(m) + g = m.graph + eliminating.eliminate_consecutive_Cast(m.graph) + eliminating.eliminate_Cast_after_input(m.graph) + eliminating.eliminate_nop_pads(g) + eliminating.eliminate_nop_cast(g) + eliminating.eliminate_Identify_and_Dropout(g) + eliminating.eliminate_trivial_maxpool(g) + eliminating.eliminate_no_children_input(g) + other.format_value_info_shape(g) + other.topological_sort(g) + m = other.inference_shapes(m) + g = m.graph + replacing.replace_split_with_slices(g) + other.topological_sort(g) + + return m + + +def common_optimization(m): + """Common optimizations can be used in most cases. + + :param m: the original model input\\ + :return: the new model after preprocessing + + It includes: + + - transpose B in Gemm + - fuse BN into Gemm + - fuse consecutive Gemm + - replace AveragePool with GAP + - replace Squeeze/Unsqueeze with Reshape + - replace Reshape with Flatten + """ + logger.info("Doing nodes fusion and replacement... ") + m = other.polish_model(m) + g = m.graph + other.transpose_B_in_Gemm(g) + fusing.fuse_BN_into_Gemm(g) + fusing.fuse_BN_with_Reshape_into_Gemm(g) + fusing.fuse_Gemm_into_Gemm(g) + fusing.fuse_consecutive_reducemean(g) + fusing.fuse_slice_nodes_into_conv(g) + fusing.fuse_relu_min_into_clip(g) + other.duplicate_shared_Flatten(g) + replacing.replace_average_pool_with_GAP(g) + + m = other.polish_model(m) + g = m.graph + + replacing.replace_Squeeze_with_Reshape(g) + replacing.replace_Unsqueeze_with_Reshape(g) + replacing.replace_Reshape_with_Flatten(g) + replacing.replace_ReduceMean_with_GlobalAveragePool(g) + replacing.replace_Sum_with_Adds(g) + replacing.replace_constant_input_concat_with_pad(g) + other.topological_sort(g) + return m + + +def pytorch_constant_folding(m): + """Constant folding needed by Pytorch exported models. It should be done + before using onnx optimizers since the dynamic shape structure may affect + the optimizations. + + :param m: the original model input\\ + :return: the new model after preprocessing + """ + logger.info("Working on constant folding.") + replacing.replace_shape_with_constant(m.graph) + replacing.replace_ConstantOfShape_with_constant(m.graph) + + # constant_folding + m = other.inference_shapes(m) + while constant_folding.constant_folding(m.graph): + logging.debug("After constant folding jobs.") + other.topological_sort(m.graph) + while len(m.graph.value_info) != 0: + m.graph.value_info.pop() + + m = other.inference_shapes(m) + replacing.replace_shape_with_constant(m.graph) + other.topological_sort(m.graph) + m = torch_pattern_match(m) + m = optimizer.optimize(m, ['eliminate_deadend']) + return m + + +def tensorflow_optimization(m): + """Optimizations for tf models can be used in most cases. + + :param m: the original model input\\ + :return: the new model after preprocessing + + It includes: + + - eliminate shape change after input + - eliminate Reshape cast + - eliminate Squeeze before Reshape + - fuse Transpose into Constant + - replace Shape with Constant + """ + + fusing.fuse_Transpose_into_Constant(m.graph) + fusing.fuse_MatMul_and_Add_into_Gemm(m.graph) + other.topological_sort(m.graph) + + m = other.polish_model(m) + + # constant folding + replacing.replace_shape_with_constant(m.graph) + + # constant_folding + m = other.inference_shapes(m) + while constant_folding.constant_folding(m.graph): + logging.debug("After constant folding jobs.") + other.topological_sort(m.graph) + while len(m.graph.value_info) != 0: + m.graph.value_info.pop() + + m = other.inference_shapes(m) + replacing.replace_shape_with_constant(m.graph) + other.topological_sort(m.graph) + m = tf_pattern_match(m) + m = optimizer.optimize(m, ['eliminate_deadend']) + + eliminating.eliminate_consecutive_reshape(m.graph) + eliminating.eliminate_Squeeze_before_Reshape(m.graph) + other.topological_sort(m.graph) + return m + + +def postprocess(m): + """Inference the shape and prepare for export. + + :param m: the original model input\\ + :return: the new model after preprocessing + """ + logger.info("Postprocessing the model...") + while len(m.graph.value_info) > 0: + m.graph.value_info.pop() + m = other.polish_model(m) + eliminating.eliminate_single_input_Concat(m.graph) + eliminating.eliminate_nop_Maxpool_and_AveragePool(m.graph) + eliminating.eliminate_trivial_elementwise_calculation(m.graph) + m = other.polish_model(m) + + replacing.replace_depthwise_1x1_with_bn(m.graph) + m = other.polish_model(m) + + # removing transpose + m = removing_transpose.eliminate_transposes(m) + m = other.polish_model(m) + removing_transpose.remove_trivial_transpose(m.graph) + removing_transpose.fuse_Transpose_into_Gemm_weight(m.graph) + + # fuse some nodes + fusing.fuse_mul_and_add_into_bn(m.graph) + m = other.polish_model(m) + fusing.fuse_mul_and_add_into_gemm(m.graph) + m = other.polish_model(m) + fusing.fuse_conv_and_add_into_conv(m.graph) + m = other.polish_model(m) + replacing.replace_mul_to_bn(m.graph) + replacing.replace_div_to_bn(m.graph) + replacing.replace_add_to_bn(m.graph) + replacing.replace_sub_to_bn(m.graph) + replacing.replace_sub_with_bn_and_add(m.graph) + m = other.polish_model(m) + + other.add_output_to_value_info(m.graph) + m = optimizer.optimize(m, ['eliminate_deadend']) + m.producer_name = 'kneron_formatter' + return m diff --git a/tools/optimizer_scripts/tools/common_pattern.py b/tools/optimizer_scripts/tools/common_pattern.py new file mode 100644 index 0000000..b65d5bd --- /dev/null +++ b/tools/optimizer_scripts/tools/common_pattern.py @@ -0,0 +1,157 @@ +from collections import defaultdict +import numpy as np +import onnx.helper +import onnx.utils + +from . import modhelper +from . import helper +from . import other + +def torch_pattern_match(m): + # Create a map from optype to the nodes. + optype2node = defaultdict(list) + for node in m.graph.node: + optype2node[node.op_type].append(node) + for matmul_node in optype2node['MatMul']: + pattern_matmul_mul_add(m.graph, matmul_node) + for resize_node in optype2node['Resize']: + # torch nn.UpsamplingBilinear2d will be given us 4 input: "X, roi, scales, sizes" + if len(resize_node.input) != 4: + continue + make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) + m = onnx.shape_inference.infer_shapes(m) + polish_RESIZE_input_param_node(m.graph, resize_node.name) + m = other.polish_model(m) + return m + +def tf_pattern_match(m): + # Create a map from optype to the nodes. + optype2node = defaultdict(list) + for node in m.graph.node: + optype2node[node.op_type].append(node) + for matmul_node in optype2node['MatMul']: + pattern_matmul_mul_add(m.graph, matmul_node) + for resize_node in optype2node['Resize']: + # In tensorflow2onnx, ReizeXXX will be given us 4 input: "X, roi, scales, sizes" + # and node output name will be given the "node name + :0" + if len(resize_node.input) != 4: + continue + make_UpsamplingBilinear2d_value_info(m.graph, resize_node.name) + m = onnx.shape_inference.infer_shapes(m) + polish_RESIZE_input_param_node(m.graph, resize_node.name) + m = other.polish_model(m) + return m + +def pattern_matmul_mul_add(g, matmul_node): + # Check node match - Mul node + next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) + if len(next_nodes) != 1: + return + if next_nodes[0].op_type != 'Mul': + return + mul_node = next_nodes[0] + # Check node match - Add node + next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) + if len(next_nodes) != 1: + return + if next_nodes[0].op_type != 'Add': + return + add_node = next_nodes[0] + # Check Mul weight + mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) + if mul_weight_node.op_type != 'Constant': + return + weight_size, mul_weight = helper.constant_to_list(mul_weight_node) + for i in mul_weight: + if i != 1: + return + channel = weight_size[0] + # Check Add weight + add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) + if add_weight_node.op_type != 'Constant': + return + # Check MatMul weight to see if it need weight broadcast + matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) + matmul_weight = helper.constant_to_numpy(matmul_weight_node) + if matmul_weight.shape[1] == 1: + # Weight broadcast + new_matmul_weight = np.tile(matmul_weight, channel) + new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) + g.node.remove(matmul_weight_node) + g.node.extend([new_matmul_weight_node]) + value = helper.find_value_by_name(g, matmul_weight_node.output[0]) + if value is not None: + g.value_info.remove(value) + # Remove Mul node + g.node.remove(mul_weight_node) + value = helper.find_value_by_name(g, mul_weight_node.output[0]) + if value is not None: + g.value_info.remove(value) + g.node.remove(mul_node) + value = helper.find_value_by_name(g, mul_node.output[0]) + if value is not None: + g.value_info.remove(value) + # Fuse Matmul and Add + gemm_node = onnx.helper.make_node( + 'Gemm', + [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], + [add_node.output[0]], + name = matmul_node.name, + alpha = 1.0, + beta = 1.0, + transA = 0, + transB = 0 + ) + g.node.extend([gemm_node]) + # Clean up + g.node.remove(matmul_node) + g.node.remove(add_node) + value = helper.find_value_by_name(g, matmul_node.output[0]) + if value is not None: + g.value_info.remove(value) + other.topological_sort(g) + +def make_UpsamplingBilinear2d_value_info(g, resize_node_name): + resize_node = helper.find_node_by_node_name(g, resize_node_name) + + shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) + shape_data = helper.constant_to_numpy(shape_data_node).astype(int) + l_shape_data = list(shape_data) + if l_shape_data[0] == 0: + l_shape_data[0] = 1 + l_shape_data[0] + shape_data = np.array(l_shape_data) + + new_output_value_info = onnx.helper.make_tensor_value_info( + resize_node.output[0], + onnx.helper.TensorProto.FLOAT, + shape_data.tolist() + ) + + g.value_info.extend([new_output_value_info]) + +def polish_RESIZE_input_param_node(g, resize_node_name): + resize_node = helper.find_node_by_node_name(g, resize_node_name) + + shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) + shape_data = helper.constant_to_numpy(shape_data_node).astype(int) + + # handle 0 batch size which is invalid + if shape_data[0] == 0: + shape_data[0] = 1 + + pre_node_output_value_info = helper.find_value_by_name(g, resize_node.input[0]) + ori_shape = np.array([pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, + pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value]) + + resize_node.input.remove(resize_node.input[3]) + + + resize_scales = np.array(shape_data/ori_shape).astype(float) + resize_scale_node = helper.list_to_constant('resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT) + + resize_node.input[2] = resize_scale_node.name + g.node.extend([resize_scale_node]) + + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/constant_folding.py b/tools/optimizer_scripts/tools/constant_folding.py new file mode 100644 index 0000000..8149628 --- /dev/null +++ b/tools/optimizer_scripts/tools/constant_folding.py @@ -0,0 +1,995 @@ +import onnx.utils +import onnx +import numpy as np +import logging +import traceback + +from . import helper +from .general_graph import Graph, Node +from .other import topological_sort +from .replacing import replace_shape_with_constant +from .helper import logger + +def are_all_inputs_Constant_with_one_child(g, node): + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + if input_node is None or input_node.op_type != 'Constant': + return False + relative_outputs = helper.find_nodes_by_input_name(g, input_name) + if len(relative_outputs) > 1: + return False + return True + + +def constant_folding(g): + """ Do constant folding until nothing more can be done. + + :param g: The onnx GraphProto\\ + :return: If any node is folded, return True. Otherwise, return False. + """ + keep_folding = True # Keep the while loop + folded = False # Return value + try: + # Before constant folding, duplicate the constant nodes. + duplicate_constant_node(g) + while keep_folding: + keep_folding = False + for node in g.node: + # Check if the node is foldable + if node.op_type not in constant_folding_nodes.keys(): + continue + # Check if the parents of the node are all single follower constant node. + if not are_all_inputs_Constant_with_one_child(g, node): + continue + # Constant folding for the specific node + if constant_folding_nodes[node.op_type](g, node): + logging.debug("Constant nodes and %s %s are folded.", + node.op_type, node.name) + folded = True + keep_folding = True + else: + logging.debug( + "Constant nodes and %s %s are skipped.", node.op_type, node.name) + except Exception as e: + logger.error("An exception is raised while constant folding.") + logger.error(traceback.format_exc()) + return folded + + + +def duplicate_constant_node(g): + """ Duplicate the constant node if its following nodes contain constant folding + nodes. Create and link the new constant nodes to the constant folding nodes. + """ + for node in g.node: + # Find a valid constant node + if node.op_type != 'Constant': + continue + output_val_info = helper.find_value_by_name(g, node.output[0]) + if output_val_info is None: + print("Cannot inference the shape of Const node output: " + + node.output[0]) + exit(1) + data_shape = helper.get_shape_from_value_info(output_val_info) + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + + # For constant that has only one following node, no need to duplicate + if len(output_nodes) < 2: + continue + + # Check if its following nodes are foldable + foldable_output_nodes = list(filter(lambda n: n.op_type in + constant_folding_nodes.keys(), output_nodes)) + if not foldable_output_nodes: + continue + + # Duplicate the node needed by foldable nodes + for i in range(len(foldable_output_nodes)): + logging.debug("Found constant %s and %s %s are availble for folding. Duplicate constant.", + node.name, foldable_output_nodes[i].op_type, foldable_output_nodes[i].name) + output_name = node.output[0] + '_dup_' + str(i) + new_constant_node = onnx.helper.make_node( + 'Constant', + [], + [output_name], + name=output_name, + value=node.attribute[0].t + ) + new_val_info = onnx.helper.make_tensor_value_info( + output_name, + node.attribute[0].t.data_type, + data_shape + ) + input_ind = list(foldable_output_nodes[i].input).index( + node.output[0]) + foldable_output_nodes[i].input[input_ind] = output_name + + g.node.extend([new_constant_node]) + g.value_info.extend([new_val_info]) + + # If all following nodes are foldable node, delete the original node. + if len(foldable_output_nodes) == len(output_nodes): + g.node.remove(node) + g.value_info.remove(output_val_info) + + topological_sort(g) + + return + +def slice_constant_folding(g, node): + op_version = helper.get_current_opset_version() + # only support opset 9 & 11 + if op_version == 11: + return slice_constant_folding_Opset_11(g, node) + elif op_version == 9: + return slice_constant_folding_Opset_9(g, node) + +def slice_constant_folding_Opset_11(g, node): + """ Fold constant and slice nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape, data_list = helper.constant_to_list(pre_node) + + starts_node = helper.find_node_by_output_name(g, node.input[1]) + _, starts = helper.constant_to_list(starts_node) + + ends_node = helper.find_node_by_output_name(g, node.input[2]) + _, ends = helper.constant_to_list(ends_node) + + + axes_node = None if len(node.input) <= 3 else helper.find_node_by_output_name(g, node.input[3]) + if not axes_node: + axes = list(range(len(helper.get_shape(data_list)))) + else: + _, axes = helper.constant_to_list(axes_node) + + steps_node = None if len(node.input) <= 4 else helper.find_node_by_output_name(g, node.input[4]) + if not steps_node: + steps = [1]*len(helper.get_shape(data_list)) + else: + _, steps = helper.constant_to_list(steps_node) + + + data_list = list(map(int, data_list)) + starts = list(map(int, starts)) + ends = list(map(int, ends)) + axes = list(map(int, axes)) + steps = list(map(int, steps)) + + data_list = np.reshape(data_list, pre_shape) + + new_data = None + for idx, _ in enumerate(axes): + new_data = np.apply_along_axis( lambda x: x[starts[idx] : ends[idx] : steps[idx]], idx, data_list ) + + new_node = helper.list_to_constant(node.output[0], helper.get_shape( + new_data), helper.flatten_to_list(new_data)) + g.node.extend([new_node]) + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(node) + g.node.remove(pre_node) + + return True + +def slice_constant_folding_Opset_9(g, node): + """ Fold constant and slice nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape, data_list = helper.constant_to_list(pre_node) + + data_list = np.reshape(data_list, pre_shape) + axes = helper.get_attribute_by_name(node, 'axes') + ends = list(helper.get_attribute_by_name(node, 'ends').ints) + starts = list(helper.get_attribute_by_name(node, 'starts').ints) + + if not axes: + axes = list(range(len(helper.get_shape(data_list)))) + else: + axes = list(axes.ints) + + new_data = helper.slice_data(data_list, starts, ends, axes) + new_node = helper.list_to_constant(node.output[0], helper.get_shape( + new_data), helper.flatten_to_list(new_data)) + g.node.extend([new_node]) + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(node) + g.node.remove(pre_node) + + return True + +def cast_constant_folding(g, node): + """ Fold constant and cast node to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + data_type = node.attribute[0].i + if data_type in (6, 7): + data = list(map(int, data)) + elif data_type == onnx.helper.TensorProto.FLOAT: + data = list(map(float, data)) + else: + raise RuntimeError('data type not supported') + + if shape == 1: + tensor = onnx.helper.make_tensor( + name=pre_node.attribute[0].name, + data_type=data_type, + dims=[], + vals=data + ) + else: + tensor = onnx.helper.make_tensor( + name=pre_node.attribute[0].name, + data_type=data_type, + dims=shape, + vals=helper.flatten_to_list(data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=tensor + ) + g.node.extend([new_node]) + + value_info = helper.find_value_by_name(g, pre_node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + value_info = helper.find_value_by_name(g, node.output[0]) + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(pre_node) + g.node.remove(node) + + return True + + +def reduceprod_constant_folding(g, node): + """ Fold constant and reduceprod nodes to a single constant node. + """ + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data_set = helper.constant_to_list(pre_node) + tensor = pre_node.attribute[0].t + + data_set = np.reshape(data_set, shape) + for att in node.attribute: + if att.name == 'axes': + axes = list(att.ints) + else: + keepdims = int(att.i) + + new_data = np.prod(data_set, axis=tuple(axes), keepdims=keepdims == 1) + new_shape = helper.get_shape(new_data) + new_flat_data = helper.flatten_to_list(new_data) + new_tensor = onnx.helper.make_tensor( + name=node.output[0], + data_type=tensor.data_type, + dims=new_shape, + vals=new_flat_data + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + value_info = None + for item in g.value_info: + if item.name == pre_node.output[0]: + value_info = item + if value_info is not None: + g.value_info.remove(value_info) + g.node.remove(pre_node) + g.node.remove(node) + + return True + + +def reshape_constant_input_folding(g, node): + """ Fold constant and reshape nodes to a single constant node. + """ + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) + + data = helper.constant_to_numpy(pre_data_node) + _, shape = helper.constant_to_list(pre_shape_node) + new_data = np.reshape(data, shape) + + new_tensor = onnx.helper.make_tensor( + name=node.output[0], + data_type=pre_data_node.attribute[0].t.data_type, + dims=new_data.shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + g.node.extend([new_node]) + + data_val_info = helper.find_value_by_name(g, pre_data_node.output[0]) + shape_val_info = helper.find_value_by_name(g, pre_shape_node.output[0]) + + g.value_info.remove(data_val_info) + g.value_info.remove(shape_val_info) + + g.node.remove(node) + g.node.remove(pre_data_node) + g.node.remove(pre_shape_node) + + return True + + +def concat_constant_folding(g, node): + """ Fold constant and concat nodes to a single constant node. + """ + node_to_del = [] + valid_inputs = True + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + input_node_output = helper.find_nodes_by_input_name(g, input_name) + if len(input_node_output) > 1: + valid_inputs = False + break + if input_node.op_type != 'Constant': + valid_inputs = False + break + + if not valid_inputs: + return False + + input_data = [] + input_shapes = [] + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + s, d = helper.constant_to_list(input_node) + d = np.reshape(d, s) + input_data.append(d) + input_shapes.append(s) + node_to_del.append(input_node) + + concat_data = np.concatenate(input_data, axis=node.attribute[0].i) + node_data_type = input_node.attribute[0].t.data_type + if concat_data.dtype in [np.int32, np.int64]: + node_data_type = onnx.helper.TensorProto.INT64 + elif concat_data.dtype in [np.float32, np.float64]: + node_data_type = onnx.helper.TensorProto.FLOAT + + new_node = helper.list_to_constant( + node.output[0], + helper.get_shape(concat_data), + helper.flatten_to_list(concat_data), + data_type=node_data_type + ) + g.node.extend([new_node]) + node_to_del.append(node) + + for input_name in node.input: + val_info = helper.find_value_by_name(g, input_name) + if val_info: + g.value_info.remove(val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def transpose_constant_folding(g, node): + """Fold constant and transpose nodes to a single constant node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + np_data = np.reshape(data, shape) + permutation = list(node.attribute[0].ints) + + new_data = np.transpose(np_data, permutation) + new_shape = new_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + new_data.flatten().tolist(), + data_type=pre_node.attribute[0].t.data_type + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node]) + + pre_val_info = helper.find_value_by_name(g, node.input[0]) + g.value_info.remove(pre_val_info) + + next_val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(next_val_info) + + new_val_info = onnx.helper.make_tensor_value_info( + node.output[0], + pre_node.attribute[0].t.data_type, + new_shape + ) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + folded = True + + return folded + + +def unsqueeze_constant_folding(g, node): + """Fold constant and unsqueeze nodes to a single constant node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + if type(shape) == int: + np_data = data[0] + else: + np_data = np.reshape(data, shape) + axes = list(node.attribute[0].ints) + axes.sort() + + for dim in axes: + np_data = np.expand_dims(np_data, axis=dim) + new_shape = np_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + np_data.flatten().tolist(), + data_type=pre_node.attribute[0].t.data_type + ) + g.node.extend([new_node]) + node_to_del.extend([node, pre_node]) + + pre_val_info = helper.find_value_by_name(g, node.input[0]) + next_val_info = helper.find_value_by_name(g, node.output[0]) + if pre_val_info is not None: + g.value_info.remove(pre_val_info) + else: + print(node.name) + if next_val_info is not None: + g.value_info.remove(next_val_info) + + new_val_info = onnx.helper.make_tensor_value_info( + node.output[0], + pre_node.attribute[0].t.data_type, + new_shape + ) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def gather_constant_folding(g, node): + """Fold constant and gather nodes to a single constant node. + """ + node_to_del = [] + + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_indices_node = helper.find_node_by_output_name(g, node.input[1]) + + shape, data = helper.constant_to_list(pre_data_node) + indice_shape, indices = helper.constant_to_list(pre_indices_node) + if type(indice_shape) == int: + indices = indices[0] + + np_data = np.reshape(data, shape) + if len(node.attribute) < 1: + axis = 0 + else: + axis = node.attribute[0].i + + new_data = np.take(np_data, indices, axis=axis) + new_shape = new_data.shape + new_node = helper.list_to_constant( + node.output[0], + new_shape, + new_data.flatten().tolist(), + data_type=pre_data_node.attribute[0].t.data_type + ) + + node_to_del.extend([node, pre_data_node, pre_indices_node]) + g.node.extend([new_node]) + + val_info_1 = helper.find_value_by_name(g, node.input[0]) + val_info_2 = helper.find_value_by_name(g, node.input[1]) + val_info_3 = helper.find_value_by_name(g, node.output[0]) + new_val_info = onnx.helper.make_tensor_value_info( + new_node.output[0], + pre_data_node.attribute[0].t.data_type, + new_shape + ) + + if val_info_1 is not None: + g.value_info.remove(val_info_1) + if val_info_2 is not None: + g.value_info.remove(val_info_2) + if val_info_3 is not None: + g.value_info.remove(val_info_3) + g.value_info.extend([new_val_info]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def add_constant_folding(g, node): + """Fold constant and add nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + if not pre_node_1 or not pre_node_2: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + try: + new_data = np.add(np_data1, np_data2) + except: + raise RuntimeError('can\'t broadcast and add two data sets') + + new_node = helper.list_to_constant( + node.output[0], + new_data.shape, + new_data.flatten().tolist(), + data_type=pre_node_1.attribute[0].t.data_type + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.value_info.remove(helper.find_value_by_name(g, pre_node_1.output[0])) + g.value_info.remove(helper.find_value_by_name(g, pre_node_2.output[0])) + folded = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return folded + + +def sqrt_constant_folding(g, node): + """ Fold constant and sqrt nodes to a single node. + """ + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + np_data = np.sqrt(np.reshape(data, shape)) + output_val_info = helper.find_value_by_name(g, node.output[0]) + input_val_info = helper.find_value_by_name(g, node.input[0]) + data_type = output_val_info.type.tensor_type.elem_type + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=data_type, + dims=shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.value_info.remove(input_val_info) + node_to_del.extend([pre_node, node]) + g.node.extend([new_node]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def reciprocal_constant_folding(g, node): + """ Fold constant and reciprocal nodes to a single constant node. + """ + node_to_del = [] + + pre_node = helper.find_node_by_output_name(g, node.input[0]) + shape, data = helper.constant_to_list(pre_node) + data = list(map(lambda x: x if abs(x) > 1.e-8 else 1.e-8, data)) + np_data = np.reshape(data, shape) + np_data = np.reciprocal(np_data) + + input_val_info = helper.find_value_by_name(g, node.input[0]) + output_val_info = helper.find_value_by_name(g, node.output[0]) + data_type = output_val_info.type.tensor_type.elem_type + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=data_type, + dims=shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node]) + g.node.extend([new_node]) + + g.value_info.remove(input_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def mul_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + + pre_value_info1 = helper.find_value_by_name(g, node.input[0]) + pre_value_info2 = helper.find_value_by_name(g, node.input[1]) + if pre_value_info1 is None or pre_value_info2 is None: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + + try: + new_data = np.multiply(np_data1, np_data2) + except: + raise RuntimeError('can not broadcast and multiply two data sets') + + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=new_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.node.extend([new_node]) + + g.value_info.remove(pre_value_info1) + g.value_info.remove(pre_value_info2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def div_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + + pre_value_info1 = helper.find_value_by_name(g, node.input[0]) + pre_value_info2 = helper.find_value_by_name(g, node.input[1]) + if pre_value_info1 is None or pre_value_info2 is None: + return False + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + np_data1 = np.reshape(data1, shape1) + np_data2 = np.reshape(data2, shape2) + + try: + new_data = np.divide(np_data1, np_data2) + except: + raise RuntimeError('can not broadcast and multiply two data sets') + + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + # Check data type if it is int + if pre_node_1.attribute[0].t.data_type == 7: + new_data = new_data.astype('int64') + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=new_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + node_to_del.extend([node, pre_node_1, pre_node_2]) + g.node.extend([new_node]) + + g.value_info.remove(pre_value_info1) + g.value_info.remove(pre_value_info2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def sub_constant_folding(g, node): + """ Fold constant and sub nodes to a single node. + """ + node_to_del = [] + pre_node_1 = helper.find_node_by_output_name(g, node.input[0]) + pre_node_2 = helper.find_node_by_output_name(g, node.input[1]) + pre_val_info_1 = helper.find_value_by_name(g, node.input[0]) + pre_val_info_2 = helper.find_value_by_name(g, node.input[1]) + + shape1, data1 = helper.constant_to_list(pre_node_1) + shape2, data2 = helper.constant_to_list(pre_node_2) + + new_data = np.subtract(data1, data2) + # Special shape for single element. + if shape1 == 1 and shape2 == 1: + new_shape = [] + else: + new_shape = new_data.shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node_1.attribute[0].t.data_type, + dims=new_shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([node, pre_node_1, pre_node_2]) + + g.value_info.remove(pre_val_info_1) + g.value_info.remove(pre_val_info_2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def neg_constant_folding(g, node): + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + + shape, data_list = helper.constant_to_list(pre_node) + new_data_list = [-num for num in data_list] + + new_tensor = onnx.helper.make_tensor( + name=pre_node.name+'_neg_tensor', + data_type=pre_node.attribute[0].t.data_type, + dims=shape, + vals=new_data_list + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([pre_node, node]) + g.value_info.remove(helper.find_value_by_name(g, node.input[0])) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + return True + + +def floor_constant_folding(g, node): + node_to_del = [] + pre_node = helper.find_node_by_output_name(g, node.input[0]) + + shape, data = helper.constant_to_list(pre_node) + new_data = np.floor(data).flatten().tolist() + + if shape == 1: + new_shape = [] + else: + new_shape = shape + + new_tensor = onnx.helper.make_tensor( + name=node.output[0]+'_data', + data_type=pre_node.attribute[0].t.data_type, + dims=new_shape, + vals=helper.flatten_to_list(new_data) + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + g.node.extend([new_node]) + node_to_del.extend([pre_node, node]) + old_value = helper.find_value_by_name(g, node.input[0]) + if old_value is not None: + g.value_info.remove(old_value) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + return True + + +def bn_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + # Prepare data + node_to_del = [] + input_node = helper.find_node_by_output_name(g, node.input[0]) + scale_node = helper.find_node_by_output_name(g, node.input[1]) + bias_node = helper.find_node_by_output_name(g, node.input[2]) + mean_node = helper.find_node_by_output_name(g, node.input[3]) + var_node = helper.find_node_by_output_name(g, node.input[4]) + + input_value_info = [] + for i in range(5): + input_value_info.append(helper.find_value_by_name(g, node.input[i])) + + if input_value_info[0] is None: + return False + + input_data = helper.constant_to_numpy(input_node) + scale_data = helper.constant_to_numpy(scale_node) + bias_data = helper.constant_to_numpy(bias_node) + mean_data = helper.constant_to_numpy(mean_node) + var_data = helper.constant_to_numpy(var_node) + + epsilon = helper.get_var_attribute_by_name(node, 'epsilon', 'float') + if epsilon is None: + epsilon = 0.00001 + + # Calculate new node + new_data = scale_data * (input_data - mean_data) / np.sqrt(var_data + epsilon) + bias_data + + new_node = helper.numpy_to_constant(node.output[0], new_data) + + # Reconnect the graph + node_to_del.extend([node, input_node, scale_node, bias_node, mean_node, var_node]) + g.node.extend([new_node]) + + for value in input_value_info: + if value is not None: + g.value_info.remove(value) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +def DequantizeLinear_constant_folding(g, node): + """ Fold constant and mul nodes to a single constant node. + """ + # Prepare data + node_to_del = [] + x_node = helper.find_node_by_output_name(g, node.input[0]) + x_scale_node = helper.find_node_by_output_name(g, node.input[1]) + if len(node.input) > 2: + x_zero_point_node = helper.find_node_by_output_name(g, node.input[2]) + else: + x_zero_point_node = None + + input_value_info = [] + for i in range(len(node.input)): + input_value_info.append(helper.find_value_by_name(g, node.input[i])) + + if input_value_info[0] is None: + return False + + x_data = helper.constant_to_numpy(x_node) + x_scale_data = helper.constant_to_numpy(x_scale_node) + if x_zero_point_node is not None: + x_zero_point_data = helper.constant_to_numpy(x_zero_point_node) + else: + x_zero_point_data = np.array([0.0]) + + # Calculate new node + new_data = (x_data.astype(np.float32) - x_zero_point_data.astype(np.float32)) * x_scale_data + + new_node = helper.numpy_to_constant(node.output[0], new_data) + + # Reconnect the graph + node_to_del.extend([node, x_node, x_scale_node]) + if x_zero_point_node is not None: + node_to_del.append(x_zero_point_node) + g.node.extend([new_node]) + + for value in input_value_info: + if value is not None: + g.value_info.remove(value) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return True + + +# Available constant folding names to function map. +constant_folding_nodes = { + 'Add': add_constant_folding, + 'BatchNormalization': bn_constant_folding, + 'Cast': cast_constant_folding, + 'Concat': concat_constant_folding, + 'DequantizeLinear': DequantizeLinear_constant_folding, + 'Div': div_constant_folding, + 'Floor': floor_constant_folding, + 'Gather': gather_constant_folding, + 'Mul': mul_constant_folding, + 'Reciprocal': reciprocal_constant_folding, + 'ReduceProd': reduceprod_constant_folding, + 'Reshape': reshape_constant_input_folding, + 'Slice': slice_constant_folding, + 'Sqrt': sqrt_constant_folding, + 'Transpose': transpose_constant_folding, + 'Unsqueeze': unsqueeze_constant_folding, + 'Sub': sub_constant_folding, + 'Neg': neg_constant_folding +} diff --git a/tools/optimizer_scripts/tools/eliminating.py b/tools/optimizer_scripts/tools/eliminating.py new file mode 100644 index 0000000..bc22b2e --- /dev/null +++ b/tools/optimizer_scripts/tools/eliminating.py @@ -0,0 +1,669 @@ +import collections +import struct +import onnx +import numpy as np +from . import other +from . import helper +from . import modhelper +from .general_graph import Graph + +def eliminate_Identify_and_Dropout(g): + """ + Eliminate Identify layers + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Identity' and node.op_type != 'Dropout': + continue + # If this node is the last node, leave it to `eliminate_useless_last node` + if helper.find_output_by_name(g, node.output[0]) is not None: + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +# Remove last useless nodes +def remove_useless_last_nodes(g): + """Remove useless nodes from the tail of the graph + """ + USELESS = ["Reshape", "Identity", "Transpose", "Flatten", "Dropout", "Mystery", "Constant", "Squeeze", "Unsqueeze", 'Softmax'] + graph = Graph(g) + todo = collections.deque() + for node in graph.output_nodes: + if len(node.children) == 0: + todo.append(node) + node_to_remove = [] + while todo: + # BFS find nodes to remove + cur_node = todo.popleft() + if cur_node.proto is None: + continue + if cur_node.proto.op_type not in USELESS: + continue + # Find the output + cur_node_output = helper.find_output_by_name(g, cur_node.proto.output[0]) + for cur_input in cur_node.parents: + cur_input.children.remove(cur_node) + if len(cur_input.children) == 0: + todo.append(cur_input) + if cur_node_output is not None: + cur_input_output = helper.find_value_by_name(g, cur_input.proto.output[0]) + cur_input_output_in_output = helper.find_output_by_name(g, cur_input.proto.output[0]) + if cur_input_output is not None and cur_input_output_in_output is None: + g.output.extend([cur_input_output]) + node_to_remove.append(cur_node.proto) + try: + g.value_info.remove(helper.find_value_by_name(g, cur_node.proto.output[0])) + except ValueError: + pass + if cur_node_output is not None: + g.output.remove(cur_node_output) + cur_node.proto = None + cur_node.parents.clear() + for node in node_to_remove: + g.node.remove(node) + +###################################### +# TF only optimization passes # +###################################### + +def eliminate_shape_changing_after_input(g): + """ + Eliminate the Reshape node after input and reshape the input + + :param g: the onnx graph + """ + node_to_remove = [] + REMOVE_LIST = ["Reshape", "Transpose", "Flatten", "Dropout", "Squeeze", "Unsqueeze"] + for node in g.node: + # Find an input and the shape node + if node.op_type not in REMOVE_LIST: + continue + old_input = helper.find_input_by_name(g, node.input[0]) + if old_input is None: + continue + # If the input is used by multiple nodes, skip. + counter = 0 + for tnode in g.node: + if old_input.name in tnode.input: + counter += 1 + if counter > 1: + continue + # Remove Weight if any. + output_val_info = helper.find_value_by_name(g, node.output[0]) + + if node.op_type == 'Reshape': + shape_node = helper.find_node_by_output_name(g, node.input[1]) + if shape_node.op_type != 'Constant': + continue + + # manuelly set the input shape + shape_info = helper.find_value_by_name(g, shape_node.output[0]) + old_size, old_shape = helper.find_size_shape_from_value(shape_info) + + _, new_shape = helper.constant_to_list(shape_node) + for i in range(len(new_shape)): + if new_shape[i] == -1: + dim = int(old_size//np.prod(new_shape)*(-1)) + new_shape[i] = dim + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + shape_outputs = helper.find_nodes_by_input_name(g, shape_node.output[0]) + if len(shape_outputs) == 1: + node_to_remove.append(shape_node) + g.value_info.remove(helper.find_value_by_name(g, shape_node.output[0])) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Transpose': + permutation = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + new_shape = [pre_shape[i] for i in permutation] + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Flatten': + axis = node.attribute[0].int + pre_shape = helper.get_shape_from_value_info(old_input) + dim_1, dim_2 = 1, 1 + if axis == 0: + dim_1 = 1 + dim_2 = np.prod(pre_shape) + else: + dim_1 = np.prod(pre_shape[:axis]).astype(int) + dim_2 = np.prod(pre_shape[axis:]).astype(int) + new_shape = [dim_1, dim_2] + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Dropout': + g.input.remove(old_input) + g.input.extend([output_val_info]) + g.value_info.remove(output_val_info) + + node_to_remove.append(node) + elif node.op_type == 'Squeeze': + axis = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + for pos in sorted(axis)[::-1]: + if pre_shape[pos] != 1: + raise RuntimeError('invalid axis for squeeze') + else: + pre_shape.pop(pos) + new_shape = pre_shape + + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + elif node.op_type == 'Unsqueeze': + axis = list(node.attribute[0].ints) + pre_shape = helper.get_shape_from_value_info(old_input) + new_shape = pre_shape + for pos in axis: + new_shape.insert(pos, 1) + new_input = onnx.helper.make_tensor_value_info( + output_val_info.name, + output_val_info.type.tensor_type.elem_type, + new_shape + ) + node_to_remove.append(node) + + g.input.remove(old_input) + g.input.extend([new_input]) + g.value_info.remove(output_val_info) + else: + pass + + for node in node_to_remove: + g.node.remove(node) + + other.topological_sort(g) + + +def eliminate_Reshape_Cast(g): + """Eliminate the cast layer for shape of Reshape layer + + :param g: the onnx graph + """ + #Find all reshape layers + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + prev_node = helper.find_node_by_output_name(g, node.input[1]) + if prev_node.op_type != 'Cast': + continue + # Now we find the cast weight pattern. Cast the weight, delete the cast. + reshape_node = node + cast_node = prev_node + weight_node = helper.find_node_by_output_name(g, cast_node.input[0]) + if weight_node is None: + raise RuntimeError("Unexpected None before Cast-Reshape.") + weight_node.attribute[0].t.data_type = 7 + if weight_node.attribute[0].t.raw_data: + raw_data = weight_node.attribute[0].t.raw_data + int_data = [i[0] for i in struct.iter_unpack('i', raw_data)] + raw_data = struct.pack('q' * len(int_data), *int_data) + elif len(weight_node.attribute[0].t.int64_data) > 0\ + or len(weight_node.attribute[0].t.int32_data) > 0: + # It's already int. Do nothing + pass + else: + raise NotImplementedError() + # Change Value info + origin_weight_out = helper.find_value_by_name(g, weight_node.output[0]) + weight_node.output.pop() + weight_node.output.extend([reshape_node.input[1]]) + # Delete + g.value_info.remove(origin_weight_out) + g.node.remove(cast_node) + +def eliminate_Cast_after_input(g): + """Eliminate the cast layer right after the input + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + old_input = helper.find_input_by_name(g, node.input[0]) + if old_input is None: + continue + next_val_info = helper.find_value_by_name(g, node.output[0]) + shape = helper.get_shape_from_value_info(next_val_info) + new_val_info = onnx.helper.make_tensor_value_info( + next_val_info.name, + node.attribute[0].i, + shape + ) + # Delete old value_info + g.input.remove(old_input) + g.value_info.remove(next_val_info) + # Append nodes to node_to_remove + node_to_remove.append(node) + # Add new input + g.input.extend([new_val_info]) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_consecutive_Cast(g): + """If two cast is next to each other, remove the first cast + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + first_node = helper.find_node_by_output_name(g, node.input[0]) + if first_node is None or first_node.op_type != 'Cast': + continue + # Here we have two consecutive Cast Node + # Reset the input of the later node + node.input[0] = first_node.input[0] + # Remove the first node and its output value info + node_to_remove.append(first_node) + first_output = helper.find_value_by_name(g, first_node.output[0]) + g.value_info.remove(first_output) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_Squeeze_before_Reshape(g): + """If Squeeze and Reshape is next to each other, remove the first node + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + first_node = helper.find_node_by_output_name(g, node.input[0]) + if not first_node: + continue + if first_node.op_type != 'Squeeze': + continue + # Here we have two consecutive Cast Node + # Reset the input of the later node + node.input[0] = first_node.input[0] + # Remove the first node and its output value info + node_to_remove.append(first_node) + first_output = helper.find_value_by_name(g, first_node.output[0]) + g.value_info.remove(first_output) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_no_children_input(g): + """Eliminate inputs with no children at all. + """ + # Create a set of input names + input_names = set([i.name for i in g.input]) + # If a name is used in any node, remove this name from the set. + for n in g.node: + for i in n.input: + input_names.discard(i) + # Remove the inputs with the left names. + for i in input_names: + info = helper.find_input_by_name(g, i) + g.input.remove(info) + +def eliminate_consecutive_reshape(g): + """Replace consecutive reshape nodes by a single node. + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + pre_data_node = helper.find_node_by_output_name(g, node.input[0]) + pre_shape_node = helper.find_node_by_output_name(g, node.input[1]) + if not pre_data_node or not pre_shape_node: + continue + if pre_shape_node.op_type != 'Constant': + continue + if pre_data_node.op_type != 'Reshape': + continue + + pre_pre_shape_node = helper.find_node_by_output_name(g, pre_data_node.input[1]) + if pre_pre_shape_node.op_type != 'Constant': + continue + + new_reshape_node = onnx.helper.make_node( + 'Reshape', + [pre_data_node.input[0], node.input[1]], + [node.output[0]], + name = node.output[0] + ) + + g.node.extend([new_reshape_node]) + node_to_del.append(node) + node_to_del.append(pre_data_node) + node_to_del.append(pre_pre_shape_node) + + val_info_to_del1 = helper.find_value_by_name(g, node.input[0]) + val_info_to_del2 = helper.find_value_by_name(g, pre_data_node.input[1]) + g.value_info.remove(val_info_to_del1) + g.value_info.remove(val_info_to_del2) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + +def eliminate_single_input_Concat(g): + """ + Eliminate single input Concat layers + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Concat': + continue + # If this node has more than 1 input, continue. + if len(node.input) > 1: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.remove(todel_output) + g.output.extend([the_input_value]) + node_to_remove.append(node) + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_nop_Maxpool_and_AveragePool(g): + """ + Eliminate do nothing MaxPool and AveragePool layers. + Those layers have valid padding, 1x1 kernel and [1,1] strides. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'MaxPool' and node.op_type != 'AveragePool': + continue + # If this node is actually working, continue. + kernel = helper.get_list_attribute_by_name(node, "kernel_shape", "int") + pads = helper.get_list_attribute_by_name(node, "pads", "int") + strides = helper.get_list_attribute_by_name(node, "strides", "int") + if kernel != [1, 1] or pads != [0, 0, 0, 0] or strides != [1, 1]: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.remove(todel_output) + g.output.extend([the_input_value]) + node_to_remove.append(node) + continue + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + print("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + + +def eliminate_trivial_maxpool(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'MaxPool': + continue + pads = None + strides = None + dilation = None + kernel_shape = None + for att in node.attribute: + if att.name == 'pads': + pads = list(att.ints) + elif att.name == 'strides': + strides = list(att.ints) + elif att.name == 'kernel_shape': + kernel_shape = list(att.ints) + elif att.name == 'dilation': + dilation = list(att.ints) + else: + pass + if pads and any([pad != 0 for pad in pads]): + continue + if strides and any([stride != 1 for stride in strides]): + continue + if dilation and any([dila != 1 for dila in dilation]): + continue + if any([dim != 1 for dim in kernel_shape]): + continue + + node_to_del.append(node) + + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + + if next_nodes[0] == None: + output_value = helper.find_output_by_name(g, node.output[0]) + if not output_value: + continue + else: + pre_val_info = helper.find_value_by_name(g, node.input[0]) + g.output.extend([pre_val_info]) + g.output.remove(output_value) + + for next_node in next_nodes: + modhelper.replace_node_input(next_node, node.output[0], node.input[0]) + + next_val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(next_val_info) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + other.topological_sort(g) + +def eliminate_empty_value_infos(g): + to_remove = [] + for value_info in g.value_info: + if len(value_info.type.tensor_type.shape.dim) == 0: + to_remove.append(value_info) + for value_info in to_remove: + g.value_info.remove(value_info) + +def eliminate_nop_pads(g): + node_to_remove = [] + for node in g.node: + if node.op_type != 'Pad': + continue + # Check if the Pad is empty or not + pads_node = helper.find_node_by_output_name(g, node.input[1]) + pads_np = helper.constant_to_numpy(pads_node) + all_zero = True + for value in pads_np: + if value != 0: + all_zero = False + if not all_zero: + continue + # Check if it has the constant_value_node + constant_value_node = None + if len(node.input) > 2: + constant_value_node = helper.find_node_by_output_name(g, node.input[2]) + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + g.output.remove(todel_output) + if helper.find_output_by_name(g, node.input[0]) is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + if the_input_value is not None: + g.output.extend([the_input_value]) + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + try: + g.value_info.remove(value_between) + except: + helper.logger.info("No value info to delete while eliminating identity layers.") + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_trivial_elementwise_calculation(g): + """Eliminate Add, Sub, Mul, Sub nodes which do nothing. + """ + node_to_remove = [] + for node in g.node: + weight_node = None + if node.op_type == 'Add' or node.op_type == 'Sub': + # For add and sub, check if the weights are 0s. + weight_node = helper.find_node_by_output_name(g, node.input[1]) + if weight_node is None or weight_node.op_type != 'Constant': + continue + weight_np = helper.constant_to_numpy(weight_node) + if np.any(weight_np): + continue + elif node.op_type == 'Mul' or node.op_type == 'Div': + # For Mul and Div, check if the weights are 1s. + weight_node = helper.find_node_by_output_name(g, node.input[1]) + if weight_node is None or weight_node.op_type != 'Constant': + continue + weight_np = helper.constant_to_numpy(weight_node) + weight_np = weight_np - 1 + if np.any(weight_np): + continue + else: + # For other nodes, just skip + continue + # Remove the node + node_to_remove.append(node) + output_value_info = helper.find_value_by_name(g, node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + # Replace next node input if any. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + todel_output = helper.find_output_by_name(g, node.output[0]) + if todel_output is not None: + g.output.remove(todel_output) + previous_output = helper.find_output_by_name(g, node.input[0]) + if previous_output is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + g.output.extend([the_input_value]) + # Delete the constant node if it is not used by other nodes + constant_following_nodes = helper.find_following_nodes_by_input_value_name(g, weight_node.output[0]) + if len(constant_following_nodes) == 1: + node_to_remove.append(weight_node) + output_value_info = helper.find_value_by_name(g, weight_node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + for node in node_to_remove: + g.node.remove(node) + +def eliminate_nop_cast(g): + """Eliminate do nothing Cast nodes. + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Cast': + continue + # Get input value_info + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + helper.logger.debug(f"Cannot find the input value_info for Cast node {node.name}. Skip elimination check.") + continue + # Get output value_info + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + helper.logger.debug(f"Cannot find the output value_info for Cast node {node.name}. Skip elimination check.") + continue + # Compare the type. + if input_value.type.tensor_type.elem_type != output_value.type.tensor_type.elem_type: + continue + # If this node is the output node, set its previous node as output nodes. + if helper.find_output_by_name(g, node.output[0]) is not None: + todel_output = helper.find_output_by_name(g, node.output[0]) + g.output.remove(todel_output) + if helper.find_output_by_name(g, node.input[0]) is None: + the_input_value = helper.find_value_by_name(g, node.input[0]) + if the_input_value is not None: + g.output.extend([the_input_value]) + # Replace the parents in all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + # Delete value info + value_between = helper.find_value_by_name(g, node.output[0]) + if value_between is not None: + g.value_info.remove(value_between) + # Node is waiting for elimination + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) diff --git a/tools/optimizer_scripts/tools/fusing.py b/tools/optimizer_scripts/tools/fusing.py new file mode 100644 index 0000000..202a4c2 --- /dev/null +++ b/tools/optimizer_scripts/tools/fusing.py @@ -0,0 +1,1064 @@ +import onnx.helper +import numpy as np +from . import helper +from .other import topological_sort +from .modhelper import delete_value_with_name_if_exists, replace_node_input + +def fuse_Transpose_into_Constant(g): + """ + Fuse Transpose layers into the Constant layers before + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None or prev_node.op_type != 'Constant': + continue + + pre_shape, data_list = helper.constant_to_list(prev_node) + w = np.reshape(data_list, pre_shape) + w = w.transpose(node.attribute[0].ints) + new_shape = w.shape + w = w.flatten() + + new_tensor = onnx.helper.make_tensor( + name=prev_node.name+'_data', + data_type=prev_node.attribute[0].t.data_type, + dims=new_shape, + vals=w.tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [node.output[0]], + name=node.output[0], + value=new_tensor + ) + + value_between = helper.find_value_by_name(g, prev_node.output[0]) + value_type = value_between.type.tensor_type.elem_type + g.value_info.remove(value_between) + + g.node.extend([new_node]) + node_to_remove.append(node) + node_to_remove.append(prev_node) + + if new_node.output[0] not in [i.name for i in g.value_info]: + new_value = onnx.helper.make_tensor_value_info( + name=new_node.output[0], + elem_type=value_type, + shape=new_shape + ) + g.value_info.extend([new_value]) + if new_node.output[0]: + val_info_to_del = helper.find_value_by_name(g, new_node.output[0]) + g.value_info.remove(val_info_to_del) + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + +def fuse_Add_into_Conv(g): + """ + Fuse Transpose layers into the Constant layers before + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Add': + continue + conv_node = helper.find_node_by_output_name(g, node.input[0]) + cons_node = helper.find_node_by_output_name(g, node.input[1]) + if conv_node is None or cons_node is None: + continue + if conv_node.op_type != 'Conv' or cons_node.op_type != 'Constant': + continue + if len(conv_node.input) > 2: + continue + # This layer should be fused. Connect constant node into convolution node. + add_node = node + conv_node.input.extend([cons_node.output[0]]) + old_value = helper.find_value_by_name(g, conv_node.output[0]) + conv_node.output[0] = add_node.output[0] + # Remove origin conv_node_output + g.value_info.remove(old_value) + # Remove current node + node_to_remove.append(add_node) + # Apply changes to the model + for node in node_to_remove: + g.node.remove(node) + +def fuse_BN_into_Gemm(g): + """Fuse the following BN into the previous Gemm. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for BN and Gemm + if node.op_type != 'BatchNormalization': + continue + gemm_node = helper.find_node_by_output_name(g, node.input[0]) + if gemm_node is None: + continue + if gemm_node.op_type != 'Gemm': + continue + if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + continue + bn_node = node + # Get original weights + gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) + gemm_b = helper.constant_to_numpy(gemm_b_node) + gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) + gemm_c = helper.constant_to_numpy(gemm_c_node) + bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) + bn_scale = helper.constant_to_numpy(bn_scale_node) + bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) + bn_bias = helper.constant_to_numpy(bn_bias_node) + bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) + bn_mean = helper.constant_to_numpy(bn_mean_node) + bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) + bn_var = helper.constant_to_numpy(bn_var_node) + # Apply attributes + # epsilon + epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + if epsilon is None: + epsilon = 0.00001 + else: + epsilon = epsilon.f + bn_var = bn_var + epsilon + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + gemm_b = gemm_b * alpha + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + gemm_c = gemm_c * beta + # transA + transA = helper.get_attribute_by_name(gemm_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None and transB.i == 1: + gemm_b = gemm_b.transpose() + # Calculate new weights + new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) + new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias + # Replace original weights + new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) + new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + g.node.extend([new_gemm_b_node, new_gemm_c_node]) + node_to_remove.extend([gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None: + transB.i = 0 + # Connect the new graph + gemm_node.input[1] = new_gemm_b_node.output[0] + gemm_node.input[2] = new_gemm_c_node.output[0] + gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) + gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) + gemm_b_value.name = new_gemm_b_node.output[0] + gemm_c_value.name = new_gemm_c_node.output[0] + gemm_value = helper.find_value_by_name(g, gemm_node.output[0]) + g.value_info.remove(gemm_value) + gemm_node.output[0] = bn_node.output[0] + for i in range(1, 5): + value = helper.find_value_by_name(g, bn_node.input[i]) + g.value_info.remove(value) + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def fuse_BN_with_Reshape_into_Gemm(g): + """Fuse the following BN into the previous Gemm, even with Reshape or \\ + Squeeze and Unsqueeze surrounding. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for BN and Gemm pattern: Gemm A BN B + # Find BatchNorm Node + if node.op_type != 'BatchNormalization': + continue + bn_node = node + # Find A Node + a_node = helper.find_node_by_output_name(g, node.input[0]) + if a_node is None or len(a_node.input) == 0: + continue + # Find Gemm Node + gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) + if gemm_node is None or gemm_node.op_type != 'Gemm': + continue + # Find B Node + b_node_list = helper.find_following_nodes_by_input_value_name(g, bn_node.output[0]) + if len(b_node_list) == 0: + the_output = helper.find_output_by_name(g, bn_node.output[0]) + if the_output is None: + continue + b_node = None + elif len(b_node_list) > 1: + continue + else: + b_node = b_node_list[0] + # Check for branches + if len(helper.find_following_nodes_by_input_value_name(g, gemm_node.output[0])) > 1: + continue + if len(helper.find_following_nodes_by_input_value_name(g, a_node.output[0])) > 1: + continue + # Check type of A + if a_node.op_type == 'Unsqueeze': + axes = helper.get_attribute_by_name(a_node, 'axes') + if axes.ints != [2]: + continue + elif a_node.op_type == 'Reshape': + a = helper.constant_to_list(helper.find_node_by_output_name(g, a_node.input[1]))[1] + if len(a) != 3 or a[2] != 1: + continue + else: + continue + # Check type of B + if b_node is None: + pass + elif b_node.op_type == 'Flatten': + pass + elif b_node.op_type == 'Squeeze': + axes = helper.get_attribute_by_name(a_node, 'axes') + if axes.ints != [2]: + continue + elif b_node.op_type == 'Reshape': + a = helper.constant_to_list(helper.find_node_by_output_name(g, b_node.input[1]))[1] + if len(a) != 2: + continue + else: + continue + # Construct new Nodes + # Get original weights + gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) + gemm_b = helper.constant_to_numpy(gemm_b_node) + gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) + gemm_c = helper.constant_to_numpy(gemm_c_node) + bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) + bn_scale = helper.constant_to_numpy(bn_scale_node) + bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) + bn_bias = helper.constant_to_numpy(bn_bias_node) + bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) + bn_mean = helper.constant_to_numpy(bn_mean_node) + bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) + bn_var = helper.constant_to_numpy(bn_var_node) + # Apply attributes + # epsilon + epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') + if epsilon is None: + epsilon = 0.00001 + else: + epsilon = epsilon.f + bn_var = bn_var + epsilon + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + gemm_b = gemm_b * alpha + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + gemm_c = gemm_c * beta + # transA + transA = helper.get_attribute_by_name(gemm_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None and transB.i == 1: + gemm_b = gemm_b.transpose() + # Calculate new weights + new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) + new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias + # Replace original weights + new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) + new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) + g.node.extend([new_gemm_b_node, new_gemm_c_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(gemm_node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(gemm_node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(gemm_node, 'transB') + if transB is not None: + transB.i = 0 + # Remove useless nodes + node_to_remove.extend([gemm_b_node, + gemm_c_node, + bn_node, + bn_scale_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + a_node]) + if a_node.op_type == 'Reshape': + node_to_remove.append(helper.find_node_by_output_name(g, a_node.input[1])) + if b_node is not None: + node_to_remove.append(b_node) + if b_node.op_type == 'Reshape': + node_to_remove.append(helper.find_node_by_output_name(g, b_node.input[1])) + # Delete useless value infos + value = helper.find_value_by_name(g, a_node.output[0]) + g.value_info.remove(value) + if a_node.op_type == 'Reshape': + value = helper.find_value_by_name(g, a_node.input[1]) + g.value_info.remove(value) + for i in range(1, 5): + value = helper.find_value_by_name(g, bn_node.input[i]) + g.value_info.remove(value) + value = helper.find_value_by_name(g, bn_node.output[0]) + if value is not None: + g.value_info.remove(value) + if b_node is not None: + value = helper.find_value_by_name(g, gemm_node.output[0]) + g.value_info.remove(value) + if b_node.op_type == 'Reshape': + value = helper.find_value_by_name(g, b_node.input[1]) + g.value_info.remove(value) + # Connect the new graph + # Connect Gemm new weights + gemm_node.input[1] = new_gemm_b_node.output[0] + gemm_node.input[2] = new_gemm_c_node.output[0] + gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) + gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) + gemm_b_value.name = new_gemm_b_node.output[0] + gemm_b_value.type.tensor_type.shape.dim[0].dim_value = new_gemm_b.shape[0] + gemm_b_value.type.tensor_type.shape.dim[1].dim_value = new_gemm_b.shape[1] + gemm_c_value.name = new_gemm_c_node.output[0] + if b_node is None: + # If b node is None, set the Gemm output as the graph output + output_value = helper.find_output_by_name(g, bn_node.output[0]) + g.output.remove(output_value) + g.output.extend([helper.find_value_by_name(g, gemm_node.output[0])]) + else: + # Else, set node B output as gemm output + gemm_node.output[0] = b_node.output[0] + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + + +def fuse_Gemm_into_Gemm(g): + """Fuse the previous Gemm into the following Gemm. + + :param g: the graph + """ + node_to_remove = [] + for node in g.node: + # Check for Gemm and Gemm + if node.op_type != 'Gemm': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None: + continue + if prev_node.op_type != 'Gemm': + continue + # Get original weights + prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1]) + prev_b = helper.constant_to_numpy(prev_b_node) + prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2]) + prev_c = helper.constant_to_numpy(prev_c_node) + b_node = helper.find_node_by_output_name(g, node.input[1]) + b = helper.constant_to_numpy(b_node) + c_node = helper.find_node_by_output_name(g, node.input[2]) + c = helper.constant_to_numpy(c_node) + # Apply attributes + # alpha + alpha = helper.get_attribute_by_name(node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + b = b * alpha + alpha = helper.get_attribute_by_name(prev_node, 'alpha') + if alpha is None: + alpha = 1 + else: + alpha = alpha.f + prev_b = prev_b * alpha + # beta + beta = helper.get_attribute_by_name(node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + c = c * beta + beta = helper.get_attribute_by_name(prev_node, 'beta') + if beta is None: + beta = 1 + else: + beta = beta.f + prev_c = prev_c * beta + # transA + transA = helper.get_attribute_by_name(node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + transA = helper.get_attribute_by_name(prev_node, 'transA') + if transA is not None and transA.i == 1: + raise RuntimeError("Do not support transA") + # transB + transB = helper.get_attribute_by_name(node, 'transB') + if transB is not None and transB.i == 1: + b = b.transpose() + transB = helper.get_attribute_by_name(prev_node, 'transB') + if transB is not None and transB.i == 1: + prev_b = prev_b.transpose() + # Calculate new weights + new_b = prev_b.dot(b) + new_c = prev_c.dot(b) + c + # Replace original weights + new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b) + new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c) + g.node.extend([new_b_node, new_c_node]) + node_to_remove.extend([b_node, + c_node, + prev_b_node, + prev_c_node, + prev_node]) + # Modify attributes + # alpha + alpha = helper.get_attribute_by_name(node, 'alpha') + if alpha is not None: + alpha.f = 1.0 + # beta + beta = helper.get_attribute_by_name(node, 'beta') + if beta is not None: + beta.f = 1.0 + # transB + transB = helper.get_attribute_by_name(node, 'transB') + if transB is not None: + transB.i = 0 + # Connect the new graph + node.input[0] = prev_node.input[0] + delete_value_with_name_if_exists(g, prev_node.output[0]) + for i in range(1, 3): + delete_value_with_name_if_exists(g, prev_node.input[i]) + delete_value_with_name_if_exists(g, node.input[i]) + node.input[1] = new_b_node.output[0] + node.input[2] = new_c_node.output[0] + # Remove useless nodes + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def fuse_MatMul_and_Add_into_Gemm(g): + """ + Fuse MatMul and Add layers into a new Gemm layers. + + :param g: the onnx graph + :raises ValueError: MatMul must be followed by an Add node + """ + node_to_remove = [] + node_to_add = [] + for node in g.node: + if node.op_type != 'MatMul': + continue + add_node = None + for i in g.node: + if not i.input: + continue + if i.input[0] == node.output[0]: + add_node = i + break + value_to_remove = helper.find_value_by_name(g, node.output[0]) + if add_node is None or value_to_remove is None or add_node.op_type != 'Add': + continue + input_list = node.input + input_list.append(add_node.input[1]), + new_node = onnx.helper.make_node( + "Gemm", + input_list, + add_node.output, + name=node.name, + alpha=1.0, + beta=1.0, + transA=0, + transB=0 + ) + node_to_add.append(new_node) + node_to_remove.append(node) + node_to_remove.append(add_node) + g.value_info.remove(value_to_remove) + for node in node_to_remove: + g.node.remove(node) + g.node.extend(node_to_add) + +def fuse_consecutive_transposes(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if pre_node.op_type != 'Transpose': + continue + + pre_permutation = list(pre_node.attribute[0].ints) + cur_permutation = list(node.attribute[0].ints) + if len(pre_permutation) != len(cur_permutation): + continue + + new_permutation = [] + for ind in cur_permutation: + new_permutation.append(pre_permutation[ind]) + + new_trans_node = onnx.helper.make_node( + 'Transpose', + [pre_node.input[0]], + [node.output[0]], + name=node.name, + perm=new_permutation + ) + + g.node.extend([new_trans_node]) + node_to_del.extend([pre_node, node]) + + mid_val_info = helper.find_value_by_name(g, node.input[0]) + if mid_val_info: + g.value_info.remove(mid_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + topological_sort(g) + +def fuse_mul_and_add_into_bn(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + add_node = node + input_nodes_add = [helper.find_node_by_output_name(g, input_name) for input_name in add_node.input] + if any([n == None for n in input_nodes_add]): + continue + mul_node, const_add = None, None + for input_node_add in input_nodes_add: + if input_node_add.op_type == 'Mul': + mul_node = input_node_add + elif input_node_add.op_type == 'Constant': + const_add = input_node_add + else: + pass + if not mul_node or not const_add: + continue + data_input_name, const_mul = None, None + for input_name in mul_node.input: + input_node = helper.find_node_by_output_name(g, input_name) + if not input_node: + data_input_name = input_name + elif input_node.op_type == 'Constant': + if not const_mul: + const_mul = input_node + else: + data_input_name = input_name + else: + data_input_name = input_name + + if not const_mul: + continue + + scale_shape, scale_data = helper.constant_to_list(const_mul) + bias_shape, __ = helper.constant_to_list(const_add) + c_dim = len(scale_data) + if scale_shape != bias_shape: + continue + + data_input_value = helper.find_value_by_name(g, data_input_name) + if data_input_value is None: + data_input_value = helper.find_input_by_name(g, data_input_name) + _ , previous_node_output_shape = helper.find_size_shape_from_value(data_input_value) + # only allow 4 dim data input due to the hardware limitation + if previous_node_output_shape is None or len(previous_node_output_shape) != 4: + continue + + # check if mul's dim and input channel dimension are matched + if previous_node_output_shape[1] != c_dim: + continue + + if scale_shape == [1, c_dim, 1, 1]: + # remove all '1' + for _ in range(3): + const_add.attribute[0].t.dims.remove(1) + const_mul.attribute[0].t.dims.remove(1) + elif scale_shape == [1, c_dim]: + # remove all '1' + const_add.attribute[0].t.dims.remove(1) + const_mul.attribute[0].t.dims.remove(1) + elif scale_shape == 1 and c_dim == 1: + # Single value weight + const_add.attribute[0].t.dims.append(1) + const_mul.attribute[0].t.dims.append(1) + else: + continue + + bn_name = add_node.output[0] + const_mean = helper.list_to_constant(bn_name+'_mean', [c_dim], [0.0 for _ in range(c_dim)]) + const_var = helper.list_to_constant(bn_name+'_var', [c_dim], [1.0 for _ in range(c_dim)]) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [data_input_name, const_mul.output[0], const_add.output[0],\ + const_mean.output[0], const_var.output[0]], + [add_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) + scale_val_info = helper.find_value_by_name(g, const_mul.output[0]) + bais_val_info = helper.find_value_by_name(g, const_add.output[0]) + g.value_info.remove(mid_val_info) + g.value_info.remove(scale_val_info) + g.value_info.remove(bais_val_info) + + new_scale_val_info = onnx.helper.make_tensor_value_info( + const_mul.output[0], + const_mul.attribute[0].t.data_type, + [c_dim] + ) + new_bais_val_info = onnx.helper.make_tensor_value_info( + const_add.output[0], + const_add.attribute[0].t.data_type, + [c_dim] + ) + mean_val_info = onnx.helper.make_tensor_value_info( + const_mean.output[0], + const_mean.attribute[0].t.data_type, + [c_dim] + ) + var_val_info = onnx.helper.make_tensor_value_info( + const_var.output[0], + const_var.attribute[0].t.data_type, + [c_dim] + ) + + g.value_info.extend([new_scale_val_info]) + g.value_info.extend([new_bais_val_info]) + g.value_info.extend([mean_val_info]) + g.value_info.extend([var_val_info]) + g.node.extend([bn_node]) + g.node.extend([const_mean]) + g.node.extend([const_var]) + node_to_del.extend([mul_node, add_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def fuse_mul_and_add_into_gemm(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + add_node = node + mul_node = helper.find_node_by_output_name(g, add_node.input[0]) + if not mul_node or mul_node.op_type != 'Mul': + continue + mul_const = helper.find_node_by_output_name(g, mul_node.input[1]) + if not mul_const or mul_const.op_type != 'Constant': + continue + add_const = helper.find_node_by_output_name(g, add_node.input[1]) + if not add_const or add_const.op_type != 'Constant': + continue + + input_val = helper.find_value_by_name(g, mul_node.input[0]) + if not input_val: + input_val = helper.find_input_by_name(g, mul_node.input[0]) + if not input_val: + continue + + _, input_shape = helper.find_size_shape_from_value(input_val) + if not input_shape: + continue + + dim = int(np.prod(input_shape)) + if input_shape != [1, dim]: + continue + + mul_const_shape, mul_const_data = helper.constant_to_list(mul_const) + add_const_shape, __ = helper.constant_to_list(add_const) + + if len(mul_const_shape) != 1 or mul_const_shape[0] != dim: + continue + if len(add_const_shape) != 1 or add_const_shape[0] != dim: + continue + + b_data = np.zeros([dim, dim]) + for i in range(dim): + b_data[i][i] = mul_const_data[i] + b_data = b_data.flatten().tolist() + b_tensor = onnx.helper.make_tensor( + name=mul_const.name+'_tensor', + data_type=mul_const.attribute[0].t.data_type, + dims=[dim, dim], + vals=b_data + ) + b_const_node = onnx.helper.make_node( + 'Constant', + [], + [mul_const.output[0]], + value=b_tensor, + name=mul_const.output[0] + ) + + add_const.attribute[0].t.dims.insert(0, 1) + + gemm_node = onnx.helper.make_node( + 'Gemm', + [mul_node.input[0], b_const_node.output[0], add_const.output[0]], + [add_node.output[0]], + name=add_node.output[0] + ) + + g.node.extend([gemm_node, b_const_node]) + node_to_del.extend([mul_const, mul_node, add_node]) + + val_info_mid = helper.find_value_by_name(g, mul_node.output[0]) + val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0]) + val_info_add_const = helper.find_value_by_name(g, add_const.output[0]) + if val_info_mid: + g.value_info.remove(val_info_mid) + if val_info_mul_const: + g.value_info.remove(val_info_mul_const) + if val_info_add_const: + g.value_info.remove(val_info_add_const) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def fuse_conv_and_add_into_conv(g): + node_to_del = [] + for node in g.node: + # Check if two nodes can be fused + if node.op_type != 'Add': + continue + add_node = node + add_const = helper.find_node_by_output_name(g, add_node.input[1]) + if not add_const or add_const.op_type != 'Constant': + continue + + conv_node = helper.find_node_by_output_name(g, add_node.input[0]) + if not conv_node or conv_node.op_type != 'Conv': + continue + weight_node = helper.find_node_by_output_name(g, conv_node.input[1]) + if not weight_node or weight_node.op_type != 'Constant': + continue + + m_dim = weight_node.attribute[0].t.dims[0] + if add_const.attribute[0].t.dims != [1, m_dim, 1, 1]: + continue + for _ in range(3): + add_const.attribute[0].t.dims.remove(1) + + # Link the add weight to constant. + conv_node.input.extend([add_const.output[0]]) + + # Remove the node + node_to_del.append(node) + output_value_info = helper.find_value_by_name(g, add_node.output[0]) + if output_value_info is not None: + g.value_info.remove(output_value_info) + add_weight_value_info = helper.find_value_by_name(g, add_const.output[0]) + if add_weight_value_info is not None: + g.value_info.remove(add_weight_value_info) + # Replace next node input if any. + following_nodes = helper.find_following_nodes_by_input_value_name(g, add_node.output[0]) + for following_node in following_nodes: + replace_node_input(following_node, add_node.output[0], add_node.input[0]) + # Replace output if any + todel_output = helper.find_output_by_name(g, add_node.output[0]) + if todel_output is not None: + g.output.remove(todel_output) + previous_output = helper.find_output_by_name(g, add_node.input[0]) + if previous_output is None: + the_input_value = helper.find_value_by_name(g, add_node.input[0]) + g.output.extend([the_input_value]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def fuse_consecutive_reducemean(g): + node_to_del = [] + for node in g.node: + # Find consecutive ReduceMean + if node.op_type != 'ReduceMean': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if pre_node is None or pre_node.op_type != 'ReduceMean': + continue + # Check attributes + pre_keepdims = helper.get_var_attribute_by_name(pre_node, 'keepdims', 'int') + pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int') + cur_keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') + cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int') + if pre_keepdims != 0 or cur_keepdims != 0: + continue + axes = sorted(pre_axes + cur_axes) + if axes != [2, 3]: + continue + # Merge two ReduceMean into GlobalAveragePool. + new_gap_node = onnx.helper.make_node( + 'GlobalAveragePool', + [pre_node.input[0]], + [node.output[0] + '_intermedia'], + name = node.name + '_gap' + ) + new_flatten_node = onnx.helper.make_node( + 'Flatten', + [node.output[0] + '_intermedia'], + [node.output[0]], + name = node.name + '_flatten', + axis = 1 + ) + + # Clean up + g.node.extend([new_gap_node, new_flatten_node]) + node_to_del.extend([pre_node, node]) + mid_val_info = helper.find_value_by_name(g, node.input[0]) + if mid_val_info: + g.value_info.remove(mid_val_info) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + topological_sort(g) + +def fuse_slice_nodes_into_conv(g): + # define pattern checker + def check_is_slice(node): + if node.op_type == 'Concat': + return True + if node.op_type != 'Slice': + return False + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(following_nodes) != 1: + return False + # also check attributes + if len(node.input) != 5: + return False + # starts should be 0 or 1 + starts_node = helper.find_node_by_output_name(g, node.input[1]) + if starts_node.op_type != 'Constant': + return False + _, starts_list = helper.constant_to_list(starts_node) + for num in starts_list: + if num != 0 and num != 1: + return False + # ends + ends_node = helper.find_node_by_output_name(g, node.input[2]) + if ends_node.op_type != 'Constant': + return False + # axes should be 2 or 3 + axes_node = helper.find_node_by_output_name(g, node.input[3]) + if axes_node.op_type != 'Constant': + return False + _, axes_list = helper.constant_to_list(axes_node) + for num in axes_list: + if num != 2 and num != 3: + return False + # Steps can only be 2 + steps_node = helper.find_node_by_output_name(g, node.input[4]) + if steps_node.op_type != 'Constant': + return False + _, steps_list = helper.constant_to_list(steps_node) + for num in steps_list: + if num != 2: + return False + # Recursion + return check_is_slice(following_nodes[0]) + # defind concat finder + def find_concat_node(node): + while node.op_type != 'Concat': + node = helper.find_following_nodes_by_input_value_name(g, node.output[0])[0] + return node + # define remove node function. + def remove_nodes(input_name): + following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + # Remove concat directly + if len(following_nodes) == 1 and following_nodes[0].op_type == 'Concat': + g.node.remove(following_nodes[0]) + return + for following_node in following_nodes: + # Recursion first + remove_nodes(following_node.output[0]) + # Remove weights + for i in range(1, len(following_node.input)): + if len(helper.find_following_nodes_by_input_value_name(g, following_node.input[i])) > 1: + # More than one following nodes. Skip. + continue + input_weight = helper.find_node_by_output_name(g, following_node.input[i]) + g.node.remove(input_weight) + # Remove Slice nodes + g.node.remove(following_node) + # define remove value_info function + def remove_value_infos(input_name): + following_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + if following_nodes[0].op_type == 'Concat': + return + for following_node in following_nodes: + output_value = helper.find_value_by_name(g, following_node.output[0]) + # Remove output values + if output_value is not None: + g.value_info.remove(output_value) + # Remove weight values + for i in range(1, len(following_node.input)): + input_value = helper.find_value_by_name(g, following_node.input[i]) + if input_value is not None: + g.value_info.remove(input_value) + # Recursion + remove_value_infos(following_node.output[0]) + # define get slice position + def get_slice_position(final_slice_output): + slice_position = [0, 0] + prev_node = helper.find_node_by_output_name(g, final_slice_output) + while prev_node is not None: + starts_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[1])) + axes_np = helper.constant_to_numpy(helper.find_node_by_output_name(g, prev_node.input[3])) + for i in range(len(axes_np)): + if axes_np[i] == 2: + slice_position[0] = starts_np[i] + elif axes_np[i] == 3: + slice_position[1] = starts_np[i] + prev_node = helper.find_node_by_output_name(g, prev_node.input[0]) + return slice_position + # Check pattern from each input + for input_value in g.input: + nodes_after_input = helper.find_following_nodes_by_input_value_name(g, input_value.name) + pattern_matched = True + for following_node in nodes_after_input: + if following_node.op_type != 'Slice': + pattern_matched = False + break + else: + pattern_matched = check_is_slice(following_node) + if not pattern_matched: + continue + # Pattern found. Check limitation + # Currently only support 2D + if len(nodes_after_input) != 4: + continue + # Get the concat node + concat_node = find_concat_node(nodes_after_input[0]) + # Get basic information + input_shape = helper.get_shape_from_value_info(input_value) + channel_num = input_shape[1] + # Construct weight + weight_np = np.zeros((input_shape[1] * 4, input_shape[1], 3, 3), dtype=np.float32) + for i in range(4): + # Check each branch + slice_position = get_slice_position(concat_node.input[i]) + for j in range(channel_num): + weight_np[i * channel_num + j, j, slice_position[0], slice_position[1]] = 1 + weight_node = helper.numpy_to_constant(concat_node.name + '_weight', weight_np) + # Construct Conv node + new_conv = onnx.helper.make_node( + 'Conv', + [input_value.name, concat_node.name + '_weight'], + [concat_node.output[0]], + name = concat_node.name + '_fused', + dilations = [1, 1], + group = 1, + kernel_shape = [3, 3], + strides = [2, 2], + pads = [0, 0, 2, 2] + ) + # Delete old nodes, weights and value_infos + remove_value_infos(input_value.name) + remove_nodes(input_value.name) + # Replace node + g.node.append(weight_node) + g.node.append(new_conv) + + +def fuse_relu_min_into_clip(g): + node_to_del = [] + for node in g.node: + # Check Min node + if node.op_type != 'Min': + continue + min_node = node + # Check Constant node + min_const = helper.find_node_by_output_name(g, min_node.input[1]) + if not min_const or min_const.op_type != 'Constant': + continue + min_shape, min_value = helper.constant_to_list(min_const) + if min_shape != 1: + continue + # Check Relu node + relu_node = helper.find_node_by_output_name(g, min_node.input[0]) + if not relu_node or relu_node.op_type != 'Relu': + continue + + # Create Clip node + relu_min_const_node = helper.list_to_constant(relu_node.name+'_min_value', [], [0.0]) + clip_node = onnx.helper.make_node( + "Clip", + [relu_node.input[0], relu_min_const_node.output[0], min_const.output[0]], + [min_node.output[0]], + name=min_node.name + ) + + node_to_del.extend([relu_node, min_node]) + + old_relu_const_val_info = helper.find_value_by_name(g, min_node.input[0]) + if old_relu_const_val_info: + g.value_info.remove(old_relu_const_val_info) + g.node.extend([relu_min_const_node, clip_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/general_graph.py b/tools/optimizer_scripts/tools/general_graph.py new file mode 100644 index 0000000..352445b --- /dev/null +++ b/tools/optimizer_scripts/tools/general_graph.py @@ -0,0 +1,83 @@ +from collections import deque + +class Node: + """A Node which maps a node proto. It has pointers to its parents and + children. + """ + def __init__(self, onnx_node): + """Initialize a node. This initialization only set up the mapping to + node proto. The pointers should be set up by outside. + """ + self.name = None + self.parents = [] + self.children = [] + self.proto = None + self.output_value = None + if onnx_node is not None: + self.name = onnx_node.name + self.proto = onnx_node + +class Graph: + """A graph which is constructed from the onnx proto. + """ + def __init__(self, onnx_graph): + """Construct the graph from onnx. + """ + self.input_nodes = [] + self.output_nodes = [] + self.name2node = {} + self.output2node = {} + self.proto = onnx_graph + # Add input nodes + for value in onnx_graph.input: + input_node = Node(None) + input_node.name = "Input_" + value.name + input_node.output_value = value + self.name2node[input_node.name] = input_node + self.output2node[value.name] = input_node + self.input_nodes.append(input_node) + output_value_names = [value.name for value in onnx_graph.output] + # Add regular nodes + for onnx_node in onnx_graph.node: + node = Node(onnx_node) + self.name2node[node.name] = node + self.output2node[onnx_node.output[0]] = node + for value_name in onnx_node.input: + node.parents.append(self.output2node[value_name]) + self.output2node[value_name].children.append(node) + if onnx_node.output[0] in output_value_names: + self.output_nodes.append(node) + # Add value infos + for value in onnx_graph.value_info: + node = self.output2node[value.name] + node.output_value = value + def get_sorted_node_list(self): + """Return a node list in topological order. + """ + visited = set() + todo = deque() + result = [] + for node in self.input_nodes: + todo.append(node) + visited.add(node) + for onnx_node in self.proto.node: + if onnx_node.op_type == "Constant": + node = self.name2node[onnx_node.name] + todo.append(node) + visited.add(node) + while todo: + node = todo.popleft() + result.append(node) + for child in node.children: + if child in visited: + continue + ready = True + for child_parent in child.parents: + if child_parent in visited: + continue + ready = False + break + if ready: + todo.append(child) + visited.add(child) + return result diff --git a/tools/optimizer_scripts/tools/helper.py b/tools/optimizer_scripts/tools/helper.py new file mode 100644 index 0000000..18bc1e3 --- /dev/null +++ b/tools/optimizer_scripts/tools/helper.py @@ -0,0 +1,621 @@ +"""This module contains helper functions that do not modify the graph. +""" +import onnx +import onnx.helper +import struct +import numpy as np +import logging + +__ONNX_VERSION__ = -1 + +logger = logging.getLogger("optimizer_scripts") + +def setup_current_opset_version(m): + global __ONNX_VERSION__ + __ONNX_VERSION__ = m.opset_import[0].version + if __ONNX_VERSION__ not in [11]: + raise RuntimeError('Only support opset 11, but got ' + str(__ONNX_VERSION__)) + +def get_current_opset_version(): + if __ONNX_VERSION__ == -1: + raise RuntimeError('do setup_current_opset_version first please') + return __ONNX_VERSION__ + +def find_nodes_by_input_name(g, name): + nodes = [] + for node in g.node: + if name in node.input: + nodes.append(node) + return nodes + +def find_node_by_output_name(g, name): + """ + Find a node in the graph by its output name + + :param g: the onnx graph\\ + :param name: the target node output name\\ + :returns: the node find by name + """ + for i in g.node: + if name in i.output: + return i + return None + +def find_node_by_node_name(g, name): + """ + Find a node in the graph by its output name + + :param g: the onnx graph\\ + :param name: the target node output name\\ + :returns: the node find by name + """ + for i in g.node: + if i.name == name: + return i + return None + +def find_following_nodes_by_input_value_name(g, name): + """ Find the following nodes of a specific value. + + :param g: the onnx graph. \\ + :param name: the value name. \\ + :return: a list of following nodes. + """ + return find_nodes_by_input_name(g, name) + +def find_value_by_name(g, name): + """ + Find a value_info in the graph by name + + :param g: the onnx graph\\ + :param name: the target value_info name\\ + :returns: the value_info find by name + """ + for i in g.value_info: + if i.name == name: + return i + return None + +def find_output_by_name(g, name): + """ + Find a value_info in the graph by name + + :param g: the onnx graph\\ + :param name: the target value_info name\\ + :returns: the value_info find by name + """ + for i in g.output: + if i.name == name: + return i + return None + +def find_input_by_name(g, name): + """ + Find a input in the graph by name + + :param g: the onnx graph\\ + :param name: the target input name\\ + :returns: the input find by name + """ + for i in g.input: + if i.name == name: + return i + return None + +def list_to_constant(name, shape, data, data_type=None): + """Generate a constant node using the given infomation. + + :name: the node name and the output value name\\ + :shape: the data shape\\ + :data: the data itself\\ + :returns: the generated onnx constant node + """ + if not data_type: + if isinstance(data, int): + data_type = onnx.helper.TensorProto.INT64 + elif isinstance(data, float): + data_type = onnx.helper.TensorProto.FLOAT + elif len(data) > 0 and isinstance(data[0], int): + data_type = onnx.helper.TensorProto.INT64 + else: + data_type = onnx.helper.TensorProto.FLOAT + tensor = onnx.helper.make_tensor( + name, + data_type, + shape, + data + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + [name], + name = name, + value = tensor + ) + return new_w_node + + +def scaler_to_constant(name, data, data_type=None): + """Generate a constant node using the given infomation. + + :name: the node name and the output value name\\ + :shape: the data shape\\ + :data: the data itself\\ + :returns: the generated onnx constant node + """ + if not data_type: + if isinstance(data, int): + data_type = onnx.helper.TensorProto.INT64 + elif isinstance(data, float): + data_type = onnx.helper.TensorProto.FLOAT + else: + logger.error("Cannot create scaler constant with a list.") + exit(1) + tensor = onnx.helper.make_tensor( + name, + data_type, + None, + [data] + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + [name], + name = name, + value = tensor + ) + return new_w_node + + +def numpy_to_constant(name, np_array): + return list_to_constant(name, np_array.shape, np_array.flatten().tolist()) + +def constant_to_list(node): + """Generate a list from the constant node + + :node: the Constant node\\ + :returns: the shape of the constant node, the data of the constant node + """ + tensor = node.attribute[0].t + # 1. check data type + # 2. get data from raw or data + # 3. get shape from dim + if tensor.data_type == onnx.helper.TensorProto.INT32: + if len(tensor.int32_data) != 0: + data = list(tensor.int32_data) + else: + data = [i[0] for i in struct.iter_unpack('i', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.INT64: + if len(tensor.int64_data) != 0: + data = list(tensor.int64_data) + else: + data = [i[0] for i in struct.iter_unpack('q', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.INT8: + if len(tensor.int32_data) != 0: + data = list(tensor.int32_data) + else: + data = [i[0] for i in struct.iter_unpack('b', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.FLOAT: + if len(tensor.float_data) != 0: + data = list(tensor.float_data) + else: + data = [i[0] for i in struct.iter_unpack('f', tensor.raw_data)] + elif tensor.data_type == onnx.helper.TensorProto.DOUBLE: + if len(tensor.double_data) != 0: + data = list(tensor.double_data) + else: + data = [i[0] for i in struct.iter_unpack('d', tensor.raw_data)] + else: + print("Not supported data type {}".format(tensor.data_type)) + raise RuntimeError + if len(tensor.dims) == 0: + shape = len(data) + else: + shape = list(tensor.dims) + return shape, data + +def constant_to_numpy(node): + """Generate a numpy array from the constant node + + :node: the Constant node\\ + :returns: the numpy array + """ + shape, data = constant_to_list(node) + return np.array(data).reshape(shape) + +def all_constant_input(node): + """Find the inputs of the given node. If the inputs of this node are all\\ + constant nodes, return True. Otherwise, return False. + + :param node: the input node which has a Node structure\\ + :return: whether the node of this node are all constant + """ + if node.proto is None: + return False + isConstant = True + for parent in node.parents: + if parent.proto is None or parent.proto.op_type != 'Constant': + isConstant = False + break + return isConstant + +def get_padding(size, kernel_size, strides): + """ Calculate the padding array for same padding in the Tensorflow fashion.\\ + See https://www.tensorflow.org/api_guides/python/nn#Convolution for more. + """ + if size[0] % strides[0] == 0: + pad_h = max(kernel_size[0] - strides[0], 0) + else: + pad_h = max(kernel_size[0] - (size[0] % strides[0]), 0) + if size[1] % strides[1] == 0: + pad_w = max(kernel_size[1] - strides[1], 0) + else: + pad_w = max(kernel_size[1] - (size[1] % strides[1]), 0) + return [pad_h//2, pad_w//2, pad_h-pad_h//2, pad_w-pad_w//2] + +def get_shape_from_value_info(value): + """Get shape from a value info. + + :param value: the value_info proto\\ + :return: list of the shape + """ + return [d.dim_value for d in value.type.tensor_type.shape.dim] + +def find_size_shape_from_value(value): + ''' + Find the size of data within the value_info object. + :param value: value_info + :return: int size and list shape of the data in the value_info + ''' + if not value: + return None, None + if not value.type.tensor_type.shape.dim: + return 0, [] + size = 1 + shape = [] + for i in range(len(value.type.tensor_type.shape.dim)): + size *= max(1, value.type.tensor_type.shape.dim[i].dim_value) + shape.append(max(1, value.type.tensor_type.shape.dim[i].dim_value)) + + return size, shape + + +def get_attribute_by_name(node, attr_name): + """Get attribute proto with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: a str for the name of the target.\\ + :return: if found, return the attribute_proto. Else, return None. + """ + for attr in node.attribute: + if attr.name == attr_name: + return attr + return None + +def get_list_attribute_by_name(node, attr_name: str, attr_type: str): + """Get list attribute with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: a str for the name of the target.\\ + :param attr_type: a str which should be "float" or "int".\\ + :return: if found, return the list. Else, return None. + """ + attr_proto = get_attribute_by_name(node, attr_name) + if attr_proto is None: + return None + if attr_type == "int": + if len(attr_proto.ints) == 0: + return None + else: + return list(attr_proto.ints) + elif attr_type == "float": + if len(attr_proto.ints) == 0: + return None + else: + return list(attr_proto.floats) + else: + print("Warning: undefined type for list attribute extraction") + return None + +def get_var_attribute_by_name(node, attr_name: str, attr_type: str): + """Get variable attribute with specific name in the given node proto. + + :param node: the node proto.\\ + :param attr_name: str for the name of the target.\\ + :param attr_type: str which should be "float", "int", "string" or "tensor".\\ + :return: if found, return the variable. Else, return None. + """ + attr_proto = get_attribute_by_name(node, attr_name) + if attr_proto is None: + return None + if attr_type == "int": + return attr_proto.i + elif attr_type == "float": + return attr_proto.f + elif attr_type == "string": + if type(attr_proto.s) == type(b'abc'): + return attr_proto.s.decode("utf-8") + else: + return attr_proto.s + elif attr_type == "tensor": + return attr_proto.t + else: + print("Warning: undefined type for variable attribute extraction") + return None + +def flatten_with_depth(data, depth): + output = [] + if type(data) not in [type(np.array([1])), type([1])]: + return [[data, 0]] + for item in data: + if type(item) not in [type(np.array([1])), type([1])]: + output.append([item, depth+1]) + else: + output += flatten_with_depth(item, depth+1) + return output + +def flatten_to_list(data): + flatten_depth = flatten_with_depth(data, 0) + flat_data = [item[0] for item in flatten_depth] + return flat_data + +def get_shape(data): + shape = [] + if type(data) not in [type(np.array([1])), type([1])]: + return [] + sub_data = data[0] + shape.append(len(data)) + while type(sub_data) in [type(np.array([1])), type([1])]: + shape.append(len(sub_data)) + sub_data = sub_data[0] + return shape + + +def slice_data(data, starts, ends, axes): + flat_data = [item[0] for item in flatten_with_depth(data, 0)] + shape = get_shape(data) + + starts_updated = [] + ends_updated = [] + for i in range(len(starts)): + start_updated = min(starts[i], shape[i]-1) % shape[i] + starts_updated.append(start_updated) + for j in range(len(starts)): + if ends[j] >= shape[j]: + end_updated = shape[j] + else: + end_updated = min(ends[j], shape[j]) % shape[j] + ends_updated.append(end_updated) + + index_slices = [] + for i in range(len(shape)): + if i not in axes: + index_slices.append(list(range(shape[i]))) + else: + axe_ind = axes.index(i) + index_slices.append(list(range(starts_updated[axe_ind], ends_updated[axe_ind]))) + + indices = [1] + for i in range(len(shape)-1, -1, -1): + step = np.prod(shape[i+1:]) + temp_pos = indices + new_indices = [] + for n in index_slices[i]: + for pos in temp_pos: + new_indices.append(int(n*step+pos)) + indices = new_indices + + sliced_data = [flat_data[k-1] for k in indices] + + # reshape to correct shape. + new_shape = [] + for i in range(len(shape)): + if i not in axes: + new_shape.append(shape[i]) + else: + axe_ind = axes.index(i) + new_shape.append(ends_updated[axe_ind]-starts_updated[axe_ind]) + if any([dim < 1 for dim in new_shape]): + raise RuntimeError('Invalid starts ends.') + + sliced_data = np.reshape(sliced_data, new_shape) + + return sliced_data + +def concatenate(data_sets, axis): + # check shapes + shapes = [] + shapes_ = [] + for data_set in data_sets: + shape = get_shape(data_set) + shapes.append(list(shape)) + shape.pop(axis) + shapes_.append(shape) + if not all([s == shapes_[0] for s in shapes_]): + raise RuntimeError('data sets shapes do not match') + + new_dim = sum([s[axis] for s in shapes]) + new_shape = list(shapes[0]) + new_shape[axis] = new_dim + + flat_data_sets = [] + for data_set in data_sets: + flat_data_sets.append(flatten_to_list(data_set)) + + sub_block_size = 1 + for i in range(axis+1, len(shapes[0])): + sub_block_size *= shapes[0][i] + + split_num = 1 + for i in range(axis): + split_num *= shapes[0][i] + + total_flat_data = [] + for i in range(split_num): + for j in range(len(shapes)): + block_size = sub_block_size*shapes[j][axis] + total_flat_data.extend(flat_data_sets[j][i*block_size:(i+1)*block_size]) + + new_data = np.reshape(total_flat_data, new_shape) + + return new_data + + +def broadcast_data_sets(data_set_1, data_set_2): + shape1 = get_shape(data_set_1) + shape2 = get_shape(data_set_2) + + # compare shapes and get broadcasted shape + list_a, list_b = (shape1, shape2) if len(shape1) > len(shape2) else (shape2, shape1) + while len(list_a) > len(list_b): + list_b.insert(0, 0) + broadcasted_shape = [] + for i in range(len(list_a)): + if list_b[i] == 0: + broadcasted_shape.append(list_a[i]) + elif list_b[i] == 1: + broadcasted_shape.append(list_a[i]) + elif list_a[i] == 1: + broadcasted_shape.append(list_b[i]) + elif list_a[i] == list_b[i]: + broadcasted_shape.append(list_a[i]) + else: + raise RuntimeError('Can not broadcast two data sets') + + # prepare data for broadcasting. + shape1 = list(map(lambda x:x if x != 0 else 1, shape1)) + shape2 = list(map(lambda x:x if x != 0 else 1, shape2)) + data_1 = np.reshape(data_set_1, shape1) + data_2 = np.reshape(data_set_2, shape2) + + for i in range(len(shape1)): + if shape1[i] != broadcasted_shape[i]: + new_data_total = [list(data_1) for _ in range(broadcasted_shape[i])] + data_1 = concatenate(new_data_total, axis=i) + for i in range(len(shape2)): + if shape2[i] != broadcasted_shape[i]: + new_data_total = [list(data_2) for _ in range(broadcasted_shape[i])] + data_2 = concatenate(new_data_total, axis=i) + + return data_1, data_2 + + +def add(data_set_1, data_set_2): + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + + flat_data_1 = flatten_to_list(broadcasted_data_1) + flat_data_2 = flatten_to_list(broadcasted_data_2) + shape = get_shape(broadcasted_data_1) + res = [] + for i in range(len(flat_data_1)): + res.append(flat_data_1[i]+flat_data_2[i]) + + res = np.reshape(res, shape) + + return res + + +def reduceprod(data_set, axis, keepdims=1): + flat_data = flatten_to_list(data_set) + old_shape = get_shape(data_set) + + temp_shape = old_shape + temp_flat_data = flat_data + for ax in axis: + split_num = 1 + step = 1 + for i in range(ax): + split_num *= temp_shape[i] + for i in range(ax+1, len(temp_shape)): + step *= temp_shape[i] + + block_size = len(temp_flat_data)//split_num + new_flat_data = [] + for j in range(split_num): + block_data = temp_flat_data[j*block_size:(j+1)*block_size] + reduced_block_data = [] + for k in range(step): + val = block_data[k] + for l in range(1, block_size//step): + val *= block_data[k+l*step] + reduced_block_data.append(val) + new_flat_data.extend(reduced_block_data) + temp_flat_data = new_flat_data + temp_shape[ax] = 1 + + new_flat_data = temp_flat_data + new_shape = temp_shape + if not keepdims: + axis = sorted(list(axis)) + for pos in axis[::-1]: + new_shape.pop(pos) + + return np.reshape(new_flat_data, new_shape) + + +def transpose(data_set, permutation): + # find series of local swaps + data_set = list(data_set) + perm = list(permutation) + shape = get_shape(data_set) + flat_data = flatten_to_list(data_set) + assert set(perm) == set(range(len(shape))), 'invalid permutation' + + new_shape = [shape[i] for i in perm] + swaps = [] + bubbled = True + while bubbled: + bubbled = False + for i in range(len(new_shape)-1): + if perm[i] > perm[i+1]: + swaps.append([i, i+1]) + p_1, p_2 = perm[i], perm[i+1] + perm[i], perm[i+1] = p_2, p_1 + bubbled = True + + # apply local swaps + current_shape = list(shape) + temp_flat_data = flat_data + + for swap in swaps[::-1]: + ind_1, ind_2 = swap[0], swap[1] + dim_1 = current_shape[ind_1] + dim_2 = current_shape[ind_2] + split_num = 1 + block_size = 1 + + for i in range(ind_1): + split_num *= current_shape[i] + for i in range(ind_2+1, len(current_shape)): + block_size *= current_shape[i] + + data_blocks = np.reshape(temp_flat_data, [-1, block_size]) + flat_data_1 = [] + for k in range(split_num): + block = [] + for m in range(dim_2): + for n in range(dim_1): + block_pos = k*dim_1*dim_2 + n*dim_2+m + block.extend(data_blocks[block_pos]) + flat_data_1.extend(block) + + temp_flat_data = flat_data_1 + current_shape[ind_1] = dim_2 + current_shape[ind_2] = dim_1 + + return np.reshape(temp_flat_data, current_shape) + +def subtract(data_set_1, data_set_2): + broadcasted_data_1, broadcasted_data_2 = broadcast_data_sets(data_set_1, data_set_2) + + shape = get_shape(broadcasted_data_1) + flat_data_1 = flatten_to_list(broadcasted_data_1) + flat_data_2 = flatten_to_list(broadcasted_data_2) + + substracted_data = [flat_data_1[i] - flat_data_2[i] for i in range(len(flat_data_1))] + + new_data = np.reshape(substracted_data, shape) + + return new_data + + \ No newline at end of file diff --git a/tools/optimizer_scripts/tools/modhelper.py b/tools/optimizer_scripts/tools/modhelper.py new file mode 100644 index 0000000..5e8302f --- /dev/null +++ b/tools/optimizer_scripts/tools/modhelper.py @@ -0,0 +1,78 @@ +"""This module contains helper functions that do graph modifications. +""" + +import onnx +from . import helper + + +def replace_node_input(node, old_input, new_input): + for i, input_name in enumerate(node.input): + if input_name == old_input: + node.input[i] = new_input + +def delete_nodes(g, node_list): + node_to_delete = [] + #Find target nodes + for node in g.node: + if node.name not in node_list: + continue + else: + node_to_delete.append(node) + if len(node_list) != len(node_to_delete): + print("Some nodes do not exist in the graph. Skipping them.") + for node in node_to_delete: + # Check the node whether if it is valid to delete + if len(node.input) == 0: + print("Deleting an Constant node. Please make sure you also delete all its following nodes") + elif len(node.input) > 1: + print("Warning: Node {} has more than one input. This script cannot delete merge nodes.".format(node.name)) + # Connect the nodes around the target node. + # Set the following node input as the previous node output. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(node.input) == 0: + for following_node in following_nodes: + following_node.input.remove(node.output[0]) + elif len(following_nodes) > 0 and len(node.input) == 1 and helper.find_input_by_name(g, node.input[0]) is not None: + # The node input is an input + new_input = helper.find_value_by_name(g, node.output[0]) + g.input.append(new_input) + g.input.remove(helper.find_input_by_name(g, node.input[0])) + g.value_info.remove(new_input) + elif len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, node.output[0], node.input[0]) + else: + # If the node is the output, replace the output with the previous input. + value = helper.find_value_by_name(g, node.input[0]) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == node.output[0]: + g.output.extend([value]) + else: + g.output.extend([output_value]) + # Remove the node and value info. + g.node.remove(node) + +def delete_input(g, target_list): + for name in target_list: + input_value = helper.find_input_by_name(g, name) + if input_value is None: + print("Cannot find input {}".format(name)) + continue + g.input.remove(input_value) + +def delete_output(g, target_list): + for name in target_list: + output_value = helper.find_output_by_name(g, name) + if output_value is None: + print("Cannot find output {}".format(name)) + continue + g.output.remove(output_value) + +def delete_value_with_name_if_exists(g, name): + value = helper.find_value_by_name(g, name) + if value is not None: + g.value_info.remove(value) diff --git a/tools/optimizer_scripts/tools/other.py b/tools/optimizer_scripts/tools/other.py new file mode 100644 index 0000000..171179e --- /dev/null +++ b/tools/optimizer_scripts/tools/other.py @@ -0,0 +1,1200 @@ +"""Optimization functions that are not fusing, eliminating or replacing. In most +cases, these are the modifications on the original nodes. +""" +import struct +import collections +import numpy as np +import onnx.helper +import onnxoptimizer as optimizer +import math +import logging +from . import helper +from .modhelper import replace_node_input +import copy +from .helper import logger + + +def polish_model(model): + ''' + This function combines several useful utility functions together. + ''' + onnx.checker.check_model(model) + onnx.helper.strip_doc_string(model) + model = onnx.shape_inference.infer_shapes(model) + model = optimizer.optimize(model) + onnx.checker.check_model(model) + return model + + +def format_value_info_shape(g): + """ + Replace -1 and 0 batch size in value info + + :param g: the onnx graph + """ + for value in g.input: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + for value in g.output: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value <= 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + for value in g.value_info: + if len(value.type.tensor_type.shape.dim) > 0 and\ + (value.type.tensor_type.shape.dim[0].dim_value < 0 or\ + not isinstance(value.type.tensor_type.shape.dim[0].dim_value, int)): + value.type.tensor_type.shape.dim[0].dim_value = 1 + +def add_name_to_node(g): + """ + If no name presents, give a name based on output name. + + :param g: the onnx graph + """ + for node in g.node: + if len(node.name) == 0: + node.name = node.output[0] + +def rename_all_node_name(g): + """ + rename all nodes if the node name is a number: + + new_name = old_name + "_kn" + + :param g: the onnx graph + """ + + for node in g.node: + if not node.name.isdigit(): + # Skip not number names + continue + new_node_name = node.name + "_kn" + new_node_output0_name = node.output[0] + "_kn" + + # in order to keep same output node name, skip if it is output node. + output_value_info = helper.find_output_by_name(g, node.output[0]) + if output_value_info != None: + continue + + # rename the input of all the following nodes + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + replace_node_input(following_node, node.output[0], new_node_output0_name ) + + # rename value info + value_info = helper.find_value_by_name(g, node.output[0]) + if value_info != None: + value_info.name = new_node_output0_name + + # rename node + node.output[0] = new_node_output0_name + node.name = new_node_name + +def add_output_to_value_info(g): + """ + If output does not present in value_info, copy one + + :param g: the onnx graph + """ + for output in g.output: + if helper.find_value_by_name(g, output.name) is None: + g.value_info.extend([output]) + +def find_first_sequential_output(g, node): + for value_name in node.output: + value = helper.find_output_by_name(g, value_name) + if value is not None: + return value + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if len(next_nodes) == 0: + # No following nodes + return None + return find_first_sequential_output(g, next_nodes[0]) + +def remove_nodes(g, cut_nodes=[], cut_types=[]): + node_to_delete = [] + #Find target nodes + for node in g.node: + if node.name not in cut_nodes and node.op_type not in cut_types: + continue + else: + node_to_delete.append(node) + # Mapping originnal outputs to new outputs. This mapping is to keep the output order. + output_mapping = {} + new_output = set() + for node in node_to_delete: + original_output = find_first_sequential_output(g, node) + if original_output.name not in output_mapping: + output_mapping[original_output.name] = [] + for input_name in node.input: + value = helper.find_value_by_name(g, input_name) + if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + output_mapping[original_output.name].append(value) + new_output.add(value.name) + # Remove them + while node_to_delete: + g.node.remove(node_to_delete.pop()) + # Remove unreachable nodes + visited_values = set() + unused_constant_map = {} + for input_value in g.input: + visited_values.add(input_value.name) + for node in g.node: + if node.op_type == 'Constant': + visited_values.add(node.output[0]) + unused_constant_map[node.output[0]] = node + continue + can_reach = True + for input_name in node.input: + if input_name not in visited_values: + can_reach = False + break + if can_reach: + for output_name in node.output: + visited_values.add(output_name) + else: + node_to_delete.append(node) + # Mapping outputs again + for node in node_to_delete: + original_output = find_first_sequential_output(g, node) + if original_output is None: + continue + if original_output.name not in output_mapping: + output_mapping[original_output.name] = [] + for input_name in node.input: + value = helper.find_value_by_name(g, input_name) + if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: + output_mapping[original_output.name].append(value) + new_output.add(value.name) + # Remove them + while node_to_delete: + g.node.remove(node_to_delete.pop()) + #Remove unused constants + for node in g.node: + for input_name in node.input: + if input_name in unused_constant_map: + del unused_constant_map[input_name] + for node in unused_constant_map.values(): + g.node.remove(node) + #Remove unreachable value infos + reachable_values = set() + for input_value in g.input: + reachable_values.add(input_value.name) + for node in g.node: + for input_name in node.input: + reachable_values.add(input_name) + for output_name in node.output: + reachable_values.add(output_name) + value_to_remove = [] + for value_info in g.value_info: + if value_info.name not in reachable_values: + value_to_remove.append(value_info) + while value_to_remove: + value_info = value_to_remove.pop() + g.value_info.remove(value_info) + # Reorder output + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name in reachable_values: + logger.info("Keep output {}".format(output_value.name)) + g.output.extend([output_value]) + elif output_value.name in output_mapping: + real_outputs = [i for i in output_mapping[output_value.name] if i.name in reachable_values] + logger.info("Replace output {} with {}".format(output_value.name, [i.name for i in real_outputs])) + g.output.extend(real_outputs) + else: + logger.info("Abandon output {}".format(output_value.name)) + continue + +def transpose_B_in_Gemm(g): + """ + If transB is set in Gemm, transpose it + + :param g: the onnx graph + """ + for node in g.node: + if node.op_type != 'Gemm': + continue + do_it = False + for attr in node.attribute: + if attr.name == "transB": + if attr.i == 1: + attr.i = 0 + do_it = True + break + if not do_it: + continue + # Transpose the weight and its output value + w_node = helper.find_node_by_output_name(g, node.input[1]) + w_output = helper.find_value_by_name(g, node.input[1]) + dim_0 = w_output.type.tensor_type.shape.dim[0].dim_value + dim_1 = w_output.type.tensor_type.shape.dim[1].dim_value + w_output.type.tensor_type.shape.dim[0].dim_value = dim_1 + w_output.type.tensor_type.shape.dim[1].dim_value = dim_0 + w_node.attribute[0].t.dims[0] = dim_1 + w_node.attribute[0].t.dims[1] = dim_0 + if w_node.attribute[0].t.raw_data: + raw_data = w_node.attribute[0].t.raw_data + fl_data = [i[0] for i in struct.iter_unpack('f', raw_data)] + else: + fl_data = w_node.attribute[0].t.float_data + w = np.reshape(fl_data, (dim_0, dim_1)) + w = w.transpose((1, 0)).flatten() + if w_node.attribute[0].t.raw_data: + buf = struct.pack('%sf' % len(w), *w) + w_node.attribute[0].t.raw_data = buf + else: + for i in range(len(fl_data)): + w_node.attribute[0].t.float_data[i] = w[i] + +def topological_sort(g): + """ + Topological sort all the layers. + Assume a node do not take the same value as more than one inputs. + + :param g: the onnx graph + """ + # TODO: Topological sort on the same branch + # Map from node name to its input degree + in_degree = {} + # Map from value info name to the nodes using it as input + output_nodes = collections.defaultdict(list) + # Map from node name to node object + node_map = {} + to_add = collections.deque() + # init + length = len(g.node) + for _ in range(length): + node = g.node.pop() + node_map[node.name] = node + if len([i for i in node.input if i != '']) == 0: + to_add.append(node.name) + else: + in_degree[node.name] = len([i for i in node.input if i != '']) + for input_name in node.input: + if input_name == '': + continue + output_nodes[input_name].append(node.name) + # sort + # deal with input first + for value_info in g.input: + input_name = value_info.name + for node_name in output_nodes[input_name]: + in_degree[node_name] -= 1 + if in_degree[node_name] == 0: + to_add.append(node_name) + del in_degree[node_name] + # main sort loop + sorted_nodes = [] + while to_add: + node_name = to_add.pop() + node = node_map[node_name] + del node_map[node_name] + sorted_nodes.append(node) + # Expect only one output name for each node + next_node_names = [] + for output_name in node.output: + next_node_names.extend(output_nodes[output_name]) + for next_node_name in next_node_names: + in_degree[next_node_name] -= 1 + if in_degree[next_node_name] == 0: + to_add.append(next_node_name) + del in_degree[next_node_name] + g.node.extend(sorted_nodes) + if in_degree: + raise RuntimeError("Unreachable nodes exist: {}".format(in_degree.keys())) + if node_map: + raise RuntimeError("Unused nodes exist: {}".format(node_map.keys())) + +def remove_zero_value_info(g): + value_info_list = list(g.value_info) + for vi in value_info_list: + if not vi.type.tensor_type.shape.dim: + g.value_info.remove(vi) + + for dim in vi.type.tensor_type.shape.dim: + if dim.dim_value == 0: + g.value_info.remove(vi) + break + +def inference_shapes(m): + while len(m.graph.value_info) > 0: + m.graph.value_info.pop() + g = m.graph + inferencing_shapes = True + while inferencing_shapes: + inferencing_shapes = False + if inference_cov_shape(g): + inferencing_shapes = True + if inference_upsample_shape(g): + inferencing_shapes = True + if inference_resize_shape(g): + inferencing_shapes = True + if inference_split_shape(g): + inferencing_shapes = True + if inferencing_shapes: + topological_sort(g) + m = polish_model(m) + g = m.graph + remove_zero_value_info(g) + m = polish_model(m) + return m + +def inference_resize_shape(g): + for node in g.node: + if node.op_type != 'Resize': + continue + + output_value = helper.find_value_by_name(g, node.output[0]) + output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None else output_value + if output_value is not None: + continue + + if len(node.input) == 4: # input: X, roi, scales, sizes + shape_node = helper.find_node_by_output_name(g, node.input[3]) + if shape_node.op_type != 'Constant': + continue + + _, shape_value = helper.constant_to_list(shape_node) + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value]) + g.value_info.extend([output_value]) + return True + else: + # If output shape is not given, inference from scales + # Get the input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + continue + shape_value = helper.get_shape_from_value_info(input_value) + scales_node = helper.find_node_by_output_name(g, node.input[2]) + if scales_node.op_type != 'Constant': + continue + _, scales_value = helper.constant_to_list(scales_node) + for i in range(len(shape_value)): + shape_value[i] *= scales_value[i] + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.FLOAT, + [int(v) for v in shape_value]) + g.value_info.extend([output_value]) + return True + return False + +def inference_upsample_shape(g): + """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ + do it ourselves. This function only inference the next upsample without\\ + output shape each time. + + :param g: the graph\\ + :return: True if any Upsample shape is generated. Otherwise, False. + """ + for node in g.node: + if node.op_type != 'Upsample': + continue + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value and helper.get_shape_from_value_info(output_value): + continue + # Get input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + continue + #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) + if not helper.get_shape_from_value_info(input_value): + continue + #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) + input_shape = helper.get_shape_from_value_info(input_value) + # Get upsample weight + weight_node = helper.find_node_by_output_name(g, node.input[1]) + weight_shape, weight = helper.constant_to_list(weight_node) + if len(input_shape) != weight_shape[0]: + raise RuntimeError("Unmatch input shape and weight shape: {} vs {}".format(input_shape, weight_shape)) + # Calculate shape + output_shape = list(input_shape) + for i in range(len(output_shape)): + output_shape[i] = int(input_shape[i] * weight[i]) + output_value = onnx.helper.make_tensor_value_info( + node.output[0], + input_value.type.tensor_type.elem_type, + output_shape) + g.value_info.extend([output_value]) + return True + return False + +def inference_cov_shape(g): + processed = False + for node in g.node: + # Check for Conv output shape need to be inferrenced. + if node.op_type != 'Conv': + continue + # Input shape is not ready yet. Skip. + input_value_info = helper.find_value_by_name(g, node.input[0]) + if not input_value_info: + input_value_info = helper.find_input_by_name(g, node.input[0]) + if not input_value_info: + continue + _, input_shape = helper.find_size_shape_from_value(input_value_info) + if not input_shape: + continue + # Output shape is already there. Skip. + output_value_info = helper.find_value_by_name(g, node.output[0]) + if not output_value_info: + output_value_info = helper.find_output_by_name(g, node.output[0]) + if output_value_info and \ + helper.get_shape_from_value_info(output_value_info): + continue + + # Now start the inference. + # Check kernel shape + kernel_value_info = helper.find_value_by_name(g, node.input[1]) + _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) + if not kernel_shape: + continue + # If auto_pad is set, use the auto_pad. + auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string') + pads = None + if auto_pad is not None and auto_pad != 'NOTSET': + if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER': + new_output_value_info = onnx.helper.make_tensor_value_info( + node.output[0], + input_value_info.type.tensor_type.elem_type, + [input_shape[0], kernel_shape[0], input_shape[2], input_shape[3]] + ) + if output_value_info: + g.value_info.remove(output_value_info) + g.value_info.extend([new_output_value_info]) + processed = True + continue + elif auto_pad == 'VALID': + pads = [0, 0, 0, 0] + else: + logger.error("Unrecognized auto_pad value: " + str(auto_pad)) + exit(1) + + strides = helper.get_attribute_by_name(node, 'strides').ints + if not pads: + pads = helper.get_attribute_by_name(node, 'pads').ints + dilation = helper.get_attribute_by_name(node, 'dilations').ints + + # Pytorch model has the case where strides only have one number + if len(strides) == 1: + strides.append(strides[0]) + if len(dilation) == 1: + dilation.append(dilation[0]) + + H = math.floor((input_shape[2]+pads[0]+pads[2]-\ + dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) + W = math.floor((input_shape[3]+pads[1]+pads[3]-\ + dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) + output_shape = [input_shape[0], kernel_shape[0], H, W] + + new_output_value_info = onnx.helper.make_tensor_value_info( + node.output[0], + input_value_info.type.tensor_type.elem_type, + output_shape + ) + + processed = True + + if output_value_info: + g.value_info.remove(output_value_info) + g.value_info.extend([new_output_value_info]) + + return processed + + +def inference_split_shape(g): + processed = False + for node in g.node: + if node.op_type != 'Split': + continue + + input_val_info = helper.find_value_by_name(g, node.input[0]) + if not input_val_info: + input_val_info = helper.find_input_by_name(g, node.input[0]) + if not input_val_info: + continue + + _, input_shape = helper.find_size_shape_from_value(input_val_info) + if not input_shape: + continue + + output_val_names = list(node.output) + output_vals = [helper.find_value_by_name(g, val_name) for val_name in output_val_names] + + output_shapes = [helper.find_size_shape_from_value(output_val)[1] for output_val in output_vals] + if not any([len(s) == 0 for s in output_shapes]): + continue + + for att in node.attribute: + if att.name == 'axis': + axis = att.i + else: + split = list(att.ints) + + new_output_vals = [] + for i in range(len(output_val_names)): + new_shape = list(input_shape) + new_shape[axis] = split[i] + new_output_val = onnx.helper.make_tensor_value_info( + output_val_names[i], + input_val_info.type.tensor_type.elem_type, + new_shape + ) + new_output_vals.append(new_output_val) + + for val in output_vals: + if val is not None: + g.value_info.remove(val) + g.value_info.extend(new_output_vals) + + processed = True + + return processed + + +def parse_shape_change_input(s: str): + """The input should be like 'input 1 1 224 224'. + """ + s_list = s.split(' ') + if len(s_list) < 2: + print("Cannot parse the shape change input: {}".format(s)) + return None + shape = [] + for i in range(1, len(s_list)): + shape.append(int(s_list[i])) + return s_list[0], shape + +def change_input_shape(g, target_list): + for target in target_list: + try: + name, shape = parse_shape_change_input(target) + input_value = helper.find_input_by_name(g, name) + if input_value is None: + print("Cannot find input {}".format(name)) + continue + if len(shape) != len(input_value.type.tensor_type.shape.dim): + print("The dimension doesn't match for input {}".format(name)) + continue + for i in range(len(shape)): + input_value.type.tensor_type.shape.dim[i].dim_value = shape[i] + except TypeError: + # This happens when the parser function returns None. + continue + except ValueError: + # This happens when the input cannot be converter into int + print("Cannot parse {} into name and int".format(target)) + continue + +def change_output_shape(g, target_list): + for target in target_list: + try: + name, shape = parse_shape_change_input(target) + output_value = helper.find_output_by_name(g, name) + if output_value is None: + print("Cannot find output {}".format(name)) + continue + if len(shape) != len(output_value.type.tensor_type.shape.dim): + print("The dimension doesn't match for output {}".format(name)) + continue + for i in range(len(shape)): + output_value.type.tensor_type.shape.dim[i].dim_value = shape[i] + except TypeError: + # This happens when the parser function returns None. + continue + except ValueError: + # This happens when the input cannot be converter into int + print("Cannot parse {} into name and int".format(target)) + continue + +def add_nop_conv_after(g, value_names): + """Add do-nothing depthwise Conv nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_names: a list of string which are the names of value_info. + """ + for value_name in value_names: + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + continue + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_conv" + ones = [1.0] * channel + weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) + # Construct BN node + conv_node = onnx.helper.make_node( + "Conv", + [value_name, + weight_node.output[0]], + [node_name], + name = node_name, + dilations = [1, 1], + group = channel, + kernel_shape = [1, 1], + pads = [0, 0, 0, 0], + strides = [1, 1] + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([conv_node, weight_node]) + topological_sort(g) + +def add_nop_bn_after(g, value_names): + """Add do-nothing BatchNormalization nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_names: a list of string which are the names of value_info. + """ + for value_name in value_names: + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + continue + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def add_bias_scale_bn_after(g, value_name, channel_bias, channel_scale): + """Add do-nothing BatchNormalization nodes after the given value info. It will\\ + take the given names as the inputs of the new node and replace the inputs\\ + of the following nodes. + + :param g: the graph\\ + :param value_name: a list of string which are the name of value_info. + """ + # Find the value first + value = helper.find_value_by_name(g, value_name) + if value is None: + value = helper.find_input_by_name(g, value_name) + if value is None: + value = helper.find_output_by_name(g, value_name) + if value is None: + print("Cannot find an value_info named {}".format(value_name)) + return + # Get the channel number from value info + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_scale_shift_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [len(channel_scale)], channel_scale) + bias_node = helper.list_to_constant(node_name + "_bias", [len(channel_bias)], channel_bias) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + following_nodes = helper.find_following_nodes_by_input_value_name(g, value_name) + if len(following_nodes) > 0: + for following_node in following_nodes: + replace_node_input(following_node, value_name, node_name) + else: + # If the node is the output, replace the output with the previous input. + new_value = onnx.helper.make_tensor_value_info( + node_name, + value.type.tensor_type.elem_type, + shape + ) + output_values = [] + while len(g.output): + output_values.append(g.output.pop()) + while output_values: + output_value = output_values.pop() + if output_value.name == value_name: + g.output.extend([new_value]) + else: + g.output.extend([output_value]) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def duplicate_shared_Flatten(g): + """To feed our compiler, bind Flatten with Gemm. If the output of one\\ + Flatten goes to two Gemm nodes, duplicate the Flatten. + + :param g: the graph + """ + for node in g.node: + # Find a Flatten node + if node.op_type != 'Flatten': + continue + # Check Flatten outputs. Get following Gemm + output_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + if len(output_nodes) < 2: + continue + gemm_nodes = [] + for output_node in output_nodes: + if output_node.op_type == 'Gemm': + gemm_nodes.append(output_node) + if len(gemm_nodes) < 2: + continue + # Process all the Gemm nodes except for the first one. + for i in range(1, len(gemm_nodes)): + # Duplicate + new_flatten_name = node.name + "_copy" + str(i) + new_flatten_node = onnx.helper.make_node( + "Flatten", + node.input, + [new_flatten_name], + name=new_flatten_name, + axis=1 + ) + # Connect new graph + replace_node_input(gemm_nodes[i], node.output[0], new_flatten_name) + g.node.extend([new_flatten_node]) + topological_sort(g) + +def deconv_to_conv_info_extraction(input_size, node_proto): + """Extract the information needed for deconv split. + + :param input_size: input shape of the deconv node.\\ + :param node_proto: the deconv node proto.\\ + :return: a dictionary of extracted params. + """ + attr = dict() + # Get attributes from Deconv node + attr["auto_pad"] = helper.get_var_attribute_by_name(node_proto, "auto_pad", "string") + attr["dilations"] = helper.get_list_attribute_by_name(node_proto, "dilations", "int") + attr["group"] = helper.get_var_attribute_by_name(node_proto, "group", "int") + attr["kernel_shape"] = helper.get_list_attribute_by_name(node_proto, "kernel_shape", "int") + attr["output_padding"] = helper.get_list_attribute_by_name(node_proto, "output_padding", "int") + attr["pads"] = helper.get_list_attribute_by_name(node_proto, "pads", "int") + attr["strides"] = helper.get_list_attribute_by_name(node_proto, "strides", "int") + # Get output_padding + if attr["output_padding"] is None: + if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": + attr["output_padding"] = [attr["strides"][0] - 1, attr["strides"][1]] + else: + attr["output_padding"] = [max(attr["strides"][0] - attr["kernel_shape"][0], 0), + max(attr["strides"][1] - attr["kernel_shape"][1], 0)] + # Calculate conv_padding + if attr["auto_pad"] == "SAME_LOWER" or attr["auto_pad"] == "SAME_UPPER": + pad1_h = attr["kernel_shape"][0] - (attr["kernel_shape"][0] - 1) // 2 - 1 + pad1_w = attr["kernel_shape"][1] - (attr["kernel_shape"][1] - 1) // 2 - 1 + head_h = min(attr["kernel_shape"][0] // 2, (attr["output_padding"][0] + 1) // 2) + head_w = min(attr["kernel_shape"][1] // 2, (attr["output_padding"][1] + 1) // 2) + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + elif attr["pads"] is not None: + sum_of_pads = sum(attr["pads"]) + if sum_of_pads == 0: + # Valid padding + pad1_h = attr["kernel_shape"][0] - 0 - 1 + pad1_w = attr["kernel_shape"][1] - 0 - 1 + head_h = 0 + head_w = 0 + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + else: + # Calculate output shape + tmp_output_shape = [0, 0] + tmp_output_shape[0] = attr["strides"][0] * (input_size[2] - 1) + attr["output_padding"][0] + attr["kernel_shape"][0] - attr["pads"][0] - attr["pads"][2] + tmp_output_shape[1] = attr["strides"][1] * (input_size[3] - 1) + attr["output_padding"][1] + attr["kernel_shape"][1] - attr["pads"][1] - attr["pads"][3] + # Calculate real conv output shape + tmp_center_shape = [0, 0] + tmp_center_shape[0] = (input_size[2] - 1) * attr["strides"][0] + 1 + tmp_center_shape[1] = (input_size[3] - 1) * attr["strides"][1] + 1 + # Calculate padding + total_padding = [0, 0] + total_padding[0] = tmp_output_shape[0] - tmp_center_shape[0] + attr["kernel_shape"][0] - 1 + total_padding[1] = tmp_output_shape[1] - tmp_center_shape[1] + attr["kernel_shape"][1] - 1 + if total_padding[0] < 0 or total_padding[1] < 0: + raise RuntimeError(node_proto.name + " cannot infer conv padding.") + conv_pads_ = [0] * 4 + conv_pads_[0] = total_padding[0] // 2 + conv_pads_[1] = total_padding[1] // 2 + conv_pads_[2] = total_padding[0] - total_padding[0] // 2 + conv_pads_[3] = total_padding[1] - total_padding[1] // 2 + attr["conv_pads"] = conv_pads_ + else: + pad1_h = attr["kernel_shape"][0] - 0 - 1 + pad1_w = attr["kernel_shape"][1] - 0 - 1 + head_h = 0 + head_w = 0 + tail_h = attr["output_padding"][0] - head_h + tail_w = attr["output_padding"][1] - head_w + attr["conv_pads"] = [pad1_h + head_h, pad1_w + head_w, pad1_h + tail_h, pad1_w + tail_w] + return attr + +def split_ConvTranspose(model): + """To feed our compiler, split ConvTranspose into Upsample and Conv. + + :param model: the model + """ + node_to_delete = [] + # Change model properties for upsample. + if model.ir_version < 3: + print("Warning: Current model IR version is not fully supported.") + model.ir_version = 4 + model.opset_import[0].version = 9 + g = model.graph + # Get a Convtranspose layer + for node in g.node: + # Find a Flatten node + if node.op_type != 'ConvTranspose': + continue + # Check auto_pad + auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad") + if auto_pad_proto is not None: + print("Currently not split auto_pad ConvTranspose") + continue + # Check output_shape + output_shape_proto = helper.get_attribute_by_name(node, "output_shape") + if output_shape_proto is not None: + print("Currently not split output_shape ConvTranspose") + continue + # Get input shape + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None: + print("Cannot get value info named {}.".format(node.input[0])) + exit(1) + input_shape = helper.get_shape_from_value_info(input_value) + # Get attrbutes + attr = deconv_to_conv_info_extraction(input_shape, node) + # Generate Upsample scales + upsample_output_shape = list(input_shape) + upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1 + upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1 + upsample_node_name = node.name + "_inner_upsample" + upsample_scale_name = upsample_node_name + "_scales" + scales_np = np.ones([4]).astype('float32') + scales_np[2] = float(upsample_output_shape[2]) / input_shape[2] + scales_np[3] = float(upsample_output_shape[3]) / input_shape[3] + scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np) + # Generate a Upsample layer and an internal value info + upsample_node = onnx.helper.make_node( + "Upsample", + [node.input[0], upsample_scale_name], + [upsample_node_name], + name=upsample_node_name, + mode="zeros" + ) + upsample_value_info = onnx.helper.make_tensor_value_info( + upsample_node_name, + input_value.type.tensor_type.elem_type, + upsample_output_shape + ) + # Check the weight layer, it may need a transpose + if attr["group"] != input_shape[1]: + weight_node = helper.find_node_by_output_name(g, node.input[1]) + weight_np = helper.constant_to_numpy(weight_node) + new_weight_np = np.transpose(weight_np, [1, 0, 2, 3]) + new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np) + node_to_delete.append(weight_node) + g.node.extend([new_weight_node]) + value = helper.find_value_by_name(g, node.input[1]) + g.value_info.remove(value) + # Generate a Conv layer + conv_node_name = node.name + "_inner_conv" + conv_node_input = [upsample_node_name] + conv_node_input.extend(node.input[1:]) + conv_node = onnx.helper.make_node( + "Conv", + conv_node_input, + [node.output[0]], + name=conv_node_name, + pads=[int(i) for i in attr["conv_pads"]], + dilations=[int(i) for i in attr["dilations"]], + group=int(attr["group"]), + kernel_shape=[int(i) for i in attr["kernel_shape"]], + strides=[int(1), int(1)] + ) + # Reconnect the graph + g.node.extend([scales_node, upsample_node, conv_node]) + g.value_info.extend([upsample_value_info]) + node_to_delete.append(node) + # Delete useless nodes + for node in node_to_delete: + g.node.remove(node) + topological_sort(g) + +def add_bn_on_skip_branch(g): + for n in g.node: + # Find merge node (Add) + if n.op_type != 'Add': + continue + if len(n.input) != 2: + continue + # TODO: Still need to consider more cases + # Check if skip branch exist + input_node_a = helper.find_node_by_output_name(g, n.input[0]) + output_of_input_node_a = helper.find_nodes_by_input_name(g, input_node_a.output[0]) + input_node_b = helper.find_node_by_output_name(g, n.input[1]) + output_of_input_node_b = helper.find_nodes_by_input_name(g, input_node_b.output[0]) + if len(output_of_input_node_a) == 1 and len(output_of_input_node_b) == 1: + continue + if len(output_of_input_node_a) == 2: + split_node = input_node_a + elif len(output_of_input_node_b) == 2: + split_node = input_node_b + else: + continue + # Get the channel number from value info + value_name = split_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + topological_sort(g) + +def add_bn_before_add(g): + for n in g.node: + # Find merge node (Add) + if n.op_type != 'Add': + continue + if len(n.input) != 2: + continue + # Get two inputs + input_node_a = helper.find_node_by_output_name(g, n.input[0]) + input_node_b = helper.find_node_by_output_name(g, n.input[1]) + # Skip constant input add + if input_node_a is None or input_node_a.op_type == 'Constant': + continue + if input_node_b is None or input_node_b.op_type == 'Constant': + continue + def add_bn_after(prev_node): + # Get the channel number from value info + value_name = prev_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name, + epsilon=0.00000001 + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + if not input_node_a.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_a.output[0])) > 1: + add_bn_after(input_node_a) + if not input_node_b.op_type == 'BatchNormalization' or len(helper.find_following_nodes_by_input_value_name(g, input_node_b.output[0])) > 1: + add_bn_after(input_node_b) + topological_sort(g) + +def add_bn_before_activation(g): + activation_nodes = set(['Relu', 'Clip', 'PRelu', 'LeakyRelu']) + previous_nodes = set(['Conv', 'BatchNormalization']) + for n in g.node: + # Find activation node + if n.op_type not in activation_nodes: + continue + # Get input + input_node = helper.find_node_by_output_name(g, n.input[0]) + if input_node is None or input_node.op_type in previous_nodes: + continue + def add_bn_after(prev_node): + # Get the channel number from value info + value_name = prev_node.output[0] + value = helper.find_value_by_name(g, value_name) + shape = helper.get_shape_from_value_info(value) + channel = shape[1] + # Construct 4 weights + node_name = value_name + "_nop_bn" + ones = [1.0] * channel + zeros = [0.0] * channel + scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) + bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) + mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) + var_node = helper.list_to_constant(node_name + "_var", [channel], ones) + # Construct BN node + bn_node = onnx.helper.make_node( + "BatchNormalization", + [value_name, + scale_node.output[0], + bias_node.output[0], + mean_node.output[0], + var_node.output[0]], + [node_name], + name = node_name, + epsilon=0.00000001 + ) + # Reconnect the graph + replace_node_input(n, value_name, node_name) + # Add node to the graph + g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) + add_bn_after(input_node) + topological_sort(g) + +def rename_output_name(g, original_name, new_name): + # Output + output_value = helper.find_output_by_name(g, original_name) + if output_value is None: + logging.error("Cannot find output value named " + original_name) + return + output_value.name = new_name + # Value Info + value_info = helper.find_value_by_name(g, original_name) + if value_info is not None: + value_info.name = new_name + # Node output + node = helper.find_node_by_output_name(g, original_name) + node.output[0] = new_name + # Node input + nodes = helper.find_nodes_by_input_name(g, original_name) + for node in nodes: + replace_node_input(node, original_name, new_name) + +def duplicate_param_shared_constant(g): + for node in g.node: + input_names = set() + for n, input_node_name in enumerate(node.input): + param_data_node = helper.find_node_by_output_name(g, input_node_name) + if param_data_node is None or param_data_node.op_type != 'Constant': + continue + if param_data_node.name not in input_names: + input_names.add(input_node_name) + continue + + new_node_name = param_data_node.name + '_' + str(n) + helper.logger.debug(f"Duplicating weight: {param_data_node.name} -> {new_node_name}") + duplicated_node = copy.deepcopy(param_data_node) + + duplicated_node.name = new_node_name + duplicated_node.output[0] = new_node_name + + node.input[n] = new_node_name + g.node.extend([duplicated_node]) diff --git a/tools/optimizer_scripts/tools/removing_transpose.py b/tools/optimizer_scripts/tools/removing_transpose.py new file mode 100644 index 0000000..d0b7882 --- /dev/null +++ b/tools/optimizer_scripts/tools/removing_transpose.py @@ -0,0 +1,317 @@ +from . import helper +from . import other +from . import modhelper +from . import fusing +import numpy as np +import onnx +import onnx.utils + +def eliminate_transposes(m): + g = m.graph + keep_eliminating = True + while keep_eliminating: + while swap_transpose_with_single_next_node(g): + pass + splitted = split_transpose_for_multiple_next_nodes(g) + annihilated = annihilate_transposes(g) + multiple_trans_swapped = swap_multiple_transposes_with_node(g) + keep_eliminating = splitted or annihilated or multiple_trans_swapped + + if keep_eliminating: + m = other.polish_model(m) + g = m.graph + + return m + + +def swap_transpose_with_single_next_node(g): + swapped = False + passable_nodes = set(['Relu', 'Neg', 'LeakyRelu', 'Sqrt', 'Reciprocal', 'Add', 'Mul', 'Tanh']) + for node in g.node: + trans_node = node + # Check for transpose node + if trans_node.op_type != 'Transpose': + continue + next_nodes = helper.find_nodes_by_input_name(g, trans_node.output[0]) + if len(next_nodes) != 1: + continue + next_node = next_nodes[0] + # Check if the next node is the type can be swapped + if next_node.op_type not in passable_nodes: + continue + + input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in next_node.input] + + # Check if the node has nonconstant input other than the Transpose node itself + nonconstant_input = False + for input_node in input_nodes: + if input_node == None: + nonconstant_input = True + break + if input_node.name == trans_node.name: + continue + elif input_node.op_type == 'Constant': + continue + else: + nonconstant_input = True + break + if nonconstant_input: + continue + + for input_node in input_nodes: + if input_node.name == trans_node.name: + # if the input is just the transpose node + next_value_info = helper.find_value_by_name(g, next_node.output[0]) + mid_value_info = helper.find_value_by_name(g, trans_node.output[0]) + + output_nodes = helper.find_nodes_by_input_name(g, next_node.output[0]) + for out_node in output_nodes: + modhelper.replace_node_input(out_node, next_node.output[0], trans_node.name) + + next_node.input[0] = trans_node.input[0] + next_node.output[0] = next_node.name + trans_node.input[0] = next_node.name + trans_node.output[0] = trans_node.name + + if next_value_info: + next_value_info.name = trans_node.name + if mid_value_info: + g.value_info.remove(mid_value_info) + else: + # if the input is a constant node + old_tensor = input_node.attribute[0].t + old_shape, data = helper.constant_to_list(input_node) + # If the constant node is a scaler, no action is needed + if type(old_shape) == int: + old_shape = [old_shape] + permutation = list(trans_node.attribute[0].ints) + while len(old_shape) < len(permutation): + old_shape.insert(0, 1) + np_data = np.reshape(data, old_shape) + reverse_perm = [] + for i in range(len(permutation)): + reverse_perm.append(permutation.index(i)) + np_data = np.transpose(np_data, reverse_perm) + new_shape = np_data.shape + new_tensor = onnx.helper.make_tensor( + name=old_tensor.name, + data_type=old_tensor.data_type, + dims=new_shape, + vals=np_data.flatten().tolist() + ) + new_node = onnx.helper.make_node( + 'Constant', + [], + [input_node.output[0]], + name=input_node.name, + value=new_tensor + ) + g.node.extend([new_node]) + + g.value_info.remove(helper.find_value_by_name(g, input_node.output[0])) + g.node.remove(input_node) + + swapped = True + + other.topological_sort(g) + return swapped + + +def swap_multiple_transposes_with_node(g): + # here only consider same input transposes + swapped = False + passable_nodes = set(['Add', 'Mul']) + node_to_del = [] + for node in g.node: + if node.op_type not in passable_nodes: + continue + input_nodes = [helper.find_node_by_output_name(g, input_name) for input_name in node.input] + if any([input_node == None for input_node in input_nodes]): + continue + if any([input_node.op_type != 'Transpose' for input_node in input_nodes]): + continue + + permutation = list(input_nodes[0].attribute[0].ints) + if any([list(input_node.attribute[0].ints) != permutation for input_node in input_nodes]): + continue + + for input_name in node.input: + input_node = helper.find_node_by_output_name(g, input_name) + modhelper.replace_node_input(node, input_name, input_node.input[0]) + + node_to_del.extend(input_nodes) + for input_node in input_nodes: + input_val_info = helper.find_value_by_name(g, input_node.output[0]) + if input_val_info is not None: + g.value_info.remove(input_val_info) + output_val_info = helper.find_value_by_name(g, node.output[0]) + if output_val_info is not None: + g.value_info.remove(output_val_info) + + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for i in range(len(output_nodes)): + new_trans_node_name = node.name+'_trans_'+str(i) + new_trans_node = onnx.helper.make_node( + 'Transpose', + [node.output[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=permutation + ) + modhelper.replace_node_input(output_nodes[i], node.output[0], new_trans_node_name) + + g.node.extend([new_trans_node]) + + swapped = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return swapped + + +def annihilate_transposes(g): + node_to_del = [] + annihilated = False + for node in g.node: + if node.op_type != 'Transpose': + continue + pre_node = helper.find_node_by_output_name(g, node.input[0]) + if not pre_node or pre_node.op_type != 'Transpose': + continue + nodes_from_top_transpose = helper.find_nodes_by_input_name(g, pre_node.output[0]) + if len(nodes_from_top_transpose) > 1: + continue + + perm_1 = list(pre_node.attribute[0].ints) + perm_2 = list(node.attribute[0].ints) + if perm_1 != perm_2: + continue + + out_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + for out_node in out_nodes: + modhelper.replace_node_input(out_node, node.output[0], pre_node.input[0]) + + node_to_del.extend([node, pre_node]) + mid_value_info = helper.find_value_by_name(g, pre_node.output[0]) + out_value_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(mid_value_info) + g.value_info.remove(out_value_info) + + annihilated = True + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + return annihilated + + +def split_transpose_for_multiple_next_nodes(g): + splitted = False + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + output_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if len(output_nodes) < 2: + continue + for i in range(len(output_nodes)): + output_node = output_nodes[i] + new_trans_node_name = node.name + '_' + str(i) + new_trans_node = onnx.helper.make_node( + 'Transpose', + [node.input[0]], + [new_trans_node_name], + name=new_trans_node_name, + perm=list(node.attribute[0].ints) + ) + modhelper.replace_node_input(output_node, node.output[0], new_trans_node.output[0]) + g.node.extend([new_trans_node]) + + node_to_del.append(node) + val_info = helper.find_value_by_name(g, node.output[0]) + g.value_info.remove(val_info) + + splitted = True + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + return splitted + +def remove_trivial_transpose(g): + node_to_del = [] + for node in g.node: + if node.op_type != 'Transpose': + continue + permutation = list(node.attribute[0].ints) + if permutation != list(range(len(permutation))): + continue + + next_nodes = helper.find_nodes_by_input_name(g, node.output[0]) + if not next_nodes: + input_val_info = helper.find_value_by_name(g, node.input[0]) + out_val_info = helper.find_output_by_name(g, node.output[0]) + if not input_val_info: + input_val_info = helper.find_input_by_name(g, node.input[0]) + g.output.remove(out_val_info) + g.output.extend([input_val_info]) + else: + out_val_info = helper.find_value_by_name(g, node.output[0]) + for next_node in next_nodes: + modhelper.replace_node_input(next_node, node.output[0], node.input[0]) + g.value_info.remove(out_val_info) + + node_to_del.append(node) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) + +def fuse_Transpose_into_Gemm_weight(g): + node_to_del = [] + for node in g.node: + # Check pattern + if node.op_type != 'Gemm': + continue + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is None or prev_node.op_type != 'Flatten': + continue + transpose_node = helper.find_node_by_output_name(g, prev_node.input[0]) + if transpose_node.op_type != 'Transpose': + continue + # Check attribute + perm = helper.get_list_attribute_by_name(transpose_node, 'perm', 'int') + if perm != [0, 2, 3, 1]: + continue + transB = helper.get_var_attribute_by_name(node, 'transB', 'int') + if transB is not None and transB == 1: + continue + # Get the original weight + origin_weight = helper.find_node_by_output_name(g, node.input[1]) + origin_np = helper.constant_to_numpy(origin_weight) + # Calculate a new weight + shape = helper.get_shape_from_value_info(helper.find_value_by_name(g, prev_node.input[0])) + shape.append(-1) + new_np = np.reshape(origin_np, shape) + new_np = np.transpose(new_np, [0, 3, 1, 2, 4]) + new_np = np.reshape(new_np, [-1, new_np.shape[-1]]) + new_weight = helper.numpy_to_constant(origin_weight.output[0], new_np) + # Replace and eliminate + prev_node.input[0] = transpose_node.input[0] + node_to_del.append(transpose_node) + node_to_del.append(origin_weight) + g.value_info.remove(helper.find_value_by_name(g, transpose_node.output[0])) + g.node.extend([new_weight]) + + while node_to_del: + node = node_to_del.pop() + g.node.remove(node) + + other.topological_sort(g) diff --git a/tools/optimizer_scripts/tools/replacing.py b/tools/optimizer_scripts/tools/replacing.py new file mode 100644 index 0000000..091e571 --- /dev/null +++ b/tools/optimizer_scripts/tools/replacing.py @@ -0,0 +1,1171 @@ +"""Optimizations that replace one node with another. +""" +from os import dup +import struct +import copy +import logging +import onnx.helper +import numpy as np +from . import helper +from . import modhelper +from .other import topological_sort + +def replace_initializer_with_Constant(g, duplicate_shared_weights=True): + """ + Replace initializers with Constant and a corresponding value_info + If the initializer has related input, remove it. + + :param g: the onnx graph + """ + + input_map = {i.name: i for i in g.input} + for tensor in g.initializer: + # Check for the initializer related input and remove it + if tensor.name in input_map: + value_info = input_map[tensor.name] + g.input.remove(value_info) + following_nodes = helper.find_nodes_by_input_name(g, tensor.name) + if duplicate_shared_weights and len(following_nodes) >= 2: + for i, node in enumerate(following_nodes): + new_name = tensor.name + "_duplicated_No" + str(i) if i > 0 else tensor.name + helper.logger.debug(f"Duplicating weight: {tensor.name} -> {new_name}") + modhelper.replace_node_input(node, tensor.name, new_name) + new_node = onnx.helper.make_node( + "Constant", + [], + [new_name], + name=new_name, + value=tensor + ) + # Add node to lists + g.node.extend([new_node]) + else: + new_name = tensor.name + new_node = onnx.helper.make_node( + "Constant", + [], + [new_name], + name=new_name, + value=tensor + ) + # Add node to lists + g.node.extend([new_node]) + + # if value info already exists, remove it as well. + value_info = helper.find_value_by_name(g, tensor.name) + if value_info is not None: + g.value_info.remove(value_info) + + # Remove original initializer + while len(g.initializer) != 0: + g.initializer.pop() + + topological_sort(g) + +def replace_Reshape_with_Flatten(g): + """ + Replace Reshape node into Flatten node if applicable. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + if node.op_type != 'Reshape': + continue + found_Gemm = False + # Flatten could be followed by Gemm + for i in g.node: + if len(i.input) == 0 or i.input[0] != node.output[0]: + continue + if i.op_type == 'Gemm': + found = True + break + # Check weight + shape_node = helper.find_node_by_output_name(g, node.input[1]) + if shape_node.op_type != 'Constant': + continue + shape_value = helper.constant_to_numpy(shape_node) + if (shape_value.size != 2 or shape_value[0] != 1) and not found_Gemm: + continue + # Replace it + node.op_type = "Flatten" + for _ in range(len(node.attribute)): + node.attribute.pop() + shape_value = helper.find_value_by_name(g, shape_node.output[0]) + node.input.pop() + node_to_remove.append(shape_node) + # If found shape value_info, remove it + if shape_value != None: + g.value_info.remove(shape_value) + + for node in node_to_remove: + g.node.remove(node) + +def replace_Squeeze_with_Reshape(g): + """ + Replace Squeeze nodes with Reshape node. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find Squeeze node + if node.op_type != 'Squeeze': + continue + # Get the shape and Construct the shape + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + raise RuntimeError("Cannot get shape for Squeeze") + shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] + const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + # Construct the Reshape layer with same input, output and name. + new_node = onnx.helper.make_node( + "Reshape", + [node.input[0], node.name + "_shape"], + node.output, + name=node.name + ) + # Append constructed nodes and append old node to remove_list + g.node.extend([const_node, new_node]) + node_to_remove.append(node) + # Remove old nodes + for node in node_to_remove: + g.node.remove(node) + # Topological sort + topological_sort(g) + +def replace_Unsqueeze_with_Reshape(g): + """ + Replace Unsqueeze nodes with Reshape node. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find Squeeze node + if node.op_type != 'Unsqueeze': + continue + # Get the shape and Construct the shape + output_value = helper.find_value_by_name(g, node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, node.output[0]) + if output_value is None: + raise RuntimeError("Cannot get shape for Unsqueeze") + shape = [dim.dim_value for dim in output_value.type.tensor_type.shape.dim] + + const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) + # Construct the Reshape layer with same input, output and name. + new_node = onnx.helper.make_node( + "Reshape", + [node.input[0], node.name + "_shape"], + node.output, + name=node.name + ) + # Append constructed nodes and append old node to remove_list + g.node.extend([const_node, new_node]) + node_to_remove.append(node) + # Remove old nodes + for node in node_to_remove: + g.node.remove(node) + # Topological sort + topological_sort(g) + +def replace_average_pool_with_GAP(g): + """ + Replace AveragePool nodes with GlobalAveragePool node when available. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find a average pool layer + if node.op_type != 'AveragePool': + continue + # Check attributes + not_replace = False + for attr in node.attribute: + if attr.name == 'pads': + if list(attr.ints) != [0, 0, 0, 0]: + not_replace = True + break + if attr.name == 'kernel_shape': + kernel_shape = list(attr.ints) + value_info = helper.find_value_by_name(g, node.input[0]) + if value_info is None: + not_replace = True + break + input_shape = [] + for dim in value_info.type.tensor_type.shape.dim: + input_shape.append(dim.dim_value) + if input_shape[-2:] != kernel_shape: + not_replace = True + break + if not_replace: + continue + # Replace it with GlobalAveragePool + new_node = onnx.helper.make_node( + "GlobalAveragePool", + node.input, + node.output, + name=node.name + ) + g.node.extend([new_node]) + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_dilated_conv(g): + """ + If the dilation of a convolution is not (1, 1), replace it with a regular + convolution with an expanded kernel. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Check if this is a conv layer + if node.op_type != 'Conv': + continue + # Check if this has dilation + has_dilations = False + has_strides = False + for attr in node.attribute: + if attr.name == "dilations": + dilations = list(attr.ints) + if dilations != [1, 1]: + has_dilations = True + if attr.name == "strides": + strides = list(attr.ints) + if strides != [1, 1]: + has_strides = True + if has_dilations and has_strides: + print("Warning: Both strides and dilations are set in ", node.name) + continue + if not has_dilations: + continue + # Construct new kernel + w_node = helper.find_node_by_output_name(g, node.input[1]) + w_output = helper.find_value_by_name(g, node.input[1]) + shape = list(w_node.attribute[0].t.dims) + # get original weight from float_data or raw data + weight = list(w_node.attribute[0].t.float_data) + if len(weight) == 0: + # Unpack from raw data + raw_data = w_node.attribute[0].t.raw_data + weight = [i[0] for i in struct.iter_unpack('f', raw_data)] + weight = np.array(weight) + weight = np.reshape(weight ,shape) + new_shape = copy.copy(shape) + new_shape[2] = 1 + (shape[2] - 1) * dilations[0] + new_shape[3] = 1 + (shape[3] - 1) * dilations[1] + new_weight = np.zeros(new_shape) + for batch in range(shape[0]): + for ch in range(shape[1]): + for h in range(shape[2]): + nh = h * dilations[0] + for w in range(shape[3]): + nw = w * dilations[1] + new_weight[batch, ch, nh, nw] = weight[batch, ch, h, w] + tensor = onnx.helper.make_tensor( + w_node.attribute[0].t.name, + w_node.attribute[0].t.data_type, + new_shape, + new_weight.ravel() + ) + new_w_node = onnx.helper.make_node( + "Constant", + [], + list(w_node.output), + name=w_node.name, + value=tensor + ) + g.node.extend([new_w_node]) + node_to_remove.append(w_node) + # Modify attributes and value info shapes + w_output.type.tensor_type.shape.dim[2].dim_value = new_shape[2] + w_output.type.tensor_type.shape.dim[3].dim_value = new_shape[3] + for attr in node.attribute: + if attr.name == "kernel_shape": + attr.ints[0] = new_shape[2] + attr.ints[1] = new_shape[3] + if attr.name == "dilations": + attr.ints[0] = 1 + attr.ints[1] = 1 + # Remove old weight nodes + for node in node_to_remove: + g.node.remove(node) + +def replace_depthwise_1x1_with_bn(g): + """Replace 1x1 DepthwiseConv node into BN node if applicable. + + :param g: the onnx graph + """ + node_to_remove = [] + for node in g.node: + # Check op_type + if node.op_type != 'Conv': + continue + # Check attributes + attr_map = {attr.name: attr for attr in node.attribute} + if "group" not in attr_map or attr_map["group"].i == 1: + continue + if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1: + continue + if "pads" in attr_map and sum(attr_map["pads"].ints) != 0: + continue + # Check scale + scale_node = helper.find_node_by_output_name(g, node.input[1]) + if scale_node is None or scale_node.attribute[0].t.dims[1] != 1: + continue + scale_node.attribute[0].t.dims.pop() + scale_node.attribute[0].t.dims.pop() + scale_node.attribute[0].t.dims.pop() + scale_info = helper.find_value_by_name(g, node.input[1]) + if scale_info is not None: + scale_info.type.tensor_type.shape.dim.pop() + scale_info.type.tensor_type.shape.dim.pop() + scale_info.type.tensor_type.shape.dim.pop() + # Check bias + if len(node.input) == 3: + bias_name = node.input[2] + else: + bias_name = node.name + "_bias" + bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + g.node.extend([bias_node]) + # Construct mean and vars + mean_name = node.name + "_mean" + mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) + var_name = node.name + "_var" + var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i) + g.node.extend([mean_node, var_node]) + # Convert + bn_node = onnx.helper.make_node( + op_type='BatchNormalization', + inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name], + outputs=node.output, + name=node.name, + epsilon=0.00001, + momentum=0.9 + ) + g.node.extend([bn_node]) + node_to_remove.append(node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_shape_with_constant(g): + """Replace Shape with Constant.\\ + This is the first step of reshape constant folding. + + :param g: the input graph\\ + :return: if anything modified, return true. + """ + node_to_remove = [] + for node in g.node: + # Find a Shape + if node.op_type != 'Shape': + continue + # Check its input + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + continue + # Check for case where dimension could be 0 or -1 + tmp = True + for d in input_value.type.tensor_type.shape.dim: + tmp = tmp and (d.dim_value > 0) + if not tmp: + continue + # Repalce it + input_shape = [ + d.dim_value for d in input_value.type.tensor_type.shape.dim] + node_name = node.output[0] + new_node = helper.list_to_constant( + node_name, [len(input_shape)], input_shape) + g.node.extend([new_node]) + node_to_remove.append(node) + + # if the input value_info is not used by other node + # delete this input value_info + val_info_used = sum([input_value.name in node.input for node in g.node]) + if val_info_used == 1: + g.value_info.remove(input_value) + + replaced = True if len(node_to_remove) > 0 else False + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + + return replaced + +def replace_ConstantOfShape_with_constant(g): + """Replace Shape with Constant.\\ + This is the first step of reshape constant folding. + + :param g: the input graph\\ + :return: if anything modified, return true. + """ + node_to_remove = [] + for node in g.node: + # Find a Shape + if node.op_type != 'ConstantOfShape': + continue + # Check input + input_value = helper.find_value_by_name(g, node.input[0]) + if input_value is None: + input_value = helper.find_input_by_name(g, node.input[0]) + if input_value is None or len(input_value.type.tensor_type.shape.dim) == 0: + continue + + # Replace to constant node + pre_node = helper.find_node_by_output_name(g, node.input[0]) + _, target_shape = helper.constant_to_list(pre_node) + + value = helper.get_attribute_by_name(node, 'value').i + + node_name = node.output[0] + new_node = helper.list_to_constant( + node_name, [target_shape[0]], [value] * target_shape[0]) + + g.node.extend([new_node]) + + # remove old node + node_to_remove.append(node) + + # delete value_info + val_info_used = sum([input_value.name in node.input for node in g.node]) + if val_info_used == 1: + g.value_info.remove(input_value) + + replaced = True if len(node_to_remove) > 0 else False + + for node in node_to_remove: + g.node.remove(node) + + topological_sort(g) + + return replaced + +def replace_split_with_slices(g): + """Replace split node with slice nodes. + :param g: input graph. + :return: + """ + node_to_remove = [] + for node in g.node: + # Find a Split + if node.op_type != 'Split': + continue + + input_value = helper.find_value_by_name(g, node.input[0]) + if not input_value: + input_value = helper.find_input_by_name(g, node.input[0]) + _, shape = helper.find_size_shape_from_value(input_value) + if len(shape) == 0: + continue + + output_val_names = list(node.output) + + axis = 0 + split = [] + for item in node.attribute: + if item.name == 'axis': + axis = item.i + if item.name == 'split': + split = item.ints + + # For opset 11, axis could be negative. + if axis < 0: + axis = len(shape) + axis + + length = input_value.type.tensor_type.shape.dim[axis].dim_value + if len(split) > 0: + n_out = len(split) + pos = 0 + for i in range(n_out): + pos += split[i] + new_node_name = output_val_names[i] + # Construct starts, ends, axes + starts_name = new_node_name + '_starts_' + str(i) + ends_name = new_node_name + '_ends_' + str(i) + axes_name = new_node_name + '_axes_' + str(i) + starts_node = helper.list_to_constant(starts_name, (1, ), [int(pos-split[i])]) + ends_node = helper.list_to_constant(ends_name, (1, ), [int(pos)]) + axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + # Construtc node + new_node = onnx.helper.make_node( + op_type='Slice', + inputs=[node.input[0], starts_name, ends_name, axes_name], + outputs=[node.output[i]], + name=new_node_name + ) + g.node.extend([starts_node, ends_node, axes_node, new_node]) + node_to_remove.append(node) + else: + n_out = len(output_val_names) + width = length//n_out + for i in range(n_out): + new_node_name = output_val_names[i] + # Construct starts, ends, axes + starts_name = new_node_name + '_starts_' + str(i) + ends_name = new_node_name + '_ends_' + str(i) + axes_name = new_node_name + '_axes_' + str(i) + starts_node = helper.list_to_constant(starts_name, (1, ), [int(i*width)]) + ends_node = helper.list_to_constant(ends_name, (1, ), [int((1+i)*width)]) + axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) + # Construtc node + new_node = onnx.helper.make_node( + op_type='Slice', + inputs=[node.input[0], starts_name, ends_name, axes_name], + outputs=[node.output[i]], + name=new_node_name + ) + g.node.extend([starts_node, ends_node, axes_node, new_node]) + node_to_remove.append(node) + + for old_node in node_to_remove: + g.node.remove(old_node) + topological_sort(g) + + +def replace_ReduceMean_with_GlobalAveragePool(g): + """ + Replace ReduceMean with GlobalAveragePool node when available. + + If there is preceeded Transpose, check the Transpose and the ReduceMean + together. If the keep_dims is set to 0, add a Flatten. + + :param g: the input graph + """ + node_to_remove = [] + for node in g.node: + # Find a ReduceMean layer + if node.op_type != 'ReduceMean': + continue + # Find if it have previous Transpose and its attribute meet the need. + prev_node = helper.find_node_by_output_name(g, node.input[0]) + if prev_node is not None and prev_node.op_type != 'Transpose': + prev_node = None + if prev_node is not None: + perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int') + if perm != [0, 2, 3, 1]: + prev_node = None + # Check attributes + axes = helper.get_list_attribute_by_name(node, 'axes', 'int') + keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') + if axes is None: + continue + if prev_node is None and axes != [2, 3]: + continue + if prev_node is not None and axes != [1, 2]: + continue + if keepdims is None: + keepdims = 1 + # Replace it with GlobalAveragePool + if prev_node: + input_list = prev_node.input + else: + input_list = node.input + if keepdims == 1: + output_list = node.output + else: + output_list = [node.output[0] + '_before_flatten'] + flatten_node = onnx.helper.make_node( + "Flatten", + output_list, + node.output, + name = node.name + "_flatten", + axis = 1 + ) + g.node.extend([flatten_node]) + new_node = onnx.helper.make_node( + "GlobalAveragePool", + input_list, + output_list, + name=node.name + ) + g.node.extend([new_node]) + node_to_remove.append(node) + if prev_node: + value = helper.find_value_by_name(g, prev_node.output[0]) + if value: + g.value_info.remove(value) + node_to_remove.append(prev_node) + for node in node_to_remove: + g.node.remove(node) + topological_sort(g) + +def replace_mul_to_bn(g): + """Replace single Mul node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Mul': + continue + + mul_op_node = node + + # only support one input node + if len(mul_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = mul_op_node.input[0] + mul_value_node = helper.find_node_by_output_name(g, mul_op_node.input[1]) + if not mul_value_node or mul_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + scale_shape, scale_data = helper.constant_to_list(mul_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise mul or const mul + if scale_shape == [1, c_dim, 1, 1]: + muls = scale_data + elif scale_shape == [c_dim, 1, 1]: + muls = scale_data + elif scale_shape == 1: + muls = scale_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + bn_name = mul_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [mul_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) + g.value_info.remove(scale_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([bias_value_node]) + g.node.extend([new_mul_value_node]) + + node_to_del.extend([mul_op_node]) + node_to_del.extend([mul_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_div_to_bn(g): + """Replace single Div node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Div': + continue + + div_op_node = node + + # only support one input node + if len(div_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = div_op_node.input[0] + div_value_node = helper.find_node_by_output_name(g, div_op_node.input[1]) + if not div_value_node or div_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + scale_shape, scale_data = helper.constant_to_list(div_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise div or const div + if scale_shape == [1, c_dim, 1, 1]: + muls = scale_data + elif scale_shape == [c_dim, 1, 1]: + muls = scale_data + elif scale_shape == 1: + muls = scale_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + muls = (1 / np.array(muls)).tolist() + bn_name = div_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + new_mul_value_node = helper.list_to_constant(bn_name+'_mul', np.array(muls).shape, muls) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + new_mul_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [div_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + scale_val_info = helper.find_value_by_name(g, div_value_node.output[0]) + g.value_info.remove(scale_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([bias_value_node]) + g.node.extend([new_mul_value_node]) + + node_to_del.extend([div_op_node]) + node_to_del.extend([div_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def replace_add_to_bn(g): + """Replace single Add node with Batchnorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Add': + continue + + add_op_node = node + + # only support one input node + if len(add_op_node.input) != 2: # OP node and value node + continue + + input_op_node_name = add_op_node.input[0] + add_value_node = helper.find_node_by_output_name(g, add_op_node.input[1]) + if not add_value_node or add_value_node.op_type != 'Constant': + continue + + prev_shape_value_info = helper.find_value_by_name(g, input_op_node_name) + prev_shape_value_info = helper.find_input_by_name(g, input_op_node_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + bias_shape, bias_data = helper.constant_to_list(add_value_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise add or const add + if bias_shape == [1, c_dim, 1, 1]: + bias = bias_data + elif bias_shape == [c_dim, 1, 1]: + bias = bias_data + elif bias_shape == 1: + bias = bias_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + bn_name = add_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(ones).shape, ones) + new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_op_node_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [add_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + add_val_info = helper.find_value_by_name(g, add_value_node.output[0]) + g.value_info.remove(add_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([scale_value_node]) + g.node.extend([new_add_value_node]) + + node_to_del.extend([add_op_node]) + node_to_del.extend([add_value_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_sub_to_bn(g): + """Replace single Sub node with BatchNorm node. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + if node.op_type != 'Sub': + continue + + sub_op_node = node + + # only support one input node + if len(sub_op_node.input) != 2: # OP node and value node + continue + + # Check the input type + input_1st_name = sub_op_node.input[0] + input_2nd_name = sub_op_node.input[1] + input_1st_node = helper.find_node_by_output_name(g, input_1st_name) + input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + real_input_name = input_2nd_name + reverse = True + constant_node = input_1st_node + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + real_input_name = input_1st_name + reverse = False + constant_node = input_2nd_node + else: + continue + + # Get shapes + prev_shape_value_info = helper.find_value_by_name(g, real_input_name) + prev_shape_value_info = helper.find_input_by_name(g, real_input_name) if prev_shape_value_info is None else prev_shape_value_info + if prev_shape_value_info is None: + continue + + _ , previous_node_output_shape = helper.find_size_shape_from_value(prev_shape_value_info) + bias_shape, bias_data = helper.constant_to_list(constant_node) + + # channel dimension + c_dim = previous_node_output_shape[1] if len(previous_node_output_shape) > 1 else 1 + + # only allow channelwise sub or const sub + if bias_shape == [1, c_dim, 1, 1]: + bias = bias_data + elif bias_shape == [c_dim, 1, 1]: + bias = bias_data + elif bias_shape == 1: + bias = bias_data * c_dim + else: + continue + + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + # If reversed provide special scaler + if reverse: + scale = [-1.0] * c_dim + else: + scale = ones + bias *= -1 + bn_name = sub_op_node.output[0] + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) + new_add_value_node = helper.list_to_constant(bn_name+'_add', np.array(bias).shape, bias) + + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [real_input_name, + scale_value_node.output[0], + new_add_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [sub_op_node.output[0]], + name=bn_name, + epsilon=0.00000001 + ) + + add_val_info = helper.find_value_by_name(g, constant_node.output[0]) + g.value_info.remove(add_val_info) + + g.node.extend([bn_node]) + g.node.extend([mean_value_node]) + g.node.extend([variance_value_node]) + g.node.extend([scale_value_node]) + g.node.extend([new_add_value_node]) + + node_to_del.extend([sub_op_node]) + node_to_del.extend([constant_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + +def replace_sub_with_bn_and_add(g): + """Replace two input Sub node with BN and Add: A - B = A + (-1) * B + :param g: input graph. + :return: + """ + for node in g.node: + if node.op_type != 'Sub': + continue + + sub_op_node = node + + # only support one input node + if len(sub_op_node.input) != 2: # OP node and value node + continue + + # Check the input type + input_1st_name = sub_op_node.input[0] + input_2nd_name = sub_op_node.input[1] + input_1st_node = helper.find_node_by_output_name(g, input_1st_name) + input_2nd_node = helper.find_node_by_output_name(g, input_2nd_name) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + continue + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + continue + + # Get shapes + input_2nd_value_info = helper.find_value_by_name(g, input_2nd_name) + if input_2nd_value_info is None: + input_2nd_value_info = helper.find_input_by_name(g, input_2nd_name) + if input_2nd_value_info is None: + continue + + # Get channel dimension + _ , input_2nd_shape = helper.find_size_shape_from_value(input_2nd_value_info) + if len(input_2nd_shape) < 2: + helper.logger.debug(f"{sub_op_node.name} cannot be replaced due to the input shape.") + c_dim = input_2nd_shape[1] + + # Create * -1 bn node. + ones = [1.0] * c_dim + zeros = [0.0] * c_dim + scale = [-1.0] * c_dim + bn_name = input_2nd_name + '_neg_for_' + node.name + mean_value_node = helper.list_to_constant(bn_name+'_mean', np.array(zeros).shape, zeros) + variance_value_node = helper.list_to_constant(bn_name+'_var', np.array(ones).shape, ones) + scale_value_node = helper.list_to_constant(bn_name+'_mul', np.array(scale).shape, scale) + bias_value_node = helper.list_to_constant(bn_name+'_add', np.array(zeros).shape, zeros) + bn_node = onnx.helper.make_node( + 'BatchNormalization', + [input_2nd_name, + scale_value_node.output[0], + bias_value_node.output[0], + mean_value_node.output[0], + variance_value_node.output[0]], + [bn_name], + name=bn_name, + epsilon=0.00000001 + ) + + # Change sub to add + sub_op_node.op_type = "Add" + # Replace add input + modhelper.replace_node_input(sub_op_node, input_2nd_name, bn_name) + + g.node.extend([scale_value_node, bias_value_node, mean_value_node, variance_value_node, bn_node]) + + topological_sort(g) + +def replace_Sum_with_Adds(g): + node_to_del = [] + + for node in g.node: + # Check for sum + if node.op_type != 'Sum': + continue + # Check for input number + if len(node.input) == 1: + # If input number is 1, delete the sum node. + following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) + for following_node in following_nodes: + modhelper.replace_node_input(following_node, node.output[0], node.input[0]) + node_to_del.append(node) + if helper.find_value_by_name(node.output[0]) is not None: + g.value_info.remove(helper.find_value_by_name(node.output[0])) + elif len(node.input) == 2: + # If input number is 2, replace it with add. + node.op_type = 'Add' + continue + elif len(node.input) > 2: + # If input number is larger than 2, replace it with n-1 add. + input_count = len(node.input) + # First node has 2 inputs + first_node = onnx.helper.make_node( + "Add", + [node.input[0], node.input[1]], + [node.output[0] + '_replacement_1'], + name=node.name + '_replacement_1' + ) + # Last node has the same output as the original sum node + last_node = onnx.helper.make_node( + "Add", + [node.output[0] + '_replacement_' + str(input_count - 2), node.input[input_count - 1]], + [node.output[0]], + name=node.name + ) + g.node.extend([first_node, last_node]) + for i in range(2, input_count - 1): + new_node = onnx.helper.make_node( + "Add", + [node.output[0] + '_replacement_' + str(i - 1), node.input[i]], + [node.output[0] + '_replacement_' + str(i)], + name=node.name + '_replacement_' + str(i) + ) + g.node.extend([new_node]) + node_to_del.append(node) + else: + logging.error("Sum node must have at least 1 input.") + quit(1) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + + +def replace_constant_input_concat_with_pad(g): + """If single input is concating with constant node of same number. Replace it with pad. Currently only support 2-3 inputs. + :param g: input graph. + :return: + """ + node_to_del = [] + for node in g.node: + # Check for Concat node + if node.op_type != 'Concat': + continue + + # Check concat node input + mode = None + value = 0 + real_input_name = None + if len(node.input) == 2: + input_1st_node = helper.find_node_by_output_name(g, node.input[0]) + input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) + if input_1st_node is not None and input_1st_node.op_type == 'Constant': + mode = 'left' + constant_value = helper.constant_to_numpy(input_1st_node) + real_input_name = node.input[1] + value = constant_value.flatten()[0] + # Check if the values are all the same. + if np.any(constant_value - value): + continue + elif input_2nd_node is not None and input_2nd_node.op_type == 'Constant': + mode = 'right' + constant_value = helper.constant_to_numpy(input_2nd_node) + real_input_name = node.input[0] + value = constant_value.flatten()[0] + # Check if the values are all the same. + if np.any(constant_value - value): + continue + else: + # No constant input case + continue + elif len(node.input) == 3: + # For 3 inputs concat node, the 1st and the 3rd input should be constant with the same value. + input_1st_node = helper.find_node_by_output_name(g, node.input[0]) + input_2nd_node = helper.find_node_by_output_name(g, node.input[1]) + input_3rd_node = helper.find_node_by_output_name(g, node.input[2]) + if input_1st_node is None or input_1st_node.op_type != 'Constant' or \ + input_3rd_node is None or input_3rd_node.op_type != 'Constant': + continue + mode = 'both' + real_input_name = node.input[1] + input_1st_value = helper.constant_to_numpy(input_1st_node) + input_3rd_value = helper.constant_to_numpy(input_3rd_node) + value = input_1st_value.flatten()[0] + # Check if all the values are all the same + if np.any(input_1st_value - value): + continue + elif np.any(input_3rd_value - value): + continue + else: + # Too many inputs case. + continue + # Make weight nodes + input_value_info = helper.find_value_by_name(g, real_input_name) + input_shape = helper.get_shape_from_value_info(input_value_info) + pads = [0] * (len(input_shape) * 2) + axis = helper.get_var_attribute_by_name(node, 'axis', 'int') + if axis < 0: + axis = len(input_shape) - axis + if mode == 'left': + left_value_info = helper.find_value_by_name(g, node.input[0]) + left_input_shape = helper.get_shape_from_value_info(left_value_info) + pads[axis] = left_input_shape[axis] + elif mode == 'right': + right_value_info = helper.find_value_by_name(g, node.input[1]) + right_input_shape = helper.get_shape_from_value_info(right_value_info) + pads[axis + len(input_shape)] = right_input_shape[axis] + else: + # mode shoule be both + left_value_info = helper.find_value_by_name(g, node.input[0]) + left_input_shape = helper.get_shape_from_value_info(left_value_info) + pads[axis] = left_input_shape[axis] + right_value_info = helper.find_value_by_name(g, node.input[2]) + right_input_shape = helper.get_shape_from_value_info(right_value_info) + pads[axis + len(input_shape)] = right_input_shape[axis] + pads_node = helper.list_to_constant( + node.name + '_pads', + (len(pads), ), + pads + ) + constant_value_node = helper.scaler_to_constant( + node.name + '_constant_value', + value + ) + # Create new Pad node + new_pad_node = onnx.helper.make_node( + "Pad", + [real_input_name, pads_node.name, constant_value_node.name], + [node.output[0]], + name = node.name, + mode = "constant" + ) + # Replace + node_to_del.append(node) + g.node.extend([pads_node, constant_value_node, new_pad_node]) + + while node_to_del: + g.node.remove(node_to_del.pop()) + + topological_sort(g) + diff --git a/tools/optimizer_scripts/tools/special.py b/tools/optimizer_scripts/tools/special.py new file mode 100644 index 0000000..38de4f5 --- /dev/null +++ b/tools/optimizer_scripts/tools/special.py @@ -0,0 +1,423 @@ +"""Special operations on model. +""" +import logging +import onnx.helper +import numpy as np +from . import helper +from . import other +from . import modhelper + +def change_first_conv_from_bgr_to_rgb(m): + """For input channel format BGR model, use this function to change the first + conv weight to adapt the input into RGB. + + :param m: the model proto + """ + # Check for first node. + g = m.graph + input_name = g.input[0].name + first_nodes = helper.find_following_nodes_by_input_value_name(g, input_name) + if len(first_nodes) > 1: + return False + first_node = first_nodes[0] + # Now we have the first node. Check this first node. + if first_node.op_type != 'Conv': + return False + weight_value = helper.find_value_by_name(g, first_node.input[1]) + weight_shape = helper.get_shape_from_value_info(weight_value) + if weight_shape[1] != 3: + return False + # Do weight shuffle + weight_node = helper.find_node_by_output_name(g, weight_value.name) + weight_np = helper.constant_to_numpy(weight_node) + b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1) + g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1) + r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1) + new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1) + new_node = helper.numpy_to_constant(weight_value.name, new_np) + # Replace the weight and topological sort + g.node.remove(weight_node) + g.node.extend([new_node]) + other.topological_sort(g) + return True + +def change_input_from_bgr_to_rgb(m): + """For input channel format BGR model, use this function to modify the model + to accepct RGB image.If the first node is a non-group Conv. Modify weight to + adapt the input into RGB. Otherwise create a new node. + + :param m: the model proto + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot change to RGB input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4 or input_shape[1] != 3: + print("The input shape is invalid for bgr conversion.") + return + # Try change conv weight first + if change_first_conv_from_bgr_to_rgb(m): + return + # Otherwise, create a special conv node and replace the input + # Construct weight + weight_np = np.zeros((3, 3, 3, 3)).astype('float32') + weight_np[0, 2, 1, 1] = 1.0 + weight_np[1, 1, 1, 1] = 1.0 + weight_np[2, 0, 1, 1] = 1.0 + new_weight = helper.numpy_to_constant("bgr_shuffle_weight", weight_np) + # Construct Conv + new_conv = onnx.helper.make_node( + 'Conv', + ['rgb_input', "bgr_shuffle_weight"], + [g.input[0].name], + name='bgr_shuffle', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Connect the graph + old_input_value = g.input.pop() + new_input_value = onnx.helper.make_tensor_value_info( + 'rgb_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_conv]) + # topological sort + other.topological_sort(g) + +def add_0_5_to_normalized_input(m): + """For normalized input between -0.5 ~ 0.5, add 0.5 to the input to keep it + between 0 ~ 1. + + :param m: the model proto + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot normalize input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4: + print("The input shape is not BCHW. Cannot normalize input.") + return + # Construct weight + ch = input_shape[1] + weight_np = np.zeros((ch, ch, 3, 3)).astype('float32') + for i in range(ch): + weight_np[i, i, 1, 1] = 1.0 + new_weight = helper.numpy_to_constant("input_norm_weight", weight_np) + # Construct bias + bias_np = np.array([0.5] * ch).astype('float32') + new_bias = helper.numpy_to_constant("input_norm_bias", bias_np) + # Construct Conv + new_conv = onnx.helper.make_node( + 'Conv', + ['origin_input', "input_norm_weight", "input_norm_bias"], + [g.input[0].name], + name='input_norm', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Construct value_infos + old_input_value = g.input.pop() + weight_value = onnx.helper.make_tensor_value_info( + 'input_norm_weight', + old_input_value.type.tensor_type.elem_type, + [3, 3, 3, 3] + ) + bias_value = onnx.helper.make_tensor_value_info( + 'input_norm_bias', + old_input_value.type.tensor_type.elem_type, + [3] + ) + # Connect the graph + new_input_value = onnx.helper.make_tensor_value_info( + 'origin_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_bias, new_conv]) + g.value_info.extend([weight_value, bias_value, old_input_value]) + # topological sort + other.topological_sort(g) + +def add_rgb2yynn_node(m): + """Add a conv layer which can convert rgb to yynn input. + """ + g = m.graph + if len(g.input) > 1: + print("This model has multiple inputs. Cannot change to rgb input.") + return + input_shape = helper.get_shape_from_value_info(g.input[0]) + if len(input_shape) != 4: + print("The input shape is not BCHW. Cannot normalize input.") + return + # Construct weight + ch = input_shape[1] + weight_np = np.zeros((3, 3, 4, 4)).astype('float32') + weight_np[1, 1, :3, :2] = np.array([[[[0.299], + [0.587], + [0.114]]]]) + weight_np[1, 1, 3, 2:] = 1. + weight_np = np.transpose(weight_np, (3, 2, 0, 1)) + new_weight = helper.numpy_to_constant("input_rgb2yynn_weight", weight_np) + # Construct conv node + new_conv = onnx.helper.make_node( + 'Conv', + ['new_input', "input_rgb2yynn_weight"], + [g.input[0].name], + name='input_rgba2yynn', + dilations=[1, 1], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1] + ) + # Construct value_infos + old_input_value = g.input.pop() + weight_value = onnx.helper.make_tensor_value_info( + 'input_rgb2yynn_weight', + old_input_value.type.tensor_type.elem_type, + [4, 4, 3, 3] + ) + # Connect the graph + new_input_value = onnx.helper.make_tensor_value_info( + 'new_input', + old_input_value.type.tensor_type.elem_type, + input_shape + ) + g.input.extend([new_input_value]) + g.node.extend([new_weight, new_conv]) + g.value_info.extend([weight_value, old_input_value]) + # topological sort + other.topological_sort(g) + +def swap_MatMul_inputs(g, original_matmul_node): + # Create Transpose nodes + input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) + input_a_shape = helper.get_shape_from_value_info(input_a_value) + if len(input_a_shape) == 2: + perm = [1, 0] + else: + perm = [0, 2, 1] + new_input_b_node = onnx.helper.make_node( + 'Transpose', + inputs = [input_a_value.name], + outputs = [input_a_value.name + '_transposed'], + name = f"{input_a_value.name}_transposed_for_{original_matmul_node.name}", + perm = perm + ) + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + if len(input_b_shape) == 3: + perm = [0, 2, 1] + else: + perm = [0, 1, 3, 2] + new_input_a_node = onnx.helper.make_node( + 'Transpose', + inputs = [input_b_value.name], + outputs = [input_b_value.name + '_transposed'], + name = f'{input_b_value.name}_transposed_for_{original_matmul_node.name}', + perm = perm + ) + # Create new MatMul node + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [new_input_a_node.output[0], new_input_b_node.output[0]], + outputs = [original_matmul_node.output[0] + '_transposed'], + name = original_matmul_node.name + '_transposed' + ) + # Create final Transpose node + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + output_shape = helper.get_shape_from_value_info(output_value) + if len(output_shape) == 3: + perm = [0, 2, 1] + else: + perm = [0, 1, 3, 2] + new_final_transpose_node = onnx.helper.make_node( + 'Transpose', + inputs = [new_matmul_node.output[0]], + outputs = [original_matmul_node.output[0]], + name = original_matmul_node.name + '_final_transpose', + perm = perm + ) + # Add new nodes + g.node.extend([new_input_a_node, new_input_b_node, new_matmul_node, new_final_transpose_node]) + # Delete original nodes + g.node.remove(original_matmul_node) + +def split_MatMul_batch_then_concat(g, original_matmul_node): + new_nodes = [] + final_concat_inputs = [] + # Get the batch count + input_a_value = helper.find_value_by_name(g, original_matmul_node.input[0]) + input_a_shape = helper.get_shape_from_value_info(input_a_value) + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + if len(input_a_shape) == 3: + batch_count = input_a_shape[0] + else: + batch_count = input_a_shape[1] + for i in range(batch_count): + # Create Split nodes for input A + starts_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_starts", (1, ), [i]) + ends_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_ends", (1, ), [i+1]) + axes_node = helper.list_to_constant(f"{input_a_value.name}_sliced_{i}_axes", (1, ), [len(input_a_shape) - 3]) + new_sliced_a_node = onnx.helper.make_node( + 'Slice', + inputs = [input_a_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], + outputs = [f"{input_a_value.name}_sliced_{i}"], + name = f"{input_a_value.name}_sliced_{i}_for_{original_matmul_node.name}" + ) + new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_a_node]) + # Create Split nodes for input B + starts_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_starts", (1, ), [i]) + ends_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_ends", (1, ), [i+1]) + axes_node = helper.list_to_constant(f"{input_b_value.name}_sliced_{i}_axes", (1, ), [len(input_b_shape) - 3]) + new_sliced_b_node = onnx.helper.make_node( + 'Slice', + inputs = [input_b_value.name, starts_node.output[0], ends_node.output[0], axes_node.output[0]], + outputs = [f"{input_b_value.name}_sliced_{i}"], + name = f"{input_b_value.name}_sliced_{i}_for_{original_matmul_node.name}" + ) + new_nodes.extend([starts_node, ends_node, axes_node, new_sliced_b_node]) + # Create MatMul nodes + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [new_sliced_a_node.output[0], new_sliced_b_node.output[0]], + outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], + name = f"{original_matmul_node.name}_sliced_{i}" + ) + new_nodes.append(new_matmul_node) + final_concat_inputs.append(new_matmul_node.output[0]) + # Create Concat nodes + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + if output_value is None: + output_value = helper.find_output_by_name(g, original_matmul_node.output[0]) + if output_value is None: + helper.logger.error(f"Cannot find value_info for {original_matmul_node.output[0]}") + output_shape = helper.get_shape_from_value_info(output_value) + new_concat_node = onnx.helper.make_node( + "Concat", + inputs = final_concat_inputs, + outputs = [original_matmul_node.output[0]], + name = f"{original_matmul_node.name}_final_concat", + axis = len(output_shape) - 3 + ) + new_nodes.append(new_concat_node) + # Add new nodes + g.node.extend(new_nodes) + # Delete original nodes + g.node.remove(original_matmul_node) + + +def split_MatMul_Constant_input_then_concat(g, original_matmul_node): + new_nodes = [] + final_concat_inputs = [] + # Get the batch count + input_b_node = helper.find_node_by_output_name(g, original_matmul_node.input[1]) + input_b_np = helper.constant_to_numpy(input_b_node) + if len(input_b_np.shape) == 3: + batch_count = input_b_np.shape[0] + else: + batch_count = input_b_np.shape[1] + for i in range(batch_count): + # Create new constant node + if len(input_b_np.shape) == 3: + new_np = input_b_np[i:i+1, ...] + else: + new_np = input_b_np[:, i:i+1, ...] + new_weight = helper.numpy_to_constant(f"{input_b_node.name}_sliced_{i}", new_np) + new_nodes.append(new_weight) + # Create MatMul nodes + new_matmul_node = onnx.helper.make_node( + 'MatMul', + inputs = [original_matmul_node.input[0], new_weight.output[0]], + outputs = [f"{original_matmul_node.output[0]}_sliced_{i}"], + name = f"{original_matmul_node.name}_sliced_{i}" + ) + new_nodes.append(new_matmul_node) + final_concat_inputs.append(new_matmul_node.output[0]) + # Create Concat nodes + output_value = helper.find_value_by_name(g, original_matmul_node.output[0]) + output_shape = helper.get_shape_from_value_info(output_value) + new_concat_node = onnx.helper.make_node( + "Concat", + inputs = final_concat_inputs, + outputs = [original_matmul_node.output[0]], + name = f"{original_matmul_node.name}_final_concat", + axis = len(output_shape) - 3 + ) + new_nodes.append(new_concat_node) + # Add new nodes + g.node.extend(new_nodes) + # Delete original value info + input_b_value = helper.find_value_by_name(g, original_matmul_node.input[1]) + if input_b_value is not None: + g.value_info.remove(input_b_value) + # Delete original nodes + g.node.remove(original_matmul_node) + g.node.remove(input_b_node) + + +def special_MatMul_process(g): + for node in g.node: + if node.op_type != 'MatMul': + continue + input_a_name = node.input[0] + input_a_value = helper.find_value_by_name(g, input_a_name) + input_b_name = node.input[1] + input_b_value = helper.find_value_by_name(g, input_b_name) + if input_a_value is None or input_b_value is None: + continue + input_a_shape = helper.get_shape_from_value_info(input_a_value) + input_b_shape = helper.get_shape_from_value_info(input_b_value) + # Check shapes and choose the process + # Normal case, Skip + if len(input_b_shape) == 2: + continue + # Too many dimensions or too few dimensions. Not supported. Skip + if len(input_a_shape) > 4 or len(input_b_shape) > 4: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have too many dimensions.") + continue + if len(input_a_shape) < 2 or len(input_b_shape) < 2: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: inputs have two few dimensions.") + continue + # For 4 dimension, check the first dimension (should be 1) and treated as 3 dimensions. + extra_dim = None + if len(input_a_shape) == 4: + extra_dim = input_a_shape[0] + input_a_shape = input_a_shape[1:] + if len(input_b_shape) == 4: + if input_b_shape[0] != extra_dim: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: input dimension batch sizes does not match ({extra_dim} vs {input_b_shape[0]}).") + continue + input_b_shape = input_b_shape[1:] + # Check input B dimension + # If B is 1 x W x V, it is the same as normal case. + if input_b_shape[0] == 1: + continue + # If B is B x W x V, but B is a constant. + input_b_node = helper.find_node_by_output_name(g, input_b_name) + if input_b_node is not None and input_b_node.op_type == 'Constant': + # Constant input + helper.logger.debug(f"Optimizing MatMul node {node.name}: split constant input.") + split_MatMul_Constant_input_then_concat(g, node) + # If B is B x W x V and A is 1 x H x W, do the swap. + elif len(input_a_shape) == 2 or (input_a_shape[0] == 1 and (extra_dim is None or extra_dim == 1)): + helper.logger.debug(f"Optimizing MatMul node {node.name}: swap input.") + swap_MatMul_inputs(g, node) + # If B is B x W x V and A is B x H x W, do the split. + elif input_b_shape[0] == input_a_shape[0]: + helper.logger.debug(f"Optimizing MatMul node {node.name}: split input batch.") + split_MatMul_batch_then_concat(g, node) + # Other cases are not supported: If B is B x W x V but A is X x H x W. + else: + helper.logger.warning(f"Cannot optimize MatMul {node.name}: unknown reason. Might be shape mismatch.") + continue + other.topological_sort(g) \ No newline at end of file diff --git a/tools/pytorch2onnx_kneron.py b/tools/pytorch2onnx_kneron.py new file mode 100644 index 0000000..e32b9a5 --- /dev/null +++ b/tools/pytorch2onnx_kneron.py @@ -0,0 +1,352 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Original: tools/pytorch2onnx.py, modified by Kneron +import argparse + +import warnings +import os +import onnx +import mmcv +import numpy as np +import onnxruntime as rt +import torch +import torch._C +import torch.serialization +from mmcv import DictAction +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint +from torch import nn + +from mmseg.apis import show_result_pyplot +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose +from mmseg.models import build_segmentor + +from optimizer_scripts.tools import other +from optimizer_scripts.pytorch_exported_onnx_preprocess import ( + torch_exported_onnx_flow, +) + +torch.manual_seed(3) + + +def _parse_normalize_cfg(test_pipeline): + transforms = None + for pipeline in test_pipeline: + if 'transforms' in pipeline: + transforms = pipeline['transforms'] + break + assert transforms is not None, 'Failed to find `transforms`' + norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize'] + assert len(norm_config_li) == 1, '`norm_config` should only have one' + norm_config = norm_config_li[0] + return norm_config + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + img = torch.FloatTensor(rng.rand(*input_shape)) + return img + + +def _prepare_input_img(img_path, + test_pipeline, + shape=None): + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = (shape[1], shape[0]) + test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + img = torch.FloatTensor(data['img']).unsqueeze_(0) + return img + + +def pytorch2onnx(model, + img, + norm_cfg=None, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + img (dict): Input tensor (1xCxHxW) + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ + model.cpu().eval() + + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes + + # replace original forward function + model.forward = model.forward_dummy + origin_forward = model.forward + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, img, + output_file, + input_names=['input'], + output_names=['output'], + export_params=True, + keep_initializers_as_inputs=False, + verbose=show, + opset_version=opset_version, + dynamic_axes=None) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + # NOTE: optimizing onnx for kneron inference + m = onnx.load(output_file) + # NOTE: PyTorch 1.10.x exports onnx ir_version == 7 for opset 11, + # but should be ir_version == 6 + if opset_version == 11: + m.ir_version = 6 + m = torch_exported_onnx_flow(m, disable_fuse_bn=False) + onnx.save(m, output_file) + print(f'{output_file} optimized by KNERON successfully.') + + if verify: + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + with torch.no_grad(): + pytorch_result = model(img).numpy() + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession( + output_file, providers=['CPUExecutionProvider'] + ) + onnx_result = sess.run( + None, {net_feed_input[0]: img.detach().numpy()})[0] + # show segmentation results + if show: + import cv2 + img = img[0][:3, ...].permute(1, 2, 0) * 255 + img = img.detach().numpy().astype(np.uint8) + ori_shape = img.shape[:2] + + # resize onnx_result to ori_shape + onnx_result_ = onnx_result[0].argmax(0) + onnx_result_ = cv2.resize(onnx_result_.astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (onnx_result_, ), + palette=model.PALETTE, + block=False, + title='ONNXRuntime', + opacity=0.5) + + # resize pytorch_result to ori_shape + pytorch_result_ = pytorch_result.squeeze().argmax(0) + pytorch_result_ = cv2.resize(pytorch_result_.astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (pytorch_result_, ), + title='PyTorch', + palette=model.PALETTE, + opacity=0.5) + # compare results + np.testing.assert_allclose( + pytorch_result.astype(np.float32) / num_classes, + onnx_result.astype(np.float32) / num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + if norm_cfg is not None: + print("Prepending BatchNorm layer to ONNX as data normalization...") + mean = norm_cfg['mean'] + std = norm_cfg['std'] + i_n = m.graph.input[0] + if ( + i_n.type.tensor_type.shape.dim[1].dim_value != len(mean) + or i_n.type.tensor_type.shape.dim[1].dim_value != len(std) + ): + raise ValueError( + f"--pixel-bias-value ({mean}) and --pixel-scale-value " + f"({std}) should be same as input dimension: " + f"{i_n.type.tensor_type.shape.dim[1].dim_value}" + ) + norm_bn_bias = [-1 * cm / cs + 128. / cs for cm, cs in zip(mean, std)] + norm_bn_scale = [1 / cs for cs in std] + other.add_bias_scale_bn_after( + m.graph, i_n.name, norm_bn_bias, norm_bn_scale + ) + m = other.polish_model(m) + bn_outf = os.path.splitext(output_file)[0] + "_bn_prepended.onnx" + onnx.save(m, bn_outf) + print(f"BN-Prepended ONNX saved to {bn_outf}") + + return + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--input-img', type=str, help='Images for input', default=None) + parser.add_argument( + '--show', + action='store_true', + help='show onnx graph and segmentation results') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=None, + help='input image height and width.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--normalization-in-onnx', + action='store_true', + help='Prepend BatchNorm layer to onnx model as a role of data ' + 'normalization according to the mean and std value in the given' + 'cfg file.' + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + assert args.opset_version == 11, "kneron_toolchain currently only supports opset 11" + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + cfg.model.pretrained = None + + test_mode = cfg.model.test_cfg.mode + + if args.shape is None: + if test_mode == 'slide': + crop_size = cfg.model.test_cfg['crop_size'] + input_shape = (1, 3, crop_size[1], crop_size[0]) + else: + img_scale = cfg.test_pipeline[1]['img_scale'] + input_shape = (1, 3, img_scale[1], img_scale[0]) + else: + if test_mode == 'slide': + warnings.warn( + "We suggest you NOT assigning shape when exporting " + "slide-mode models. Assigning shape to slide-mode models " + "may result in unexpected results. To see which mode the " + "model is using, check cfg.model.test_cfg.mode, which " + "should be either 'whole' or 'slide'." + ) + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + # build the model and load checkpoint + cfg.model.train_cfg = None + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + if args.checkpoint: + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + segmentor.PALETTE = checkpoint['meta']['PALETTE'] + + # read input or create dummpy input + if args.input_img is not None: + preprocess_shape = (input_shape[2], input_shape[3]) + img = _prepare_input_img( + args.input_img, + cfg.data.test.pipeline, + shape=preprocess_shape) + else: + img = _demo_mm_inputs(input_shape) + + if args.normalization_in_onnx: + norm_cfg = _parse_normalize_cfg(cfg.test_pipeline) + else: + norm_cfg = None + # convert model to onnx file + pytorch2onnx( + segmentor, + img, + norm_cfg=norm_cfg, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify, + )