101 lines
2.3 KiB
C
101 lines
2.3 KiB
C
//
|
|
// Created by xiangzhou on 11/9/23.
|
|
//
|
|
|
|
#pragma once
|
|
#include "C_Tensor.h"
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
|
|
|
|
//----------------------------------------------------------------------------
|
|
// Conv
|
|
//------------------------------------------------------------------------
|
|
struct ConvAttribute {
|
|
// user
|
|
const char* name_;
|
|
const char* pad_type_;
|
|
C_Shape dilations_;
|
|
C_Shape kernel_;
|
|
C_Shape pads_;
|
|
C_Shape strides_;
|
|
C_Shape conv_pads_; // for conv transpose, it's different from
|
|
|
|
// user - for conv transpose
|
|
|
|
// library
|
|
uint32_t group_;
|
|
uint32_t out_top_pad_; // for conv transpose
|
|
uint32_t out_bottom_pad_;
|
|
uint32_t out_left_pad_;
|
|
uint32_t out_right_pad_;
|
|
C_Shape input_0_shape_;
|
|
C_Shape output_0_shape_;
|
|
C_Shape weight_0_shape_;
|
|
|
|
// for gemm
|
|
uint32_t M_;
|
|
uint32_t K_;
|
|
uint32_t N_;
|
|
|
|
// for batch
|
|
uint32_t input_batch_size_;
|
|
uint32_t output_batch_size_;
|
|
|
|
int allocated_size_;
|
|
};
|
|
|
|
typedef struct ConvAttribute ConvAttribute;
|
|
typedef struct ConvAttribute ConvIntegerAttribute;
|
|
typedef struct ConvAttribute ConvTransposeAttribute;
|
|
|
|
//----------------------------------------------------------------------------
|
|
// Gemm
|
|
//------------------------------------------------------------------------
|
|
struct GemmAttribute {
|
|
// user
|
|
const char* name_;
|
|
float alpha_;
|
|
float beta_;
|
|
int transA_;
|
|
int transB_;
|
|
|
|
// library
|
|
uint32_t M_;// = 0;
|
|
uint32_t K_;// = 0;
|
|
uint32_t N_;// = 0;
|
|
};
|
|
typedef struct GemmAttribute GemmAttribute;
|
|
//----------------------------------------------------------------------------
|
|
// MatMul
|
|
//------------------------------------------------------------------------
|
|
struct MatMulAttribute {
|
|
// user
|
|
const char* name_;
|
|
// library
|
|
|
|
C_Shape first_dim_;
|
|
C_Shape second_dim_;
|
|
C_Shape broadcast_first_dim_;
|
|
C_Shape broadcast_second_dim_;
|
|
};
|
|
|
|
typedef struct MatMulAttribute MatMulAttribute;
|
|
typedef struct Attribute MatMulIntegerAttribute;
|
|
|
|
CREATE_DECL_SETUP_FUN(Conv)
|
|
CREATE_DECL_SETUP_FUN(Gemm)
|
|
CREATE_DECL_SETUP_FUN(MatMul)
|
|
|
|
CREATE_DECL_RUN_FUN(Conv)
|
|
CREATE_DECL_RUN_FUN2(ConvInteger)
|
|
CREATE_DECL_RUN_FUN(ConvTranspose)
|
|
CREATE_DECL_RUN_FUN(Gemm)
|
|
CREATE_DECL_RUN_FUN(MatMul)
|
|
CREATE_DECL_RUN_FUN(MatMulInteger)
|
|
|
|
#ifdef __cplusplus
|
|
}
|
|
#endif
|