鲲鹏社区首页
中文
注册
开发者
我要评分
获取效率
正确性
完整性
易理解
在线提单
论坛求助

示例

#include <omp.h>
#include <mpi.h>
#include <cstdio>
#include "kuqcd.h"

using namespace kuqcd;

template<typename T, int Nbatch, KuQCDReconstructType ReconS, KuQCDReconstructType ReconN>
void TestCGMPC(const KuQCDGaugeOrder gauge_order);


int main (int argc, char** argv)
{
    int provided = 0;
    int required = MPI_THREAD_FUNNELED;
    int flag = MPI_Init_thread(&argc, &argv, required, &provided);

    TestCGMPC<float, 1, KUQCD_GAUGE_RECON_NO, KUQCD_GAUGE_RECON_NO>(KUQCD_NORMAL_ORDER);
    MPI_Finalize();
    return 0;
}


template<typename T, int Nbatch, KuQCDReconstructType ReconS, KuQCDReconstructType ReconN>
void TestCGMPC(const KuQCDGaugeOrder gauge_order)
{
    int repos = 1;
    int nx = 4, ny = 4, nz = 4, nt = 1, seed = 1;
    const int num_shift = 1;
    std::array<unsigned short, 4> gridDim{1, 1, 1, 1};
    KuQCDRankOrder od = KUQCD_TZYX_ORDER;
    KuQCDSiteSubset site_subset = KUQCD_ODD;
    int haloDepthSpin = 0;
    int haloDepthGauge = 3;

    Lattice latSup(nx, ny, nz, nt);

    KuQCDPrecision prec = (sizeof(T)==sizeof(double)) ? KUQCD_DOUBLE : KUQCD_SINGLE;

    rootLogger.Info("initializing spinor field lattice base");
    LattBase *latBaseOdd = nullptr;
    KuQCDGaugeParam Gagugeoptions;
    Gagugeoptions.lat[0] = nx;
    Gagugeoptions.lat[1] = ny;
    Gagugeoptions.lat[2] = nz;
    Gagugeoptions.lat[3] = nt;
    Gagugeoptions.gauge_order = gauge_order;
    Gagugeoptions.od = od;
    Gagugeoptions.prec = prec;
    site_subset = KUQCD_ALL;
    Gagugeoptions.site_subset = site_subset;
    Gagugeoptions.halo_depth = haloDepthGauge;


    double shifts[num_shift];
    shifts[0] = 3.67167631338816e-05;
    FILE* fp;
    float* origin;
    int np = GetNp();
    int rk = GetRank();

    size_t vol = (size_t)nx *ny * nt *nz / Nbatch / (np);
    size_t volh = vol / 2;
    size_t spinorLen = volh* Nbatch * 6;
    size_t gaugeLen = vol* Nbatch * 4 * ReconS;
    size_t gaugeLenNaik = vol* Nbatch * 4 * ReconN;

    printf("vol is %zu, spinorLen %zu, gaugeLen %zu\n",vol,spinorLen,gaugeLen);
    origin = (float*)malloc(((num_shift + 1)*spinorLen+gaugeLen+gaugeLenNaik)*sizeof(float));
    char name[100];

    for(int i = 0; i<(num_shift + 1)*spinorLen+gaugeLen+gaugeLenNaik;i++) {
        origin[i] = 1.0;
    }
    kuqcd_init();
    float *originSrc = origin;
    float *origingaugeSmeared = &(origin[(num_shift+1)*spinorLen]);
    float *origingaugeNaik = &(origin[(num_shift+1)*spinorLen+gaugeLen]);

    rootLogger.Info("SpinorFieldSetValues1 done");
    Gagugeoptions.link_type = KUQCD_FAT_LINK;
    Gagugeoptions.recon = ReconS;
    kuqcd_load_gaugefield((void*)origingaugeSmeared, &Gagugeoptions);
    Gagugeoptions.link_type = KUQCD_LONG_LINK;
    Gagugeoptions.recon = ReconN;
    kuqcd_load_gaugefield((void*)origingaugeNaik, &Gagugeoptions);

    rootLogger.Info("initializing SetValues done");

    KuQCDInvertParam invertOpts;
    invertOpts.dslash_type = KUQCD_HISQPC_DIRAC;
    invertOpts.inverter_type = KUQCD_CGM;
    invertOpts.matpc_type = KUQCD_ODD_ODD;
    int iter = 0;
    invertOpts.num_shift = num_shift;
    invertOpts.max_iter = 6500;
    invertOpts.tol = 1e-12;
    invertOpts.shifts = shifts;
    invertOpts.prec =prec;
    void *dst[num_shift] = {nullptr};
    float *dst_value[num_shift];
    for(int i = 0;i<num_shift;i++) {
        dst_value[i] = (float*)malloc(spinorLen * sizeof(float));
        dst[i] = (void *)(dst_value[i]);
    }

    double tStart = MPI_Wtime();

    kuqcd_solve_multishift(dst, (void *)originSrc, &invertOpts);
    double tEnd = MPI_Wtime();
    Rank0Printf("done\n");
    fflush(0);
    Rank0Printf("total time =  %e [s]\n", tEnd - tStart);
    Rank0Printf("time/iter  =  %e [s]\n", (tEnd - tStart) / invertOpts.iter);
    fflush(stdout);



}