2026-01-28 06:16:04 +00:00

101 lines
2.3 KiB
C

//
// Created by xiangzhou on 11/9/23.
//
#pragma once
#include "C_Tensor.h"
#ifdef __cplusplus
extern "C" {
#endif
//----------------------------------------------------------------------------
// Conv
//------------------------------------------------------------------------
struct ConvAttribute {
// user
const char* name_;
const char* pad_type_;
C_Shape dilations_;
C_Shape kernel_;
C_Shape pads_;
C_Shape strides_;
C_Shape conv_pads_; // for conv transpose, it's different from
// user - for conv transpose
// library
uint32_t group_;
uint32_t out_top_pad_; // for conv transpose
uint32_t out_bottom_pad_;
uint32_t out_left_pad_;
uint32_t out_right_pad_;
C_Shape input_0_shape_;
C_Shape output_0_shape_;
C_Shape weight_0_shape_;
// for gemm
uint32_t M_;
uint32_t K_;
uint32_t N_;
// for batch
uint32_t input_batch_size_;
uint32_t output_batch_size_;
int allocated_size_;
};
typedef struct ConvAttribute ConvAttribute;
typedef struct ConvAttribute ConvIntegerAttribute;
typedef struct ConvAttribute ConvTransposeAttribute;
//----------------------------------------------------------------------------
// Gemm
//------------------------------------------------------------------------
struct GemmAttribute {
// user
const char* name_;
float alpha_;
float beta_;
int transA_;
int transB_;
// library
uint32_t M_;// = 0;
uint32_t K_;// = 0;
uint32_t N_;// = 0;
};
typedef struct GemmAttribute GemmAttribute;
//----------------------------------------------------------------------------
// MatMul
//------------------------------------------------------------------------
struct MatMulAttribute {
// user
const char* name_;
// library
C_Shape first_dim_;
C_Shape second_dim_;
C_Shape broadcast_first_dim_;
C_Shape broadcast_second_dim_;
};
typedef struct MatMulAttribute MatMulAttribute;
typedef struct Attribute MatMulIntegerAttribute;
CREATE_DECL_SETUP_FUN(Conv)
CREATE_DECL_SETUP_FUN(Gemm)
CREATE_DECL_SETUP_FUN(MatMul)
CREATE_DECL_RUN_FUN(Conv)
CREATE_DECL_RUN_FUN2(ConvInteger)
CREATE_DECL_RUN_FUN(ConvTranspose)
CREATE_DECL_RUN_FUN(Gemm)
CREATE_DECL_RUN_FUN(MatMul)
CREATE_DECL_RUN_FUN(MatMulInteger)
#ifdef __cplusplus
}
#endif