import numpy as np import onnx import sys from tools.other import topological_sort from tools import helper def fuse_bias_in_consecutive_1x1_conv(g): for second in g.node: # Find two conv if second.op_type != "Conv": continue first = helper.find_node_by_output_name(g, second.input[0]) if first is None or first.op_type != "Conv": continue # Check if the first one has only one folloing node if ( len( helper.find_following_nodes_by_input_value_name( g, first.output[0] ) ) != 1 ): continue # If first node has no bias, continue if len(first.input) == 2: continue # Check their kernel size first_kernel_shape = helper.get_list_attribute_by_name( first, "kernel_shape", "int" ) second_kernel_shape = helper.get_list_attribute_by_name( second, "kernel_shape", "int" ) prod = ( first_kernel_shape[0] * first_kernel_shape[1] * second_kernel_shape[0] * second_kernel_shape[1] ) if prod != 1: continue print("Found: ", first.name, " ", second.name) # Get bias of the nodes first_bias_node = helper.find_node_by_output_name(g, first.input[2]) second_weight_node = helper.find_node_by_output_name( g, second.input[1] ) second_bias_node = helper.find_node_by_output_name(g, second.input[2]) first_bias = helper.constant_to_numpy(first_bias_node) second_weight = helper.constant_to_numpy(second_weight_node) second_bias = helper.constant_to_numpy(second_bias_node) # Calculate the weight for second node first_bias = np.reshape(first_bias, (1, first_bias.size)) second_weight = np.reshape( second_weight, (second_weight.shape[0], second_weight.shape[1]) ) second_weight = np.transpose(second_weight) new_second_bias = second_bias + np.matmul(first_bias, second_weight) new_second_bias = np.reshape(new_second_bias, (new_second_bias.size,)) # Generate new weight new_first_bias = np.reshape(first_bias, (first_bias.size,)) for i in range(new_first_bias.shape[0]): new_first_bias[i] = 0.0 new_first_bias_node = helper.numpy_to_constant( first_bias_node.output[0], new_first_bias ) new_second_bias_node = helper.numpy_to_constant( second_bias_node.output[0], new_second_bias ) # Delete old weight and add new weights g.node.remove(first_bias_node) g.node.remove(second_bias_node) g.node.extend([new_first_bias_node, new_second_bias_node]) topological_sort(g) if __name__ == "__main__": if len(sys.argv) != 3: exit(1) m = onnx.load(sys.argv[1]) fuse_bias_in_consecutive_1x1_conv(m.graph) onnx.save(m, sys.argv[2])