示例
#include <omp.h>
#include <mpi.h>
#include <cstdio>
#include "kuqcd.h"
using namespace kuqcd;
template<typename T, int Nbatch, KuQCDReconstructType ReconS, KuQCDReconstructType ReconN>
void TestCGMPC(const KuQCDGaugeOrder gauge_order);
int main (int argc, char** argv)
{
int provided = 0;
int required = MPI_THREAD_FUNNELED;
int flag = MPI_Init_thread(&argc, &argv, required, &provided);
TestCGMPC<float, 1, KUQCD_GAUGE_RECON_NO, KUQCD_GAUGE_RECON_NO>(KUQCD_NORMAL_ORDER);
MPI_Finalize();
return 0;
}
template<typename T, int Nbatch, KuQCDReconstructType ReconS, KuQCDReconstructType ReconN>
void TestCGMPC(const KuQCDGaugeOrder gauge_order)
{
int repos = 1;
int nx = 4, ny = 4, nz = 4, nt = 1, seed = 1;
const int num_shift = 1;
std::array<unsigned short, 4> gridDim{1, 1, 1, 1};
KuQCDRankOrder od = KUQCD_TZYX_ORDER;
KuQCDSiteSubset site_subset = KUQCD_ODD;
int haloDepthSpin = 0;
int haloDepthGauge = 3;
Lattice latSup(nx, ny, nz, nt);
KuQCDPrecision prec = (sizeof(T)==sizeof(double)) ? KUQCD_DOUBLE : KUQCD_SINGLE;
rootLogger.Info("initializing spinor field lattice base");
LattBase *latBaseOdd = nullptr;
KuQCDGaugeParam Gagugeoptions;
Gagugeoptions.lat[0] = nx;
Gagugeoptions.lat[1] = ny;
Gagugeoptions.lat[2] = nz;
Gagugeoptions.lat[3] = nt;
Gagugeoptions.gauge_order = gauge_order;
Gagugeoptions.od = od;
Gagugeoptions.prec = prec;
site_subset = KUQCD_ALL;
Gagugeoptions.site_subset = site_subset;
Gagugeoptions.halo_depth = haloDepthGauge;
double shifts[num_shift];
shifts[0] = 3.67167631338816e-05;
FILE* fp;
float* origin;
int np = GetNp();
int rk = GetRank();
size_t vol = (size_t)nx *ny * nt *nz / Nbatch / (np);
size_t volh = vol / 2;
size_t spinorLen = volh* Nbatch * 6;
size_t gaugeLen = vol* Nbatch * 4 * ReconS;
size_t gaugeLenNaik = vol* Nbatch * 4 * ReconN;
printf("vol is %zu, spinorLen %zu, gaugeLen %zu\n",vol,spinorLen,gaugeLen);
origin = (float*)malloc(((num_shift + 1)*spinorLen+gaugeLen+gaugeLenNaik)*sizeof(float));
char name[100];
for(int i = 0; i<(num_shift + 1)*spinorLen+gaugeLen+gaugeLenNaik;i++) {
origin[i] = 1.0;
}
kuqcd_init();
float *originSrc = origin;
float *origingaugeSmeared = &(origin[(num_shift+1)*spinorLen]);
float *origingaugeNaik = &(origin[(num_shift+1)*spinorLen+gaugeLen]);
rootLogger.Info("SpinorFieldSetValues1 done");
Gagugeoptions.link_type = KUQCD_FAT_LINK;
Gagugeoptions.recon = ReconS;
kuqcd_load_gaugefield((void*)origingaugeSmeared, &Gagugeoptions);
Gagugeoptions.link_type = KUQCD_LONG_LINK;
Gagugeoptions.recon = ReconN;
kuqcd_load_gaugefield((void*)origingaugeNaik, &Gagugeoptions);
rootLogger.Info("initializing SetValues done");
KuQCDInvertParam invertOpts;
invertOpts.dslash_type = KUQCD_HISQPC_DIRAC;
invertOpts.inverter_type = KUQCD_CGM;
invertOpts.matpc_type = KUQCD_ODD_ODD;
int iter = 0;
invertOpts.num_shift = num_shift;
invertOpts.max_iter = 6500;
invertOpts.tol = 1e-12;
invertOpts.shifts = shifts;
invertOpts.prec =prec;
void *dst[num_shift] = {nullptr};
float *dst_value[num_shift];
for(int i = 0;i<num_shift;i++) {
dst_value[i] = (float*)malloc(spinorLen * sizeof(float));
dst[i] = (void *)(dst_value[i]);
}
double tStart = MPI_Wtime();
kuqcd_solve_multishift(dst, (void *)originSrc, &invertOpts);
double tEnd = MPI_Wtime();
Rank0Printf("done\n");
fflush(0);
Rank0Printf("total time = %e [s]\n", tEnd - tStart);
Rank0Printf("time/iter = %e [s]\n", (tEnd - tStart) / invertOpts.iter);
fflush(stdout);
}
父主题: 函数定义