C Interface:#include <stdio.h>
#include <kml_scaiss.h>
#include "mpi.h"
// USER data struct
typedef struct {
int n;
const int *ia;
const int *ja;
const double *a;
MPI_Comm comm;
} duser;
int user_spmv(void *usr, const double *x, double *y)
{
duser *u = (duser *)usr;
int size;
int rank;
MPI_Comm_size(u->comm, &size);
MPI_Comm_rank(u->comm, &rank);
int n = 8 / size;
double fullX[8] = { 0.0 };
MPI_Allgather(&x[0], n, MPI_DOUBLE, &fullX[0], n, MPI_DOUBLE, u->comm);
int i, j;
for (i = 0; i < u->n; i++) {
double sum = 0.0;
for (j = u->ia[i]; j < u->ia[i + 1]; j++) {
int k = u->ja[j];
double value = u->a[j];
sum += fullX[k] * value;
}
y[i] = sum;
}
return 0;
}
int main(void)
{
/* MPI initialization */
MPI_Init(NULL, NULL);
int size, rank;
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
/* Matrix data (CSR, stored full matrix)
| 1 * * 1 2 * * * |
| * 9 2 1 * -3 * * |
| * 2 3 * * * * 2 |
| 1 1 * 9 * * -5 * |
| 2 * * * 6 1 * * |
| * -3 * * 1 4 * 1 |
| * * * -5 * * 7 * |
| * * 2 * * 1 * 2 |
*/
/* Initialize separations */
int mat_size = 8;
int n = mat_size / size;
int n_beg = n * rank;
if (n * size != 8 && rank == (size - 1)) {
n = mat_size - n * rank;
}
int ia[9] = { 0, 3, 7, 10, 14, 17, 21, 23, 26 };
int a_beg = ia[n_beg];
for (int i = n_beg; i < (n_beg + n + 1); i++) {
ia[i] -= a_beg;
}
/* clang-format off */
int ja[26] = { 0, 3, 4,
1, 2, 3, 5,
1, 2, 7,
0, 1, 3, 6,
0, 4, 5,
1, 4, 5, 7,
3, 6,
2, 5, 7 };
double a[26] = { 1.0, 1.0, 2.0,
9.0, 2.0, 1.0, -3.0,
2.0, 3.0, 2.0,
1.0, 1.0, 9.0, -5.0,
2.0, 6.0, 1.0,
-3.0, 1.0, 4.0, 1.0,
-5.0, 7.0,
2.0, 1.0, 2.0 };
/* clang-format on */
/* Right-hand side vector */
double b[8] = { 4.0, 9.0, 7.0, 6.0, 9.0, 3.0, 2.0, 5.0 };
/* Solution vector */
double x[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
/* Internal KML_SCAISS structure */
KmlScasolverTask *handle;
/* KML_SCAISS control parameters */
int nrhs = 1; /* Number of right-hand sides */
int ldx = n, ldb = n; /*!Leading dimension of B and X */
int error; /* Output error handle */
/* Create data structures */
const double *a_holder = &a[a_beg];
const int *ja_holder = &ja[a_beg];
const int *ia_holder = &ia[n_beg];
error = KmlScaissCsiInitStripesDI(&handle, mat_size, 1, &n, &n_beg, &a_holder, &ja_holder, &ia_holder,
MPI_COMM_WORLD);
if (error != 0) {
printf("ERROR in KmlScaissCsiInitStripesDI: %d\n", error);
return 1;
}
double eigs[2] = { 0.061645, 13.350966 };
error = KmlScaissCsiSetDID(&handle, KMLSS_SPECTRUM_BOUNDS, &eigs[0], 2);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDID: %d\n", error);
return 1;
}
error = KmlScaissCsiAnalyzeDI(&handle);
if (error != 0) {
printf("ERROR in KmlScaissCsiAnalyzeDI: %d\n", error);
return 1;
}
error = KmlScaissCsiFactorizeDI(&handle);
if (error != 0) {
printf("ERROR in KmlScaissCsiFactorizeDI: %d\n", error);
return 1;
}
/* solve */
double x7[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDI(&handle, nrhs, &x7[0], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("solve, x:\n");
for (int i = 0; i < n; i++) {
printf("%lf\n", x7[i]);
}
}
/* solveDx */
double x0[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDxDI(&handle, nrhs, &x0[n_beg], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("solveDx, x:\n");
for (int i = 0; i < n; i++) {
printf("%lf\n", x0[i]);
}
}
/* L1 norm */
int NormType = KMLSS_L1;
error = KmlScaissCsiSetDII(&handle, KMLSS_VECTOR_NORM_TYPE, &NormType, 1);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDII: %d\n", error);
return 1;
}
/* Solve */
error = KmlScaissCsiSolveDI(&handle, nrhs, &x[0], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("L1 norm, x:\n");
for (int i = 0; i < mat_size; i++) {
printf("%lf\n", x[i]);
}
}
/* SolveDx */
double x6[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDxDI(&handle, nrhs, &x6[n_beg], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("L2 norm, x:\n");
for (int i = 0; i < n; i++) {
printf("%lf\n", x6[i]);
}
}
/* L2 norm */
NormType = KMLSS_L2;
error = KmlScaissCsiSetDII(&handle, KMLSS_VECTOR_NORM_TYPE, &NormType, 1);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDII: %d\n", error);
return 1;
}
/* Solve */
double x2[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDI(&handle, nrhs, &x2[0], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("L2 norm, x:\n");
for (int i = 0; i < mat_size; i++) {
printf("%lf\n", x2[i]);
}
}
/* abs residual */
double res = 1e-10;
error = KmlScaissCsiSetDID(&handle, KMLSS_ABS_TOLERANCE, &res, 1);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDIA: %d\n", error);
return 1;
}
/* Solve */
double x3[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDI(&handle, nrhs, &x3[0], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("abs residual, x:\n");
for (int i = 0; i < mat_size; i++) {
printf("%lf\n", x3[i]);
}
}
/* relative residual */
res = 1e-10;
error = KmlScaissCsiSetDID(&handle, KMLSS_THRESHOLD, &res, 1);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDIA: %d\n", error);
return 1;
}
/* Solve */
double x4[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
error = KmlScaissCsiSolveDI(&handle, nrhs, &x4[0], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("relative residual, x:\n");
for (int i = 0; i < mat_size; i++) {
printf("%lf\n", x4[i]);
}
}
/* Second call with user spmv */
double x8[8] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
duser prea;
prea.n = n;
prea.ia = &ia[n_beg];
prea.ja = &ja[a_beg];
prea.a = &a[a_beg];
prea.comm = MPI_COMM_WORLD;
error = KmlScaissCsiInitWithoutMatDI(&handle, 1, &n, MPI_COMM_WORLD);
if (error != 0) {
printf("ERROR in KmlScaissCsiInitWithouMatDI: %d\n", error);
return 1;
}
error = KmlScaissCsiSetDID(&handle, KMLSS_SPECTRUM_BOUNDS, &eigs[0], 2);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetDID: %d\n", error);
return 1;
}
/* Set user spmv */
error = KmlScaissCsiSetUserSpmvDI(&handle, &prea, &user_spmv);
if (error != 0) {
printf("ERROR in KmlScaissCsiSetUserSpmvDI: %d\n", error);
return 1;
}
/* Solve */
error = KmlScaissCsiSolveDxDI(&handle, nrhs, &x8[n_beg], ldx, &b[n_beg], ldb);
if (error != 0) {
printf("ERROR in KmlScaissCsiSolveDxDI: %d\n", error);
return 1;
}
/* Finalize and Clean-up */
error = KmlScaissCsiCleanDI(&handle);
if (error != 0) {
printf("ERROR in KmlScaissCsiCleanDI: %d\n", error);
return 1;
}
/* Print the solution */
if (rank == 0) {
printf("user spmv, x:\n");
for (int i = 0; i < n; i++) {
printf("%lf\n", x8[i]);
}
}
MPI_Finalize();
return 0;
}