鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

P?GESV

使用部分选主元的LU分解算法求解线性方程组Ax=B,其中A是N*N的分布式子矩阵,B是具有NRHS个向量的右端项矩阵。

接口定义

C Interface:

void psgesv_(const int *n, const int *nrhs, float *a, const int *ia, const int *ja, const int *desca, int *ipiv, float *b, const int *ib, const int *jb, const int *descb, int *info);

void pdgesv_(const int *n, const int *nrhs, double *a, const int *ia, const int *ja, const int *desca, int *ipiv, double *b, const int *ib, const int *jb, const int *descb, int *info);

void pcgesv_(const int *n, const int *nrhs, float _Complex *a, const int *ia, const int *ja, const int *desca, int *ipiv, float _Complex *b, const int *ib, const int *jb, const int *descb, int *info);

void pzgesv_(const int *n, const int *nrhs, double _Complex *a, const int *ia, const int *ja, const int *desca, int *ipiv, double _Complex *b, const int *ib, const int *jb, const int *descb, int *info);

Fortran Interface:

PSGESV(n, nrhs, a, ia, ja, desca, ipiv, b, ib, jb, descb, info)

PDGESV(n, nrhs, a, ia, ja, desca, ipiv, b, ib, jb, descb, info)

PCGESV(n, nrhs, a, ia, ja, desca, ipiv, b, ib, jb, descb, info)

PZGESV(n, nrhs, a, ia, ja, desca, ipiv, b, ib, jb, descb, info)

参数

参数

类型

范围

说明

输入/输出

n

整型

全局

矩阵的行数和列数。

输入

nrhs

整型

全局

右侧的数量,即分布式子矩阵子(B)和X的列数。

输入

a

  • 在psgesv中为单精度浮点型数组。
  • 在pdgesv中为双精度浮点型数组。
  • 在pcgesv中为单精度复数型数组。
  • 在pzgesv中为双精度复数型数组。

本地

  • 调用前保存分布式矩阵A的本地M*N部分。
  • 调用后保存本地部分存放的分解结果L和U,不保存L的对角线元素(均为1)。

输入,输出

ia

整型

全局

子矩阵A在全局矩阵中的行索引。

输入

ja

整型

全局

子矩阵A在全局矩阵中的列索引。

输入

desca

整型数组

本地,全局

分布式矩阵A的矩阵描述符。

输入

ipiv

整型

本地

包含了主元及交换信息。

输出

b

  • 在psgesv中为单精度浮点型数组。
  • 在pdgesv中为双精度浮点型数组。
  • 在pcgesv中为单精度复数型数组。
  • 在pzgesv中为双精度复数型数组。

本地

  • 调用前保存右端项
  • 调用后保存求解结果

输入,输出

ib

整型

全局

子矩阵B在全局矩阵中的行索引。

输入

jb

整型

全局

子矩阵B在全局矩阵中的列索引。

输入

descb

整型数组

本地,全局

分布式矩阵B的矩阵描述符。

输入

info

整型

全局

  • 等于0:表示成功。
  • 小于0:info=-i,表示第i个参数非法。
  • 大于0:算法出错。

输出

依赖

#include <kscalapack.h>

