195 lines
5.3 KiB
Python

import queue
class StatTree(object):
def __init__(self, root_node):
assert isinstance(root_node, StatNode)
self.root_node = root_node
def get_same_level_max_node_depth(self, query_node):
if query_node.name == self.root_node.name:
return 0
same_level_depth = max(
[child.depth for child in query_node.parent.children])
return same_level_depth
def update_stat_nodes_granularity(self):
q = queue.Queue()
q.put(self.root_node)
while not q.empty():
node = q.get()
node.granularity = self.get_same_level_max_node_depth(node)
for child in node.children:
q.put(child)
def get_collected_stat_nodes(self, query_granularity):
self.update_stat_nodes_granularity()
collected_nodes = []
stack = list()
stack.append(self.root_node)
while len(stack) > 0:
node = stack.pop()
for child in reversed(node.children):
stack.append(child)
if node.depth == query_granularity:
collected_nodes.append(node)
if node.depth < query_granularity <= node.granularity:
collected_nodes.append(node)
return collected_nodes
class StatNode(object):
def __init__(self, name=str(), parent=None):
self._name = name
self._input_shape = None
self._output_shape = None
self._parameter_quantity = 0
self._inference_memory = 0
self._MAdd = 0
self._Memory = (0, 0)
self._Flops = 0
self._duration = 0
self._duration_percent = 0
self._granularity = 1
self._depth = 1
self.parent = parent
self.children = list()
@property
def name(self):
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def granularity(self):
return self._granularity
@granularity.setter
def granularity(self, g):
self._granularity = g
@property
def depth(self):
d = self._depth
if len(self.children) > 0:
d += max([child.depth for child in self.children])
return d
@property
def input_shape(self):
if len(self.children) == 0: # leaf
return self._input_shape
else:
return self.children[0].input_shape
@input_shape.setter
def input_shape(self, input_shape):
assert isinstance(input_shape, (list, tuple))
self._input_shape = input_shape
@property
def output_shape(self):
if len(self.children) == 0: # leaf
return self._output_shape
else:
return self.children[-1].output_shape
@output_shape.setter
def output_shape(self, output_shape):
assert isinstance(output_shape, (list, tuple))
self._output_shape = output_shape
@property
def parameter_quantity(self):
# return self.parameters_quantity
total_parameter_quantity = self._parameter_quantity
for child in self.children:
total_parameter_quantity += child.parameter_quantity
return total_parameter_quantity
@parameter_quantity.setter
def parameter_quantity(self, parameter_quantity):
assert parameter_quantity >= 0
self._parameter_quantity = parameter_quantity
@property
def inference_memory(self):
total_inference_memory = self._inference_memory
for child in self.children:
total_inference_memory += child.inference_memory
return total_inference_memory
@inference_memory.setter
def inference_memory(self, inference_memory):
self._inference_memory = inference_memory
@property
def MAdd(self):
total_MAdd = self._MAdd
for child in self.children:
total_MAdd += child.MAdd
return total_MAdd
@MAdd.setter
def MAdd(self, MAdd):
self._MAdd = MAdd
@property
def Flops(self):
total_Flops = self._Flops
for child in self.children:
total_Flops += child.Flops
return total_Flops
@Flops.setter
def Flops(self, Flops):
self._Flops = Flops
@property
def Memory(self):
total_Memory = self._Memory
for child in self.children:
total_Memory[0] += child.Memory[0]
total_Memory[1] += child.Memory[1]
print(total_Memory)
return total_Memory
@Memory.setter
def Memory(self, Memory):
assert isinstance(Memory, (list, tuple))
self._Memory = Memory
@property
def duration(self):
total_duration = self._duration
for child in self.children:
total_duration += child.duration
return total_duration
@duration.setter
def duration(self, duration):
self._duration = duration
def find_child_index(self, child_name):
assert isinstance(child_name, str)
index = -1
for i in range(len(self.children)):
if child_name == self.children[i].name:
index = i
return index
def add_child(self, node):
assert isinstance(node, StatNode)
if self.find_child_index(node.name) == -1: # not exist
self.children.append(node)