?hgemm

一般矩阵乘矩阵。

即:

op(X)可取值:,alpha,beta为乘法系数,op(A)为m*k矩阵,op(B)为k*n矩阵,C为m*n矩阵。

接口定义

C interface

void cblas_hgemm(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT M, const BLASINT N, const BLASINT K, const __fp16 alpha, const __fp16 *A, const BLASINT lda, const __fp16 *B, const BLASINT ldb, const __fp16 beta, __fp16 *C, const BLASINT ldc);

void cblas_shgemm(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT M, const BLASINT N, const BLASINT K, const float alpha, const __fp16 *A, const BLASINT lda, const __fp16 *B, const BLASINT ldb, const float beta, float *C, const BLASINT ldc);

void cblas_chgemm(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT M, const BLASINT N, const BLASINT K, const void *alpha, const void *A, const BLASINT lda, const void *B, const BLASINT ldb, const void *beta, void *C, const BLASINT ldc);

void cblas_cshgemm(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE transA, const enum CBLAS_TRANSPOSE transB, const BLASINT M, const BLASINT N, const BLASINT K, const void *alpha, const void *A, const BLASINT lda, const void *B, const BLASINT ldb, const void *beta, void *C, const BLASINT ldc);

Fortran interface

CALL HGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)

CALL SHGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)

CALL CHGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)

CALL CSHGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)

参数

参数名

类型

描述

输入/输出

order

枚举类型CBLAS_ORDER

表示矩阵是行主序或列主序。

输入

TransA

枚举类型CBLAS_TRANSPOSE

矩阵A为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransA = CblasNoTrans,
  • 如果TransA = CblasTrans,
  • 如果TransA = CblasConjTrans,
  • 如果TransA = CblasConjTrans,

输入

TransB

枚举类型CBLAS_TRANSPOSE

矩阵B为常规矩阵,转置矩阵或共轭矩阵。

  • 如果TransB = CblasNoTrans,
  • 如果TransB = CblasTrans,
  • 如果TransB = CblasConjTrans,
  • 如果TransB = CblasConjTrans,

输入

M

整型数

矩阵op(A)和矩阵C的行。

输入

N

整型数

矩阵op(B)和矩阵C的列。

输入

K

整型数

矩阵op(A)的列和矩阵op(B)的行。

输入

alpha

  • 在hgemm中是半精度浮点类型。
  • 在shgemm中是单精度浮点类型。
  • 在chgemm中是半精度复数类型。
  • 在cshgemm中是单精度复数类型。

乘法系数。

输入

A

  • 在hgemm中是半精度浮点类型。
  • 在shgemm中是半精度浮点类型。
  • 在chgemm中是半精度复数类型。
  • 在cshgemm中是半精度复数类型。

矩阵A。

输入

lda

整型数

  • 矩阵为列存,TransA = CblasNoTrans,lda至少max(1, m),否则max(1, k)。
  • 矩阵为行存,TransA = CblasNoTrans,lda至少max(1, k),否则max(1, m)。

输入

B

  • 在hgemm中是半精度浮点类型。
  • 在shgemm中是半精度浮点类型。
  • 在chgemm中是半精度复数类型。
  • 在cshgemm中是半精度复数类型。

矩阵B。

输入

ldb

整型数

  • 矩阵为列存,TransB = CblasNoTrans,ldb至少max(1, k),否则max(1, n)。
  • 矩阵为行存,TransB = CblasNoTrans,ldb至少max(1, n),否则max(1, k)。

输入

beta

  • 在hgemm中是半精度浮点类型。
  • 在shgemm中是单精度浮点类型。
  • 在chgemm中是半精度复数类型。
  • 在cshgemm中是单精度复数类型。

乘法系数。

输入

C

  • 在hgemm中是半精度浮点类型。
  • 在shgemm中是单精度浮点类型。
  • 在chgemm中是半精度复数类型。
  • 在cshgemm中是单精度复数类型。

矩阵C。

输入/输出

ldc

整型数

矩阵为列存,ldc至少max(1, m),否则max(1, n)。

