示例
C Interface:
#include <stdio.h>
#include <kml_scadss.h>
int Run(MPI_Comm comm)
{
int ierr;
int rank;
MPI_Comm_rank(comm, &rank);
int n = 8;
int nrhs = 1;
// Create matrix A
int ia[9] = {0, 2, 4, 6, 7, 8, 10, 12, 14};
int ja[14] = {0, 7, 1, 6, 2, 5, 3, 4, 2, 5, 1, 6, 0, 7};
double a[14] = {1.0, 2.0, -2.0, 3.0, 3.0, 4.0, -4.0, 5.0, 4.0, -6.0, 3.0, 7.0, 2.0, 8.0};
KmlSolverMatrixStore storeA;
storeA.indexType = KMLSS_INDEX_INT32;
storeA.valueType = KMLSS_VALUE_FP64;
storeA.format = KMLSS_MATRIX_STORE_CSR;
if (rank == 0) {
storeA.nRow = n;
storeA.nCol = n;
storeA.csr.rowOffset = ia;
storeA.csr.colIndex = ja;
storeA.csr.value = a;
} else {
storeA.nRow = 0;
storeA.nCol = 0;
storeA.csr.rowOffset = nullptr;
storeA.csr.colIndex = nullptr;
storeA.csr.value = nullptr;
}
KmlSolverMatrixOption optA;
optA.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
optA.type = KMLSS_MATRIX_GEN;
KmlScasolverMatrixOption scaOptA;
if (rank == 0) {
scaOptA.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
KMLSS_MATRIX_OPTIONS_PARTITION;
scaOptA.partition.type = KMLSS_MATRIX_PARTITION_ROW;
scaOptA.globalNumRows = n;
scaOptA.globalNumCols = n;
scaOptA.partition.localBegin = 0;
} else {
scaOptA.fieldMask = 0;
}
KmlScasolverMatrix *A;
ierr = KmlScasolverMatrixCreate(&A, &storeA, &optA, &scaOptA);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR when create A: %d\n", ierr);
return 1;
}
// Create vector b
double b[8] = {3.0, 1.0, 7.0, -4.0, 5.0, -2.0, 10.0, 10.0};
KmlSolverMatrixStore storeB;
storeB.indexType = KMLSS_INDEX_INT32;
storeB.valueType = KMLSS_VALUE_FP64;
storeB.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
if (rank == 0) {
storeB.nRow = n;
storeB.nCol = nrhs;
storeB.dense.value = b;
storeB.dense.ld = n;
} else {
storeB.nRow = 0;
storeB.nCol = 0;
storeB.dense.value = nullptr;
storeB.dense.ld = 0;
}
KmlSolverMatrixOption optB;
optB.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
optB.type = KMLSS_MATRIX_GEN;
KmlScasolverMatrixOption scaOptB;
if (rank == 0) {
scaOptB.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
KMLSS_MATRIX_OPTIONS_PARTITION;
scaOptB.partition.type = KMLSS_MATRIX_PARTITION_ROW;
scaOptB.partition.localBegin = 0;
scaOptB.globalNumRows = n;
scaOptB.globalNumCols = nrhs;
} else {
scaOptB.fieldMask = 0;
}
KmlScasolverMatrix *B;
ierr = KmlScasolverMatrixCreate(&B, &storeB, &optB, &scaOptB);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR when create b: %d\n", ierr);
return 1;
}
// Create vector x
double x[8] = {0};
KmlSolverMatrixStore storeX;
storeX.indexType = KMLSS_INDEX_INT32;
storeX.valueType = KMLSS_VALUE_FP64;
storeX.format = KMLSS_MATRIX_STORE_DENSE_COL_MAJOR;
if (rank == 0) {
storeX.nRow = n;
storeX.nCol = nrhs;
storeX.dense.value = x;
storeX.dense.ld = n;
} else {
storeX.nRow = 0;
storeX.nCol = 0;
storeX.dense.value = nullptr;
storeX.dense.ld = 0;
}
KmlSolverMatrixOption optX;
optX.fieldMask = KMLSS_MATRIX_OPTION_TYPE;
optX.type = KMLSS_MATRIX_GEN;
KmlScasolverMatrixOption scaOptX;
if (rank == 0) {
scaOptX.fieldMask = KMLSS_MATRIX_OPTIONS_GLOBAL_NROWS |
KMLSS_MATRIX_OPTIONS_GLOBAL_NCOLS |
KMLSS_MATRIX_OPTIONS_PARTITION;
scaOptX.partition.type = KMLSS_MATRIX_PARTITION_ROW;
scaOptX.partition.localBegin = 0;
scaOptX.globalNumRows = n;
scaOptX.globalNumCols = nrhs;
} else {
scaOptX.fieldMask = 0;
}
KmlScasolverMatrix *X;
ierr = KmlScasolverMatrixCreate(&X, &storeX, &optX, &scaOptX);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR when create x: %d\n", ierr);
return 1;
}
// Init solver
KmlDssInitOption opt;
opt.fieldMask = KMLDSS_INIT_OPTION_BWR_MODE | KMLDSS_INIT_OPTION_NTHREADS;
opt.bwrMode = KMLDSS_BWR_OFF;
opt.nThreads = 32;
KmlScadssInitOption scaOpt;
scaOpt.fieldMask = KMLSCADSS_OPTIONS_COMM;
scaOpt.comm = comm;
KmlScadssSolver *solver;
ierr = KmlScadssInit(&solver, &opt, &scaOpt);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssInit: %d\n", ierr);
return ierr;
}
// Analyze
KmlDssAnalyzeOption optAnalyze;
optAnalyze.fieldMask = KMLDSS_ANALYZE_OPTION_MATCHING_TYPE | KMLDSS_ANALYZE_OPTION_RDR_TYPE |
KMLDSS_ANALYZE_OPTION_NTHREADS_RDR;
optAnalyze.matchingType = KMLDSS_MATCHING_OFF;
optAnalyze.rdrType = KMLDSS_RDR_KRDR;
optAnalyze.nThreadsRdr = 1;
KmlScadssAnalyzeOption scaOptAnalyze;
scaOptAnalyze.fieldMask = 0;
ierr = KmlScadssAnalyze(solver, A, &optAnalyze, &scaOptAnalyze);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssAnalyze: %d\n", ierr);
return ierr;
}
// Factorize
KmlDssFactorizeOption optFact;
optFact.fieldMask = KMLDSS_FACTORIZE_OPTION_PERTURBATION_THRESHOLD;
optFact.perturbationThreshold = 1e-8;
KmlScadssFactorizeOption scaOptFact;
scaOptFact.fieldMask = 0;
ierr = KmlScadssFactorize(solver, A, &optFact, &scaOptFact);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssFactorize: %d\n", ierr);
return ierr;
}
// Solve
KmlDssSolveOption optSolve;
optSolve.fieldMask = KMLDSS_SOLVE_OPTION_SOLVE_STAGE | KMLDSS_SOLVE_OPTION_REFINE_METHOD;
optSolve.stage = KMLDSS_SOLVE_ALL;
optSolve.refineMethod = KMLDSS_REFINE_OFF;
KmlScadssSolveOption scaOptSolve;
scaOptSolve.fieldMask = 0;
ierr = KmlScadssSolve(solver, B, X, &optSolve, &scaOptSolve);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssSolve: %d\n", ierr);
return ierr;
}
// Output result x
if (rank == 0) {
printf("Result of first factorize and solve:\n");
for (int i = 0; i < n; i++) {
printf("%lf ", x[i]);
}
printf("\n");
}
// Set new values of A
double a1[14] = {2.0, 3.0, -3.0, 4.0, 4.0, 5.0, -5.0, 6.0, 5.0, -7.0, 4.0, 8.0, 3.0, 9.0};
KmlScasolverMatrixSetValue(A, a1);
// Factorize with new values
ierr = KmlScadssFactorize(solver, A, &optFact, &scaOptFact);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssFactorize: %d\n", ierr);
return ierr;
}
// Set new values of B
double b1[8] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
KmlScasolverMatrixSetValue(B, b1);
// Solve with new values
ierr = KmlScadssSolve(solver, B, X, &optSolve, &scaOptSolve);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssSolve: %d\n", ierr);
return ierr;
}
// Output new result x
if (rank == 0) {
printf("Result of second factorize and solve:\n");
for (int i = 0; i < n; i++) {
printf("%lf ", x[i]);
}
printf("\n");
}
// Query
KmlDssInfo info;
info.fieldMask = KMLDSS_INFO_PEAK_MEM;
KmlScadssInfo scaInfo;
scaInfo.fieldMask = 0;
ierr = KmlScadssQuery(solver, &info, &scaInfo);
if (ierr != KMLSS_NO_ERROR) {
printf("ERROR in KmlDssQuery: %d\n", ierr);
return ierr;
}
if (rank == 0) {
printf("Peak memory is %ld Byte\n", info.peakMem);
}
// Destroy
KmlScadssClean(&solver);
KmlScasolverMatrixDestroy(&A);
KmlScasolverMatrixDestroy(&B);
KmlScasolverMatrixDestroy(&X);
return 0;
}
int main(int argc, char **argv)
{
MPI_Init(&argc, &argv);
Run(MPI_COMM_WORLD);
MPI_Finalize();
return 0;
}
运行结果:
Result of first factorize and solve: 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 Result of second factorize and solve: 0.666667 -0.100000 0.226415 -0.200000 0.166667 0.018868 0.175000 -0.111111 Peak memory is 102376 Byte
父主题: 求解器函数