261 lines
8.0 KiB
C
261 lines
8.0 KiB
C
/****************************************************************************
|
||
*
|
||
* Copyright (c) 2017 - 2018 by Rockchip Corp. All rights reserved.
|
||
*
|
||
* The material in this file is confidential and contains trade secrets
|
||
* of Rockchip Corporation. This is proprietary information owned by
|
||
* Rockchip Corporation. No part of this work may be disclosed,
|
||
* reproduced, copied, transmitted, or used in any way for any purpose,
|
||
* without the express written permission of Rockchip Corporation.
|
||
*
|
||
*****************************************************************************/
|
||
|
||
#ifndef _RKNN_MATMUL_API_H
|
||
#define _RKNN_MATMUL_API_H
|
||
|
||
#ifdef __cplusplus
|
||
extern "C" {
|
||
#endif
|
||
|
||
#include "rknn_api.h"
|
||
|
||
typedef rknn_context rknn_matmul_ctx;
|
||
|
||
typedef struct _rknn_matmul_tensor_attr
|
||
{
|
||
char name[RKNN_MAX_NAME_LEN];
|
||
|
||
// indicate A(M, K) or B(K, N) or C(M, N)
|
||
uint32_t n_dims;
|
||
uint32_t dims[RKNN_MAX_DIMS];
|
||
|
||
// matmul tensor size
|
||
uint32_t size;
|
||
|
||
// matmul tensor data type
|
||
// int8 : A, B
|
||
// int32: C
|
||
rknn_tensor_type type;
|
||
} rknn_matmul_tensor_attr;
|
||
|
||
typedef struct _rknn_matmul_io_attr
|
||
{
|
||
// indicate A(M, K) or B(K, N) or C(M, N)
|
||
rknn_matmul_tensor_attr A;
|
||
rknn_matmul_tensor_attr B;
|
||
rknn_matmul_tensor_attr C;
|
||
} rknn_matmul_io_attr;
|
||
|
||
/*
|
||
matmul information struct
|
||
*/
|
||
typedef struct rknn_matmul_info_t
|
||
{
|
||
int32_t M;
|
||
int32_t K; // limit: rk356x: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte;
|
||
// rk3588: int8 type must be aligned with 32byte, float16 type must be aligned with 32byte;
|
||
int32_t N; // limit: rk356x: int8 type must be aligned with 16byte, float16 type must be aligned with 8byte;
|
||
// rk3588: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte;
|
||
|
||
// matmul data type
|
||
// int8: int8(A) x int8(B) -> int32(C)
|
||
// float16: float16(A) x float16(B) -> float32(C)
|
||
rknn_tensor_type type;
|
||
|
||
// matmul native layout for B
|
||
// 0: normal layout
|
||
// 1: native layout
|
||
int32_t native_layout;
|
||
|
||
// matmul perf layout for A and C
|
||
// 0: normal layout
|
||
// 1: perf layout
|
||
int32_t perf_layout;
|
||
} rknn_matmul_info;
|
||
|
||
/* rknn_matmul_create
|
||
|
||
params:
|
||
rknn_matmul_ctx *ctx the handle of context.
|
||
rknn_matmul_info *info the matmal information.
|
||
rknn_matmul_io_attr *io_attr inputs/output attribute
|
||
return:
|
||
int error code
|
||
*/
|
||
int rknn_matmul_create(rknn_matmul_ctx* ctx, rknn_matmul_info* info, rknn_matmul_io_attr* io_attr);
|
||
|
||
/* rknn_matmul_set_io_mem
|
||
|
||
params:
|
||
rknn_matmul_ctx ctx the handle of context.
|
||
rknn_tensor_mem *mem the pointer of tensor memory information.
|
||
rknn_matmul_tensor_attr *attr the attribute of input or output tensor buffer.
|
||
return:
|
||
int error code.
|
||
|
||
formula:
|
||
C = A * B,
|
||
|
||
limit:
|
||
K <= 4096
|
||
K limit: rk356x: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte;
|
||
rk3588: int8 type must be aligned with 32byte, float16 type must be aligned with 32byte;
|
||
N limit: rk356x: int8 type must be aligned with 16byte, float16 type must be aligned with 8byte;
|
||
rk3588: int8 type must be aligned with 32byte, float16 type must be aligned with 16byte;
|
||
|
||
A shape: M x K
|
||
normal layout: (M, K)
|
||
[M1K1, M1K2, ..., M1Kk,
|
||
M2K1, M2K2, ..., M2Kk,
|
||
...
|
||
MmK1, MmK2, ..., MmKk]
|
||
for rk356x:
|
||
int8:
|
||
perf layout: (K / 8, M, 8)
|
||
[K1M1, K2M1, ..., K8M1,
|
||
K9M2, K10M2, ..., K16M2,
|
||
...
|
||
K(k-7)Mm, K(k-6)Mm, ..., KkMm]
|
||
float16:
|
||
perf layout: (K / 4, M, 4)
|
||
[K1M1, K2M1, ..., K4M1,
|
||
K9M2, K10M2, ..., K8M2,
|
||
...
|
||
K(k-3)Mm, K(k-2)Mm, ..., KkMm]
|
||
for rk3588:
|
||
int8:
|
||
perf layout: (K / 16, M, 16)
|
||
[K1M1, K2M1, ..., K16M1,
|
||
K9M2, K10M2, ..., K32M2,
|
||
...
|
||
K(k-15)Mm, K(k-14)Mm, ..., KkMm]
|
||
float16:
|
||
perf layout: (K / 8, M, 8)
|
||
[K1M1, K2M1, ..., K8M1,
|
||
K9M2, K10M2, ..., K16M2,
|
||
...
|
||
K(k-7)Mm, K(k-6)Mm, ..., KkMm]
|
||
B shape: K x N
|
||
normal layout: (K, N)
|
||
[K1N1, K1N2, ..., K1Nn,
|
||
K2N1, K2N2, ..., K2Nn,
|
||
...
|
||
KkN1, KkN2, ..., KkNn]
|
||
for rk356x:
|
||
int8:
|
||
native layout: (N / 16, K / 32, 16, 32)
|
||
[K1N1, K2N1, ..., K32N1,
|
||
K1N2, K2N2, ..., K32N2,
|
||
...
|
||
K1N16, K2N16, ..., K32N16,
|
||
K33N1, K34N1, ..., K64N1,
|
||
K33N2, K34N2, ..., K64N2,
|
||
...
|
||
K(k-31)N16, K(k-30)N16, ..., KkN16,
|
||
K1N17, K2N17, ..., K32N17,
|
||
K1N18, K2N18, ..., K32N18,
|
||
...
|
||
K(k-31)Nn, K(k-30)Nn, ..., KkNn]
|
||
float16:
|
||
native layout: (N / 8, K / 16, 8, 16)
|
||
[K1N1, K2N1, ..., K16N1,
|
||
K1N2, K2N2, ..., K16N2,
|
||
...
|
||
K1N8, K2N8, ..., K16N8,
|
||
K17N1, K18N1, ..., K32N1,
|
||
K17N2, K18N2, ..., K32N2,
|
||
...
|
||
K(k-15)N8, K(k-30)N8, ..., KkN8,
|
||
K1N9, K2N9, ..., K16N9,
|
||
K1N10, K2N10, ..., K16N10,
|
||
...
|
||
K(k-15)Nn, K(k-14)Nn, ..., KkNn]
|
||
for rk3588:
|
||
int8:
|
||
native layout: (N / 32, K / 32, 32, 32)
|
||
[K1N1, K2N1, ..., K32N1,
|
||
K1N2, K2N2, ..., K32N2,
|
||
...
|
||
K1N32, K2N32, ..., K32N32,
|
||
K33N1, K34N1, ..., K64N1,
|
||
K33N2, K34N2, ..., K64N2,
|
||
...
|
||
K(k-31)N32, K(k-30)N32, ..., KkN32,
|
||
K1N33, K2N33, ..., K32N33,
|
||
K1N34, K2N34, ..., K32N34,
|
||
...
|
||
K(k-31)Nn, K(k-30)Nn, ..., KkNn]
|
||
float16:
|
||
native layout: (N / 16, K / 32, 16, 32)
|
||
[K1N1, K2N1, ..., K32N1,
|
||
K1N2, K2N2, ..., K32N2,
|
||
...
|
||
K1N16, K2N16, ..., K32N16,
|
||
K33N1, K34N1, ..., K64N1,
|
||
K33N2, K34N2, ..., K64N2,
|
||
...
|
||
K(k-31)N16, K(k-30)N16, ..., KkN16,
|
||
K1N17, K2N17, ..., K32N17,
|
||
K1N18, K2N18, ..., K32N18,
|
||
...
|
||
K(k-31)Nn, K(k-30)Nn, ..., KkNn]
|
||
C shape: M x N
|
||
normal layout: (M, N)
|
||
[M1N1, M1N2, ..., M1Nn,
|
||
M2N1, M2N2, ..., M2Nn,
|
||
...
|
||
MmN1, MmN2, ..., MmNn]
|
||
perf layout: (N / 4, M, 4)
|
||
[N1M1, N2M1, ..., N4M1,
|
||
N5M2, N6M2, ..., N8M2,
|
||
...
|
||
N(n-3)Mm, N(n-2)Mm, ..., NnMm]
|
||
*/
|
||
int rknn_matmul_set_io_mem(rknn_matmul_ctx ctx, rknn_tensor_mem* mem, rknn_matmul_tensor_attr* attr);
|
||
|
||
/* rknn_matmul_set_core_mask
|
||
|
||
set rknn core mask.(only support rk3588 in current)
|
||
|
||
RKNN_NPU_CORE_AUTO: auto mode, default value
|
||
RKNN_NPU_CORE_0: core 0 mode
|
||
RKNN_NPU_CORE_1: core 1 mode
|
||
RKNN_NPU_CORE_2: core 2 mode
|
||
RKNN_NPU_CORE_0_1: combine core 0/1 mode
|
||
RKNN_NPU_CORE_0_1_2: combine core 0/1/2 mode
|
||
|
||
input:
|
||
rknn_matmul_ctx context the handle of context.
|
||
rknn_core_mask core_mask the core mask.
|
||
return:
|
||
int error code.
|
||
*/
|
||
int rknn_matmul_set_core_mask(rknn_matmul_ctx context, rknn_core_mask core_mask);
|
||
|
||
/* rknn_matmul_run
|
||
|
||
run the matmul in blocking mode
|
||
|
||
params:
|
||
rknn_matmul_ctx ctx the handle of context.
|
||
return:
|
||
int error code.
|
||
*/
|
||
int rknn_matmul_run(rknn_matmul_ctx ctx);
|
||
|
||
/* rknn_matmul_destroy
|
||
|
||
destroy the matmul context
|
||
|
||
params:
|
||
rknn_matmul_ctx ctx the handle of context.
|
||
return:
|
||
int error code.
|
||
*/
|
||
int rknn_matmul_destroy(rknn_matmul_ctx ctx);
|
||
|
||
#ifdef __cplusplus
|
||
} // extern "C"
|
||
#endif
|
||
|
||
#endif // _RKNN_MATMUL_API_H
|