鲲鹏社区首页
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助

data_process.py文件

DLRM数据预处理文件data_process.py内容如下。

from __future__ import absolute_import, division, print_function, unicode_literals
import os.path
from io import StringIO
from os import path
import os
import time
import numpy as np
from  multiprocessing import Pool,shared_memory
import warnings
warnings.filterwarnings('ignore')

#进程-显示错误
def showerror(err):
    print("show error:",err)

#结合共享内存,多进程同时统计每一列的分类
#把每一列的字符串映射成唯一标识的整数
def speedup_col(out_name,in_name,shape,col,d_path,split):
    pid = os.getpid()
    os.sched_setaffinity(pid, [col])
    ssm = shared_memory.SharedMemory(name=out_name)
    out = np.ndarray(shape=shape,dtype="i4",buffer=ssm.buf)
    
    ssm_in = shared_memory.SharedMemory(name=in_name)
    mat = np.ndarray(shape=shape,dtype="U8",buffer=ssm_in.buf)
    convertDict = {}
    count = 0
    for i in range(mat.shape[0]):
        # add to convertDict and increment count
        if mat[i,col] not in convertDict:
            convertDict[mat[i,col]] = count
            count += 1
        out[i,col] = convertDict[mat[i,col]]
    
    return (count,col)
def processKaggleCriteoAdData(split, d_path,o_filename,datas):
    #拼接拆分后的数据
    for i, data in enumerate(datas):
        if i == 0:
            X_cat = data["X_cat"]
            X_int = data["X_int"]
            y = data["y"]
        else:
            X_cat = np.concatenate((X_cat, data["X_cat"]))
            X_int = np.concatenate((X_int, data["X_int"]))
            y = np.concatenate((y, data["y"]))
    shape =X_cat.shape
    counts = [0 for _ in range(shape[1])]
    pool = Pool(shape[1])
    results = []
    out_X_cat = np.zeros(shape=shape,dtype='i4')
    out_name = f"out_X_cat"
    in_name = f"in_X_cat"
    X_cat_bytes = X_cat.nbytes
    # 创建共享内存
    ssm = shared_memory.SharedMemory(name=out_name,create=True, size=out_X_cat.nbytes)
    ssm_in = shared_memory.SharedMemory(name=in_name,create=True, size=X_cat.nbytes)
    del out_X_cat
    new_X_cat = np.ndarray(shape=shape,dtype="U8",buffer=ssm_in.buf)
    new_X_cat[:] = X_cat
    del X_cat
    X_cat = new_X_cat
    #开启进程池,每个进程处理一列数据
    for j in range(shape[1]):
        res = pool.apply_async(speedup_col,args=(out_name,in_name,shape,j,d_path,split,),error_callback=showerror)
        results.append(res)
    pool.close()
    pool.join()
    for res in results:
        count, col= res.get()
        counts[col] = count
    X_int[X_int < 0] = 0
    X_cat = np.ndarray(shape=shape,dtype="i4",buffer=ssm.buf)
    np.savez(
        str(d_path) + str(o_filename) + ".npz",
        X_cat=X_cat,
        X_int=X_int,
        y=y,
        counts=counts,
    )
    
    print("counts: ",counts)
    print("X_cat shape: ",X_cat.shape)
    print("X_int shape: ",X_int.shape)
    print("Step 2 done !!!")
    return (ssm,ssm_in)

#每个进程处理n行数据,将稀疏、稠密特征、标签拆分并保存到对应的numpy数组
def speedup(i,info,num_data_in_split):
    type = np.dtype(
        [("label", ("i4", 1)), ("int_feature", ("i4", 13)), ("cat_feature", ("U8", 26))]
    )
    pid = os.getpid()
    os.sched_setaffinity(pid, [i])
    one = shared_memory.SharedMemory(name="one")
    two = shared_memory.SharedMemory(name="two")
    three = shared_memory.SharedMemory(name="three")
    y = np.ndarray(shape=(num_data_in_split),dtype="i4",buffer=one.buf)
    X_int = np.ndarray(shape=(num_data_in_split, 13),dtype="i4",buffer=two.buf)
    X_cat = np.ndarray(shape=(num_data_in_split, 26),dtype="U8",buffer=three.buf)
    for cut in  info:
        line,index = cut
        data = np.genfromtxt(StringIO(line), dtype=type, delimiter="\t")
        y[index] = data["label"]
        X_int[index] = data["int_feature"]
        X_cat[index] = data["cat_feature"]
