


#include "kml_scaiss.h"


C Interface:
#include <stdio.h>
#include <kml_scaiss.h>
#include "mpi.h"
// USER data struct
typedef struct {
    int n;
    const int *ia;
    const int *ja;
    const double *a;
    MPI_Comm comm;
} duser;
int diagonal_preconditioner(void *usr, double *x)
    duser *u = (duser *)usr;
    int i, j;
    /* Apply diagonal preconditioner */
    for (i = 0; i < u->n; i++) {
        for (j = u->ia[i]; j < u->ia[i + 1]; j++) {
            if (u->ja[j] == i) {
                x[i] /= u->a[j];
    return 0;
int user_spmv(void *usr, const double *x, double *y)
    duser *u = (duser *)usr;
    int size;
    int rank;
    MPI_Comm_size(u->comm, &size);
    MPI_Comm_rank(u->comm, &rank);
    int n = 8 / size;
    double fullX[8] = { 0.0 };
    MPI_Allgather(&x[0], n, MPI_DOUBLE, &fullX[0], n, MPI_DOUBLE, u->comm);
    int i, j;
    for (i = 0; i < u->n; i++) {
        double sum = 0.0;
        for (j = u->ia[i]; j < u->ia[i + 1]; j++) {
            int k = u->ja[j];
            double value = u->a[j];
            sum += fullX[k] * value;
        y[i] = sum;
    return 0;
int main(void)
    /* MPI initialization */
    MPI_Init(NULL, NULL);
    int size, rank;
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    /* Matrix data (CSR, stored full matrix)
        |  1    *    *    1    2    *    *    *  |
        |  *    9    2    1    *   -3    *    *  |
        |  *    2    3    *    *    *    *    2  |
        |  1    1    *    9    *    *   -5    *  |
        |  2    *    *    *    6    1    *    *  |
        |  *   -3    *    *    1    4    *    1  |
        |  *    *    *   -5    *    *    7    *  |
        |  *    *    2    *    *    1    *    2  |
    /* Initialize separations */
    int mat_size = 8;
    int n = mat_size / size;
    int n_beg = n * rank;
    if (n * size != 8 && rank == (size - 1)) {
        n = mat_size - n * rank;
    int ia[9] = { 0, 3, 7, 10, 14, 17, 21, 23, 26 };
    int a_beg = ia[n_beg];
    for (int i = n_beg; i < (n_beg + n + 1); i++) {
        ia[i] -= a_beg;
    /* clang-format off */
    int ja[26] = { 0,       3, 4,
                      1, 2, 3,    5,
                      1, 2,             7,
                   0, 1,    3,       6,
                   0,          4, 5,
                      1,       4, 5,    7,
                            3,       6,
                        2,        5,    7 };
    double a[26] = { 1.0,           1.0, 2.0,
                          9.0, 2.0, 1.0,     -3.0,
                          2.0, 3.0,                     2.0,
                     1.0, 1.0,      9.0,          -5.0,
                     2.0,                6.0, 1.0,
                          -3.0,          1.0, 4.0,      1.0,
                                    -5.0,          7.0,
                                2.0,          1.0,      2.0 };
    /* clang-format on */
    /* Right-hand side vector */
    double b[8] = { 4.0, 9.0, 7.0, 6.0, 9.0, 3.0, 2.0, 5.0 };
    /* Internal KML_SCAISS structure */
    KmlScasolverTask *handle;
    /* KML_SCAISS control parameters */
    int nrhs = 1;         /* Number of right-hand sides */
    int ldx = n, ldb = n; /*!Leading dimension of B and X */
    int error;            /* Output error handle */
    /* Create data structures */
    const double *a_holder = &a[a_beg];
    const int *ja_holder = &ja[a_beg];
    const int *ia_holder = &ia[n_beg];
    error = KmlScaissGmresInitStripesDI(&handle, mat_size, 1, &n, &n_beg, &a_holder, &ja_holder, &ia_holder,
    if (error != 0) {
        printf("ERROR in KmlScaissGmresInitStripesDI: %d\n", error);
        return 1;
    error = KmlScaissGmresAnalyzeDI(&handle);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresAnalyzeDI: %d\n", error);
        return 1;
    error = KmlScaissGmresFactorizeDI(&handle);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresFactorizeDI: %d\n", error);
        return 1;
    /* solve */
    double x7[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDI(&handle, nrhs, &x7[0], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("solve, x:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf\n", x7[i]);
    /* solveDx */
    double x0[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDxDI(&handle, nrhs, &x0[n_beg], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("solveDx, x:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf\n", x0[i]);
    /* L2 norm */
    int NormType = KMLSS_L2;
    error = KmlScaissGmresSetDII(&handle, KMLSS_VECTOR_NORM_TYPE, &NormType, 1);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSetDII: %d\n", error);
        return 1;
    /* Solve */
    double x2[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDI(&handle, nrhs, &x2[0], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("L2 norm, x:\n");
        for (int i = 0; i < mat_size; i++) {
            printf("%lf\n", x2[i]);
    /* SolveDx */
    double x6[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDxDI(&handle, nrhs, &x6[n_beg], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("L2 norm, x:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf\n", x6[i]);
    /* abs residual */
    double res = 1e-10;
    error = KmlScaissGmresSetDID(&handle, KMLSS_ABS_TOLERANCE, &res, 1);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSetDIA: %d\n", error);
        return 1;
    /* Solve */
    double x3[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDI(&handle, nrhs, &x3[0], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("abs residual, x:\n");
        for (int i = 0; i < mat_size; i++) {
            printf("%lf\n", x3[i]);
    /* relative residual */
    res = 1e-10;
    error = KmlScaissGmresSetDID(&handle, KMLSS_THRESHOLD, &res, 1);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSetDIA: %d\n", error);
        return 1;
    /* Solve */
    double x4[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    error = KmlScaissGmresSolveDI(&handle, nrhs, &x4[0], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("relative residual, x:\n");
        for (int i = 0; i < mat_size; i++) {
            printf("%lf\n", x4[i]);
    /* Second call with preconditioner */
    double x5[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    duser prea;
    prea.n = n;
    prea.ia = &ia[n_beg];
    prea.ja = &ja[a_beg];
    prea.a = &a[a_beg];
    /* Set user preconditioner */
    error = KmlScaissGmresSetUserPreconditionerDI(&handle, &prea, &diagonal_preconditioner);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSetUserPreconditionerDI: %d\n", error);
        return 1;
    /* Solve */
    error = KmlScaissGmresSolveDI(&handle, nrhs, &x5[0], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Finalize and Clean-up */
    error = KmlScaissGmresCleanDI(&handle);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresCleanDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("user preconditioner, x:\n");
        for (int i = 0; i < mat_size; i++) {
            printf("%lf\n", x5[i]);
    /* Second call with user spmv */
    KmlScasolverTask *handle2;
    double x8[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
    duser prea2;
    prea2.n = n;
    prea2.ia = &ia[n_beg];
    prea2.ja = &ja[a_beg];
    prea2.a = &a[a_beg];
    prea2.comm = MPI_COMM_WORLD;
    error = KmlScaissGmresInitWithoutMatDI(&handle2, 1, &n, MPI_COMM_WORLD);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresInitWithouMatDI: %d\n", error);
        return 1;
    /* Set user spmv */
    error = KmlScaissGmresSetUserSpmvDI(&handle2, &prea2, &user_spmv);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSetUserSpmvDI: %d\n", error);
        return 1;
    /* Solve */
    error = KmlScaissGmresSolveDxDI(&handle2, nrhs, &x8[n_beg], ldx, &b[n_beg], ldb);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresSolveDI: %d\n", error);
        return 1;
    /* Finalize and Clean-up */
    error = KmlScaissGmresCleanDI(&handle2);
    if (error != 0) {
        printf("ERROR in KmlScaissGmresCleanDI: %d\n", error);
        return 1;
    /* Print the solution */
    if (rank == 0) {
        printf("user spmv, x:\n");
        for (int i = 0; i < n; i++) {
            printf("%lf\n", x8[i]);
    return 0;