示例

C Interface:

#include <stdio.h>
#include <kml_scadss.h>

int Run(MPI_Comm comm)
{
    int ierr;
    int rank;
    MPI_Comm_rank(comm, &rank);

    int n = 8;
    int nrhs = 1;

    // Create matrix A
    int ia[9] = {0, 2, 4, 6, 7, 8, 10, 12, 14};
    int ja[14] = {0, 7, 1, 6, 2, 5, 3, 4, 2, 5, 1, 6, 0, 7};
    double a[14] = {1.0, 2.0, -2.0, 3.0, 3.0, 4.0, -4.0, 5.0, 4.0, -6.0, 3.0, 7.0, 2.0, 8.0};
    KmlSolverMatrixStore storeA;
    storeA.indexType = KMLSS_INDEX_INT32;
    storeA.valueType = KMLSS_VALUE_FP64;
    storeA.format = KMLSS_MATRIX_STORE_CSR;
    if (rank == 0) {
        storeA.nRow = n;
        storeA.nCol = n;
        storeA.csr.rowOffset = ia;
        storeA.csr.colIndex = ja;
        storeA.csr.value = a;
    } else {
        storeA.nRow = 0;
        storeA.nCol = 0;
        storeA.csr.rowOffset = nullptr;
        storeA.csr.colIndex = nullptr;
        storeA.csr.value = nullptr;
    }

    KmlSolverMatrixOption optA;
    optA.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optA.type = KMLSS_MATRIX_GEN;

    KmlScasolverMatrixOption scaOptA;
    if (rank == 0) {
        scaOptA.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
                            KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
                            KMLSS_MATRIX_OPTIONS_PARTITION;
        scaOptA.partition.type = KMLSS_MATRIX_PARTITION_ROW;
        scaOptA.globalNumRows = n;
        scaOptA.globalNumCols = n;
        scaOptA.partition.localBegin = 0;
    } else {
        scaOptA.fieldMask = 0;
    }

    KmlScasolverMatrix *A;
    ierr = KmlScasolverMatrixCreate(&A, &storeA, &optA, &scaOptA);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create A: %d\n", ierr);
        return 1;
    }

    // Create vector b
    double b[8] = {3.0, 1.0, 7.0, -4.0, 5.0, -2.0, 10.0, 10.0};
    KmlSolverMatrixStore storeB;
    storeB.indexType = KMLSS_INDEX_INT32;
    storeB.valueType = KMLSS_VALUE_FP64;
    storeB.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
    if (rank == 0) {
        storeB.nRow = n;
        storeB.nCol = nrhs;
        storeB.dense.value = b;
        storeB.dense.ld = n;
    } else {
        storeB.nRow = 0;
        storeB.nCol = 0;
        storeB.dense.value = nullptr;
        storeB.dense.ld = 0;
    }

    KmlSolverMatrixOption optB;
    optB.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optB.type = KMLSS_MATRIX_GEN;

    KmlScasolverMatrixOption scaOptB;
    if (rank == 0) {
        scaOptB.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
                            KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
                            KMLSS_MATRIX_OPTIONS_PARTITION;
        scaOptB.partition.type = KMLSS_MATRIX_PARTITION_ROW;
        scaOptB.partition.localBegin = 0;
        scaOptB.globalNumRows = n;
        scaOptB.globalNumCols = nrhs;
    } else {
        scaOptB.fieldMask = 0;
    }

    KmlScasolverMatrix *B;
    ierr = KmlScasolverMatrixCreate(&B, &storeB, &optB, &scaOptB);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create b: %d\n", ierr);
        return 1;
    }

    // Create vector x
    double x[8] = {0};
    KmlSolverMatrixStore storeX;
    storeX.indexType = KMLSS_INDEX_INT32;
    storeX.valueType = KMLSS_VALUE_FP64;
    storeX.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
    if (rank == 0) {
        storeX.nRow = n;
        storeX.nCol = nrhs;
        storeX.dense.value = x;
        storeX.dense.ld = n;
    } else {
        storeX.nRow = 0;
        storeX.nCol = 0;
        storeX.dense.value = nullptr;
        storeX.dense.ld = 0;
    }

    KmlSolverMatrixOption optX;
    optX.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optX.type = KMLSS_MATRIX_GEN;

    KmlScasolverMatrixOption scaOptX;
    if (rank == 0) {
        scaOptX.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
                            KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
                            KMLSS_MATRIX_OPTIONS_PARTITION;
        scaOptX.partition.type = KMLSS_MATRIX_PARTITION_ROW;
        scaOptX.partition.localBegin = 0;
        scaOptX.globalNumRows = n;
        scaOptX.globalNumCols = nrhs;
    } else {
        scaOptX.fieldMask = 0;
    }


    KmlScasolverMatrix *X;
    ierr = KmlScasolverMatrixCreate(&X, &storeX, &optX, &scaOptX);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create x: %d\n", ierr);
        return 1;
    }

    // Init solver
    KmlDssInitOption opt;
    opt.fieldMask = KMLDSS_INIT_OPTION_BWR_MODE | KMLDSS_INIT_OPTION_NTHREADS;
    opt.bwrMode = KMLDSS_BWR_OFF;
    opt.nThreads = 32;

    KmlScadssInitOption scaOpt;
    scaOpt.fieldMask = KMLSCADSS_OPTIONS_COMM;
    scaOpt.comm = comm;

    KmlScadssSolver *solver;
    ierr = KmlScadssInit(&solver, &opt, &scaOpt);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssInit: %d\n", ierr);
        return ierr;
    }

    // Analyze
    KmlDssAnalyzeOption optAnalyze;
    optAnalyze.fieldMask = KMLDSS_ANALYZE_OPTION_MATCHING_TYPE | KMLDSS_ANALYZE_OPTION_RDR_TYPE |
                           KMLDSS_ANALYZE_OPTION_NTHREADS_RDR;
    optAnalyze.matchingType = KMLDSS_MATCHING_OFF;
    optAnalyze.rdrType = KMLDSS_RDR_KRDR;
    optAnalyze.nThreadsRdr = 1;

    KmlScadssAnalyzeOption scaOptAnalyze;
    scaOptAnalyze.fieldMask = 0;

    ierr = KmlScadssAnalyze(solver, A, &optAnalyze, &scaOptAnalyze);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssAnalyze: %d\n", ierr);
        return ierr;
    }

    // Factorize
    KmlDssFactorizeOption optFact;
    optFact.fieldMask = KMLDSS_FACTORIZE_OPTION_PERTURBATION_THRESHOLD;
    optFact.perturbationThreshold = 1e-8;

    KmlScadssFactorizeOption scaOptFact;
    scaOptFact.fieldMask = 0;

    ierr = KmlScadssFactorize(solver, A, &optFact, &scaOptFact);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssFactorize: %d\n", ierr);
        return ierr;
    }

    // Solve
    KmlDssSolveOption optSolve;
    optSolve.fieldMask = KMLDSS_SOLVE_OPTION_SOLVE_STAGE | KMLDSS_SOLVE_OPTION_REFINE_METHOD;
    optSolve.stage = KMLDSS_SOLVE_ALL;
    optSolve.refineMethod = KMLDSS_REFINE_OFF;

    KmlScadssSolveOption scaOptSolve;
    scaOptSolve.fieldMask = 0;

    ierr = KmlScadssSolve(solver, B, X, &optSolve, &scaOptSolve);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssSolve: %d\n", ierr);
        return ierr;
    }

    // Output result x
    if (rank == 0) {
        printf("Result of first factorize and solve:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf ", x[i]);
        }
        printf("\n");
    }

    // Set new values of A
    double a1[14] = {2.0, 3.0, -3.0, 4.0, 4.0, 5.0, -5.0, 6.0, 5.0, -7.0, 4.0, 8.0, 3.0, 9.0};
    KmlScasolverMatrixSetValue(A, a1);

    // Factorize with new values
    ierr = KmlScadssFactorize(solver, A, &optFact, &scaOptFact);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssFactorize: %d\n", ierr);
        return ierr;
    }

    // Set new values of B
    double b1[8] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
    KmlScasolverMatrixSetValue(B, b1);

    // Solve with new values
    ierr = KmlScadssSolve(solver, B, X, &optSolve, &scaOptSolve);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssSolve: %d\n", ierr);
        return ierr;
    }

    // Output new result x
    if (rank == 0) {
        printf("Result of second factorize and solve:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf ", x[i]);
        }
        printf("\n");
    }

    // Query
    KmlDssInfo info;
    info.fieldMask = KMLDSS_INFO_PEAK_MEM;
    KmlScadssInfo scaInfo;
    scaInfo.fieldMask = 0;
    ierr = KmlScadssQuery(solver, &info, &scaInfo);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssQuery: %d\n", ierr);
        return ierr;
    }
    if (rank == 0) {
        printf("Peak memory is %ld Byte\n", info.peakMem);
    }

    // Destroy
    KmlScadssClean(&solver);

    KmlScasolverMatrixDestroy(&A);
    KmlScasolverMatrixDestroy(&B);
    KmlScasolverMatrixDestroy(&X);

    return 0;
}

int main(int argc, char **argv)
{
    MPI_Init(&argc, &argv);
    Run(MPI_COMM_WORLD);
    MPI_Finalize();
    return 0;
}

运行结果:

Result of first factorize and solve:
1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 
Result of second factorize and solve:
0.666667 -0.100000 0.226415 -0.200000 0.166667 0.018868 0.175000 -0.111111 
Peak memory is 102376 Byte