def getKaggleCriteoAdData(datafile="", o_filename=""):
    d_path = "./kaggle_data/"
    # determine if intermediate data path exists
    if path.isdir(str(d_path)):
        print("Saving intermediate data files at %s" % (d_path))
    else:
        os.mkdir(str(d_path))
        print("Created %s for storing intermediate data files" % (d_path))
    # determine if data file exists (train.txt)
    if path.exists(str(datafile)):
        print("Reading data from path=%s" % (str(datafile)))
    else:
        print(
            "Path of Kaggle Display Ad Challenge Dataset is invalid; please download from https://labs.criteo.com/2014/09/kaggle-contest-dataset-now-available-academic-use/"
        )
        exit(0)
    # count number of datapoints in training set
    total_count = 0
    with open(str(datafile)) as f:
        for _ in f:
            total_count += 1
    print("Total number of datapoints:", total_count)
    # determine length of split over 7 days
    split = 1
    num_data_per_split, extras = divmod(total_count, 7)
    # initialize data to store
    if extras > 0:
        num_data_in_split = num_data_per_split + 1
        extras -= 1
    y = np.zeros(num_data_in_split, dtype="i4")
    X_int = np.zeros((num_data_in_split, 13), dtype="i4")
    X_cat = np.zeros((num_data_in_split, 26), dtype="U8")
    
    cpus = os.cpu_count()
    y_size = y.nbytes
    X_int_size = X_int.nbytes
    X_cat_size = X_cat.nbytes
    one = shared_memory.SharedMemory(name="one",create=True, size=y_size)
    two = shared_memory.SharedMemory(name="two",create=True, size=X_int_size)
    three = shared_memory.SharedMemory(name="three",create=True, size=X_cat_size)
    pool =Pool(cpus)
    count = 0
    process_id = 0
    batch_size = 10000
    info = []
    datas = []
    if split == 1:
        # load training data
        start = time.time()
        with open(str(datafile)) as f:
            
            for i, line in enumerate(f):
                # store day"s worth of data and reinitialize data
                if i == (count + num_data_in_split):
                    pool.close()
                    pool.join()
                    y = np.ndarray(shape=(num_data_in_split),dtype="i4",buffer=one.buf)
                    X_int = np.ndarray(shape=(num_data_in_split, 13),dtype="i4",buffer=two.buf)
                    X_cat = np.ndarray(shape=(num_data_in_split, 26),dtype="U8",buffer=three.buf)
                    datas.append({
                        'X_int':X_int.copy(),
                        'X_cat':X_cat.copy(),
                        'y':y.copy()
                    })
                    print("\nSaved day_{0} in datas!".format(split))
                    split += 1
                    count += num_data_in_split
                    if extras > 0:
                        num_data_in_split = num_data_per_split + 1
                        extras -= 1
                    pool =Pool(cpus)
                index = i - count
                if  ((index+1)==num_data_in_split) or ((index%batch_size)==0):
                    pool.apply_async(speedup,args=(process_id%cpus,info,num_data_in_split),error_callback=showerror)
                    info = []
                    info.append((line,index))
                    process_id +=1
                else:
                    info.append((line,index))
        pool.apply_async(speedup,args=(process_id%cpus,info,num_data_in_split),error_callback=showerror)
        pool.close()
        pool.join()
        y = np.ndarray(shape=(num_data_in_split),dtype="i4",buffer=one.buf)
        X_int = np.ndarray(shape=(num_data_in_split, 13),dtype="i4",buffer=two.buf)
        X_cat = np.ndarray(shape=(num_data_in_split, 26),dtype="U8",buffer=three.buf)
        datas.append({
                'X_int':X_int.copy(),
                'X_cat':X_cat.copy(),
                'y':y.copy()
        })
        print("\nSaved day_{0} in datas!".format(split))
        ssms = [one,three,two]
        for ssm in ssms:
            ssm.close()
            ssm.unlink()
        print(f"spend {int(time.time()-start)}s in first step!")
        
    else:
        print("Using existing %skaggle_day_*.npz files" % str(d_path))
    print("Step 1 done !!!")
    return datas
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--file',default='train.txt',choices=['split.txt','train.txt'],)
args = parser.parse_args()
split = 7
d_path = "./kaggle_data/"
o_filename = "kaggle_processed"
start = time.time()
try:
    datas = getKaggleCriteoAdData(args.file,o_filename)
    ssms = processKaggleCriteoAdData(split, d_path,o_filename,datas)
    for ssm in ssms:
        ssm.close()
        ssm.unlink()
    
except Exception as e:
    print("there are something wrong:", e)
    
print(f"All done, spend time: {int(time.time()-start) }s")