开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

示例2

使用schur补对矩阵进行求解:

#include <stdio.h>
#include <stdlib.h>
#include "kml_solver.h"
void dgesv_(const int *n, const int *nrhs, double *a, const int *lda, int *ipiv, double *b, const int *ldb,
    int *info);
int main()
{
    int ierr;
    int n = 8;
    int nrhs = 1;
    // Create matrix A
    int ia[9] = {0, 2, 4, 6, 7, 8, 10, 12, 14};
    int ja[14] = {0, 7, 1, 6, 2, 5, 3, 4, 2, 5, 1, 6, 0, 7};
    double a[14] = {1.0, 2.0, -2.0, 3.0, 3.0, 4.0, -4.0, 5.0, 4.0, -6.0, 3.0, 7.0, 2.0, 8.0};
    KmlSolverMatrixStore storeA;
    storeA.indexType = KMLSS_INDEX_INT32;
    storeA.valueType = KMLSS_VALUE_FP64;
    storeA.nRow = n;
    storeA.nCol = n;
    storeA.format = KMLSS_MATRIX_STORE_CSR;
    storeA.csr.rowOffset = ia;
    storeA.csr.colIndex = ja;
    storeA.csr.value = a;
    KmlSolverMatrixOption optA;
    optA.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optA.type = KMLSS_MATRIX_GEN;
    KmlSolverMatrix *A;
    ierr = KmlSolverMatrixCreate(&A, &storeA, &optA);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create A: %d\n", ierr);
        return 1;
    }
    // Create vector b
    double b[8] = {3.0, 1.0, 7.0, -4.0, 5.0, -2.0, 10.0, 10.0};
    KmlSolverMatrixStore storeB;
    storeB.indexType = KMLSS_INDEX_INT32;
    storeB.valueType = KMLSS_VALUE_FP64;
    storeB.nRow = n;
    storeB.nCol = nrhs;
    storeB.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
    storeB.dense.value = b;
    storeB.dense.ld = n;
    KmlSolverMatrixOption optB;
    optB.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optB.type = KMLSS_MATRIX_GEN;
    KmlSolverMatrix *B;
    ierr = KmlSolverMatrixCreate(&B, &storeB, &optB);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create b: %d\n", ierr);
        return 1;
    }
    // Create vector x
    double x[8] = {0};
    KmlSolverMatrixStore storeX;
    storeX.indexType = KMLSS_INDEX_INT32;
    storeX.valueType = KMLSS_VALUE_FP64;
    storeX.nRow = n;
    storeX.nCol = nrhs;
    storeX.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
    storeX.dense.value = x;
    storeX.dense.ld = n;
    KmlSolverMatrixOption optX;
    optX.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
    optX.type = KMLSS_MATRIX_GEN;
    KmlSolverMatrix *X;
    ierr = KmlSolverMatrixCreate(&X, &storeX, &optX);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR when create x: %d\n", ierr);
        return 1;
    }
    // Init solver
    KmlDssInitOption opt;
    opt.fieldMask = KMLDSS_INIT_OPTION_BWR_MODE | KMLDSS_INIT_OPTION_NTHREADS;
    opt.bwrMode = KMLDSS_BWR_OFF;
    opt.nThreads = 32;
    KmlDssSolver *solver;
    ierr = KmlDssInit(&solver, &opt);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssInit: %d\n", ierr);
        return ierr;
    }
    // Analyze
    int schurSize = 2;
    KmlDssAnalyzeOption optAnalyze;
    optAnalyze.fieldMask = KMLDSS_ANALYZE_OPTION_MATCHING_TYPE |
                        KMLDSS_ANALYZE_OPTION_RDR_TYPE |
                        KMLDSS_ANALYZE_OPTION_NTHREADS_RDR |
                        KMLDSS_ANALYZE_OPTION_SCHUR_SIZE |
                        KMLDSS_ANALYZE_OPTION_SCHUR_FORMAT;
    optAnalyze.matchingType = KMLDSS_MATCHING_OFF;
    optAnalyze.rdrType = KMLDSS_RDR_KRDR;
    optAnalyze.nThreadsRdr = 8;
    optAnalyze.schurSize = schurSize;
    optAnalyze.schurFormat = KMLSS_MATRIX_STORE_CSR;
    ierr = KmlDssAnalyze(solver, A, &optAnalyze);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssAnalyze: %d\n", ierr);
        return ierr;
    }
    // Factorize
    KmlDssFactorizeOption optFact;
    optFact.fieldMask = KMLDSS_FACTORIZE_OPTION_PERTURBATION_THRESHOLD;
    optFact.perturbationThreshold = 1e-8;
    ierr = KmlDssFactorize(solver, A, &optFact);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssFactorize: %d\n", ierr);
        return ierr;
    }
    // Solve forward
    KmlDssSolveOption optSolve;
    optSolve.fieldMask = KMLDSS_SOLVE_OPTION_SOLVE_STAGE;
    optSolve.stage = KMLDSS_SOLVE_FORWARD;
    ierr = KmlDssSolve(solver, B, X, &optSolve);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssSolve: %d\n", ierr);
        return ierr;
    }
    // Solve schur
    double xTemp[8] = {0};
    KmlDssInfo info;
    info.fieldMask = KMLDSS_INFO_SCHUR_NNZ;
    ierr = KmlDssQuery(solver, &info);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssQuery: %d\n", ierr);
        return ierr;
    }
    int schurNnz = info.schurNnz;
    KmlSolverMatrixStore schurMat;
    schurMat.csr.value = malloc(sizeof(double) * schurNnz);
    schurMat.csr.rowOffset = malloc(sizeof(int) * (schurSize + 1));
    schurMat.csr.colIndex = malloc(sizeof(int) * schurNnz);
    info.fieldMask = KMLDSS_INFO_SCHUR_MAT;
    info.schurMat = &schurMat;
    ierr = KmlDssQuery(solver, &info);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssQuery: %d\n", ierr);
        return ierr;
    }
    double *denseA = (double *)malloc(sizeof(double) * schurSize * schurSize);
    int *ja_schur = (int *)(schurMat.csr.colIndex);
    int *ia_schur = (int *)(schurMat.csr.rowOffset);
    double *value_schur = (double *)(schurMat.csr.value);
    for (int i = 0; i < schurSize; i++) {
        for (int id = ia_schur[i]; id < ia_schur[i + 1]; id++) {
            int j = ja_schur[id];
            denseA[i + j * schurSize] = value_schur[id];
        }
    }
    // use lapack to solve scuhrmat
    int *ipiv = (int *)malloc(sizeof(int) * schurSize);
    int iinfo = 0;
    dgesv_(&schurSize, &nrhs, denseA, &schurSize, ipiv, &xTemp[n - schurSize], &n, &iinfo);
    free(denseA);
    free(ipiv);
    KmlSolverMatrixSetValue(B, xTemp);
    // Solve backward
    optSolve.stage = KMLDSS_SOLVE_BACKWARD;
    ierr = KmlDssSolve(solver, B, X, &optSolve);
    if (ierr != KMLSS_NO_ERROR) {
        printf("ERROR in KmlDssSolve: %d\n", ierr);
        return ierr;
    }
    // Output result x
    printf("Result of first factorize and solve:\n");
    for (int i = 0; i < n; i++) {
        printf("%lf ", x[i]);
    }
    printf("\n");
    // Destroy
    KmlDssClean(&solver);
    KmlSolverMatrixDestroy(&A);
    KmlSolverMatrixDestroy(&B);
    KmlSolverMatrixDestroy(&X);
    return 0;
}