// // Created by xiangzhou on 11/9/23. // #pragma once #include "C_Tensor.h" #ifdef __cplusplus extern "C" { #endif //---------------------------------------------------------------------------- // Conv //------------------------------------------------------------------------ struct ConvAttribute { // user const char* name_; const char* pad_type_; C_Shape dilations_; C_Shape kernel_; C_Shape pads_; C_Shape strides_; C_Shape conv_pads_; // for conv transpose, it's different from // user - for conv transpose // library uint32_t group_; uint32_t out_top_pad_; // for conv transpose uint32_t out_bottom_pad_; uint32_t out_left_pad_; uint32_t out_right_pad_; C_Shape input_0_shape_; C_Shape output_0_shape_; C_Shape weight_0_shape_; // for gemm uint32_t M_; uint32_t K_; uint32_t N_; // for batch uint32_t input_batch_size_; uint32_t output_batch_size_; int allocated_size_; }; typedef struct ConvAttribute ConvAttribute; typedef struct ConvAttribute ConvIntegerAttribute; typedef struct ConvAttribute ConvTransposeAttribute; //---------------------------------------------------------------------------- // Gemm //------------------------------------------------------------------------ struct GemmAttribute { // user const char* name_; float alpha_; float beta_; int transA_; int transB_; // library uint32_t M_;// = 0; uint32_t K_;// = 0; uint32_t N_;// = 0; }; typedef struct GemmAttribute GemmAttribute; //---------------------------------------------------------------------------- // MatMul //------------------------------------------------------------------------ struct MatMulAttribute { // user const char* name_; // library C_Shape first_dim_; C_Shape second_dim_; C_Shape broadcast_first_dim_; C_Shape broadcast_second_dim_; }; typedef struct MatMulAttribute MatMulAttribute; typedef struct Attribute MatMulIntegerAttribute; CREATE_DECL_SETUP_FUN(Conv) CREATE_DECL_SETUP_FUN(Gemm) CREATE_DECL_SETUP_FUN(MatMul) CREATE_DECL_RUN_FUN(Conv) CREATE_DECL_RUN_FUN2(ConvInteger) CREATE_DECL_RUN_FUN(ConvTranspose) CREATE_DECL_RUN_FUN(Gemm) CREATE_DECL_RUN_FUN(MatMul) CREATE_DECL_RUN_FUN(MatMulInteger) #ifdef __cplusplus } #endif