示例

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    int izero=0;
    int ione=1;
    int myrank_mpi, nprocs_mpi;
    MPI_Init( &argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &myrank_mpi);
    MPI_Comm_size(MPI_COMM_WORLD, &nprocs_mpi);
 
    int n = 8;       // (Global) Matrix size
    int nprow = 2;   // Number of row procs
    int npcol = 2;   // Number of column procs
    int nb = 4;      // (Global) Block size
    char uplo='L';   // Matrix is lower triangular
    char layout='R'; // Block cyclic, Row major processor mapping
    int nrhs = 1;
 
    printf("Usage: ./test matrix_size block_size nprocs_row nprocs_col\n");
 
    if(argc > 1) {
        n = atoi(argv[1]);
    }
    if(argc > 2) {
        nb = atoi(argv[2]);
    }
    if(argc > 3) {
        nprow = atoi(argv[3]);
    }
    if(argc > 4) {
        npcol = atoi(argv[4]);
    }
    assert(nprow * npcol == nprocs_mpi);
    // Initialize BLACS
    int iam, nprocs;
    int zero = 0;
    int ictxt, myrow, mycol;
    blacs_pinfo_(&iam, &nprocs) ; // BLACS rank and world size
    blacs_get_(&zero, &zero, &ictxt ); // -> Create context
    blacs_gridinit_(&ictxt, &layout, &nprow, &npcol ); // Context -> Initialize the grid
    blacs_gridinfo_(&ictxt, &nprow, &npcol, &myrow, &mycol ); // Context -> Context grid info (# procs row/col, current procs row/col)
 
    // Compute the size of the local matrices
    int mpA    = numroc_( &n, &nb, &myrow, &izero, &nprow ); // My proc -> row of local A
    int nqA    = numroc_( &n, &nb, &mycol, &izero, &npcol ); // My proc -> col of local A
    int mpB    = numroc_( &n, &nb, &myrow, &izero, &nprow );
    ofstream f1;
    string filename = to_string(myrank_mpi)+"Abegin.dat";
    f1.open(filename);
    double *A;
    A = (double *)calloc(mpA*nqA,sizeof(double)) ;
    if (A==NULL){ printf("Error of memory allocation A on proc %dx%d\n",myrow,mycol); exit(0); }
    int k = 0;
    for (int j = 0; j < nqA; j++) { // local col
        int l_j = j / nb; // which block
        int x_j = j % nb; // where within that block
        int J   = (l_j * npcol + mycol) * nb + x_j; // global col
        for (int i = 0; i < mpA; i++) { // local row
            int l_i = i / nb; // which block
            int x_i = i % nb; // where within that block
            int I   = (l_i * nprow + myrow) * nb + x_i; // global row
            assert(I < n);
            assert(J < n);
            if(I == J) {
                A[k] = 2*n + 1.5  +  (rand())%10;
            } else {
                A[k] = i + j + rand()% 10;
            }
            //printf("%d %d -> %d %d -> %f\n", i, j, I, J, A[k]);
            f1 <<I << " "<<J << " " << A[k]<<endl;
            k++;
        }
    }
    f1.close();
    
    //creat descriptor
    int descA[9];
    int info=0;
    int ipiv[10] = {0};
    int lddB = mpB > 1 ? mpB : 1;
    descinit_( descA,  &n, &n, &nb, &nb, &izero, &izero, &ictxt, &lddB, &info);
    if(info != 0) {
        printf("Error in descinit, info = %d\n", info);
    }
    
    filename = to_string(myrank_mpi)+"Bbegin.dat";
    f1.open(filename);
    double *B;
    B = (double *)calloc(mpA,sizeof(double)) ;
    if (A==NULL){ printf("Error of memory allocation A on proc %dx%d\n",myrow,mycol); exit(0); }
    k = 0;
    for (int j = 0; j < mpB; j++) { // local col
        int l_i = j / nb; // which block
        int x_i = j % nb; // where within that block
        int I   = (l_i * nprow + myrow) * nb + x_i; // global row
       B[j] = j + 1.5  +  (rand())%10;
        f1 <<I << " " << B[j]<<endl;
    }  
    f1.close();
    int descB[9];
    int nbrhs=1;    
    descinit_( descB,  &n, &nrhs, &nb, &nbrhs, &izero, &izero, &ictxt, &lddB, &info); // nbrhs need to be revised when nrhs!=1
        
 
 
    //run pdpotrf_ and time
    double MPIt1 = MPI_Wtime();
    printf("[%dx%d] Starting \n", myrow, mycol);
    pdgesv_(&n, &nrhs, A, &ione, &ione, descA, ipiv, B, &ione, &ione, descB, &info);
    if (info != 0) {
        printf("Error in calculate, info = %d\n", info);
    }
    filename = to_string(myrank_mpi)+"Bend.dat"; 
    f1.open(filename);
    for (int j = 0; j < mpB; j++) {
        int l_i = j / nb; // which block
        int x_i = j % nb; // where within that block
        int I   = (l_i * nprow + myrow) * nb + x_i; // global row
       f1 <<I<< " " << B[j]<<endl;
    }
    f1.close();
 
    double MPIt2 = MPI_Wtime();
    printf("[%dx%d] Done, time %e s.\n", myrow, mycol, MPIt2 - MPIt1);
    filename = to_string(myrank_mpi)+"end.dat";
    f1.open(filename);
    k = 0;
    for (int j = 0; j < nqA; j++) { // local col
        int l_j = j / nb; // which block
        int x_j = j % nb; // where within that block
        int J   = (l_j * npcol + mycol) * nb + x_j; // global col
        for (int i = 0; i < mpA; i++) { // local row
            int l_i = i / nb; // which block
            int x_i = i % nb; // where within that block
            int I   = (l_i * nprow + myrow) * nb + x_i; // global row
            assert(I < n);
            assert(J < n);
            f1 <<I << " "<<J << " " << A[k]<<endl;
            k++;
        }
    }
    f1.close();
    free(A);
    /*
    origin A:
[[20.500000  4.000000 11.000000  3.000000  3.000000  4.000000 11.000000 3.000000]
 [ 7.000000 22.500000  4.000000 13.000000  7.000000  7.000000  4.000000 13.000000]
 [ 9.000000  9.000000 19.500000  8.000000  9.000000  9.000000  6.000000 8.000000]
 [ 8.000000  6.000000 12.000000 23.500000  8.000000  6.000000 12.000000 12.000000]
 [ 3.000000  4.000000 11.000000  3.000000 20.500000  4.000000 11.000000 3.000000]
 [ 7.000000  7.000000  4.000000 13.000000  7.000000 22.500000  4.000000 13.000000]
 [ 9.000000  9.000000  6.000000  8.000000  9.000000  9.000000 19.500000 8.000000]
 [ 8.000000  6.000000 12.000000 12.000000  8.000000  6.000000 12.000000 23.500000]]
origin B:
[[ 1.500000]
 [ 8.500000]
 [ 5.500000]
 [10.500000]
 [ 1.500000]
 [ 8.500000]
 [ 5.500000]
 [10.500000]]
X:
[[-0.073846]
 [ 0.069101]
 [ 0.047280]
 [ 0.273735]
 [-0.073846]
 [ 0.069101]
 [ 0.047280]
 [ 0.273735]]
    */