输入

线程自定义说明

使用chgemm和cshgemm接口时,用户可以通过设置环境变量“BLAS_MNK_RANGE”和“BLAS_MNK_THREADS”获得线程定制化的需求。

在不设置“BLAS_MNK_RANGE”或该值为零时,系统会根据当前环境自动分配线程数。

依赖

#include "kblas.h"

示例

C interface

    int m = 4, k = 3, n = 4, lda = 4, ldb = 3, ldc = 4;  
    __fp16 alpha = 1.0, beta = 2.0;  
     /*  
     * A:  
     *     0.340188,       0.411647,       -0.222225,  
     *     -0.105617,      -0.302449,      0.053970,  
     *     0.283099,       -0.164777,      -0.022603,  
     *     0.298440,       0.268230,       0.128871,  
     * B:  
     *     -0.135216,      0.416195,       -0.358397,      -0.257113,  
     *     0.013401,       0.135712,       0.106969,       -0.362768,  
     *     0.452230,       0.217297,       -0.483699,      0.304177,  
     * C:  
     *     -0.343321,      0.498924,       0.112640,       -0.006417,  
     *     -0.099056,      -0.281743,      -0.203968,      0.472775,  
     *     -0.370210,      0.012932,       0.137552,       -0.207483,  
     *     -0.391191,      0.339112,       0.024287,       0.271358,  
     */  
    __fp16 a[12] = {0.340188, -0.105617, 0.283099,  
                    0.298440, 0.411647, -0.302449,  
                    -0.164777, 0.268230, -0.222225,  
                    0.053970, -0.022603, 0.128871};  
    __fp16 b[12] = {-0.135216, 0.013401, 0.452230, 0.416195,  
                    0.135712, 0.217297, -0.358397, 0.106969,  
                    -0.483699, -0.257113, -0.362768, 0.304177};  
    __fp16 c[16] = {-0.343321, -0.099056, -0.370210, -0.391191,  
                    0.498924, -0.281743, 0.012932, 0.339112,  
                    0.112640, -0.203968, 0.137552, 0.024287,  
                    -0.006417, 0.472775, -0.207483, 0.271358};  
  
    cblas_hgemm(CblasColMajor,CblasNoTrans,CblasNoTrans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);  
    /*  
     * Output C:  
     *     -0.827621       1.147010        0.254881        -0.317229  
     *     -0.163476       -0.636762       -0.428542       1.098841  
     *     -0.791128       0.116416        0.166949        -0.434854  
     *     -0.760862       0.866839        -0.092028       0.407877  
     *  
     */

Fortran interface

      INTEGER :: M=4, K=3, N=4  
      INTEGER :: LDA=4, LDB=3, LDC=4  
      REAL(4) :: ALPHA=1.0, BETA=2.0  
      REAL(4) :: A(12), B(12), C(16)  
      DATA A/0.340188, -0.105617, 0.283099,  
     $       0.298440, 0.411647, -0.302449,  
     $       -0.164777, 0.268230, -0.222225,  
     $       0.053970, -0.022603, 0.128871/  
      DATA B/-0.135216, 0.013401, 0.452230, 0.416195,  
     $       0.135712, 0.217297, -0.358397, 0.106969,  
     $       -0.483699, -0.257113, -0.362768, 0.304177/  
      DATA C/-0.343321, -0.099056, -0.370210, -0.391191,  
     $       0.498924, -0.281743, 0.012932, 0.339112,  
     $       0.112640, -0.203968, 0.137552, 0.024287,  
     $       -0.006417, 0.472775, -0.207483, 0.271358/  
      EXTERNAL HGEMM  
      CALL HGEMM('N', 'N', M, N, K, ALPHA, A, LDA, B, LDB, BETA, C,  
     $          LDC)  
*     Output C:  
*         -0.827621       1.147010        0.254881        -0.317229  
*         -0.163476       -0.636762       -0.428542       1.098841  
*         -0.791128       0.116416        0.166949        -0.434854  
*         -0.760862       0.866839        -0.092028       0.407877