zoukankan      html  css  js  c++  java
  • baostock_multiprocessing 多进程取数据

    #!/usr/bin/env python
    import baostock as bs
    import pandas as pd
    import time
    import os
    import shutil
    import multiprocessing
    
    def download_factor(start_date, end_date, stock_df):
        rs_list = []
        result_factor = pd.DataFrame()
        for code in stock_df["code"]:
            # print("Downloading factor start:" + code,threading.current_thread().name)
            rs_factor = bs.query_adjust_factor(code=code, start_date=start_date, end_date=end_date)
            # print(rs_factor,"Downloading factor mid:" + code, threading.current_thread().name)
            while (rs_factor.error_code == '0') & rs_factor.next():
                rs_list.append(rs_factor.get_row_data())
            result_factor = pd.DataFrame(rs_list, columns=rs_factor.fields)
            # print("Downloading factor end:" + code, threading.current_thread().name)
        # print(result_factor)
        # print("Downloading factor end:" , threading.current_thread().name)
        return result_factor
    
    def download_data(start_date,end_date,code):
        # 获取指定日期的指数、股票数据
        data_df = pd.DataFrame()
        #print("Downloading :" + code)
        k_rs = bs.query_history_k_data_plus(code, "date,code,open,high,low,close,volume,amount,turn,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM",
                                            start_date=start_date, end_date=end_date,adjustflag= "2",frequency="d")
        data_df = data_df.append(k_rs.get_data())
        return data_df
    
    def conpare_list():
        stock_rs = bs.query_all_stock(end_date)
        stock_df = stock_rs.get_data()
        file_name = pathsave + "\" + "all.csv"
        print(file_name)
        stock_read = pd.read_csv(file_name)
        print(stock_read)
        for code in stock_df["code"]:
            #print(code)
            flag_t = stock_read.loc[stock_read["code"] == code,"flag"]
            flag_t = flag_t.reset_index(drop=True)
            flag_t = pd.DataFrame(flag_t)
            t = ''
            if flag_t.empty:
                t = "new"
            else:
                t = flag_t.loc[0,"flag"]
            stock_df.loc[stock_df["code"] == code,"flag"] = t
        return stock_df
    
    def add_data(end_date,stock_df,pathsave):
        stock_df = stock_df.drop_duplicates(subset=["code"], keep="last", inplace=False)
        stock_df["code2"] = stock_df["code"].str.replace("sh.", "SH")
        stock_df["code2"] = stock_df["code2"].str.replace("sz.", "SZ")
        stock_df = stock_df.set_index("code")
        #print(stock_df)
        for code in stock_df.index:
            file = pathsave + "\"  + stock_df.loc[code,"flag"]  +"\"+ stock_df.loc[code,"code2"]+".csv"
            #print(file)
            df_old = pd.DataFrame()
            if  os.path.isfile(file):
                df_old = pd.read_csv(file)
            df_all = download_data(stock_df.loc[code,"start_date"],end_date,code)
            df_all["code"] = df_all["code"].str.replace("sh.", "SH")
            df_all["code"] = df_all["code"].str.replace("sz.", "SZ")
            df_all["date"] = df_all["date"].str.replace("-", "")
            df_old = df_old.append(df_all)
            #df_new = df_old.reset_index(drop=True)
            df_old["date"] = df_old["date"].astype(str)
            df_old = df_old.drop_duplicates(subset=["date"], keep="last", inplace=False)
            df_old.to_csv(file,sep=",",encoding="gbk", index=False)
    
    def rewrite_new_file(pathsave):#对新增加的股票进行移动,更新到all.csv文件
        file_name_w = pathsave + "\" + "all.csv"
        file_name_r = pathsave + "\" + "list.csv"
        pathdir = pathsave + "\" + "new"
        stock_read = pd.read_csv(file_name_r)
        pd_new = stock_read.loc[stock_read["flag"] == "new"]
        #newfiles = os.listdir(pathdir)
        #print(stock_read)
        if len(pd_new)>0:
            for file1 in pd_new["code"]:
                file =file1
                #print(file)
                file = file.replace("sz.", "SZ")
                file = file.replace("sh.", "SH")
                file = file + ".csv"
                file2 = file
                file = pathdir + "\" + file
                if os.path.isfile(file):
                    df_new = pd.read_csv(file)
                    if pd.isna(df_new.loc[0,"peTTM"]):
                        print(file,"可能是指数文件")
                    else:
                        if file.find("SZ")>=0:
                            #print(file.find("SZ"))
                            stock_read.loc[stock_read["code"]==file1, "flag"] = "sz"
                            pathdir_sz = pathsave + "\" + "sz"
                            dstfile = pathdir_sz +"\"+file2
                            shutil.move(file, dstfile)
                        else:
                            stock_read.loc[stock_read["code"]==file1, "flag"] = "sh"
                            pathdir_sz = pathsave + "\" + "sh"
                            dstfile = pathdir_sz + "\" + file2
                            shutil.move(file, dstfile)
    
        stock_read.to_csv(file_name_w,sep=",",encoding="utf-8", index=False)
    
    def sub_process(start_date,end_date,df_only_name1,q):
        lg = bs.login()
        print('login respond error_code:' + lg.error_code)
        print('login respond  error_msg:' + lg.error_msg)
        print('-----process begin-----')
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), multiprocessing.current_process().name)
        df_factor1 = download_factor(start_date, end_date, df_only_name1)
        q.put(df_factor1,block = False)
        print('-----process done-----')
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),multiprocessing.current_process().name)
        exit(0)
    
    def sub_process2(end_date,df_only_name1,pathsave,q):
        lg = bs.login()
        print('login respond error_code:' + lg.error_code)
        print('login respond  error_msg:' + lg.error_msg)
        print('-----process 下载数据 begin-----')
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), multiprocessing.current_process().name)
        add_data(end_date, df_only_name1,pathsave)
        q.put(multiprocessing.current_process().name,block = False)
        print('-----process 数据下载写入结束 done-----')
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),multiprocessing.current_process().name)
        exit(0)
    
    if __name__ == '__main__':
        # 获取指定日期全部股票的日K线数据
        print("hello")
        lg = bs.login()
        print('login respond error_code:' + lg.error_code)
        print('login respond  error_msg:' + lg.error_msg)
        pathsave = 'G:\datas of status\python codes\baostock\lx'  # 设定临时文件存放位置
    
        ori_date = "2018-01-01"#设定最初日期数据
        start_date = "2020-08-18"     #常设,设定这次要下载的数据开始日期
        end_date = "2020-08-20"       #常设,设定这次要下载的数据结束日期,结束日期必须是交易日,否则会出错
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        print("开始比较")
        stock_df = conpare_list()    #分清指数,上证,深证
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        print("开始下载factor")
        file_w = pathsave + "\" + "list.csv"
        stock_df.to_csv(file_w, sep=",", index=False, header=True)
    
        #=====================下载factor
        all_nums = len(stock_df)
        epochs = 5
        step = int(all_nums / epochs)
        process_list = []
        q = multiprocessing.Queue(maxsize=epochs)
        for i in range(epochs):
            begin = i * step
            end = begin + step
            if i == epochs - 1:
                end = all_nums
            df_only_name1 = stock_df[begin:end]
            print("no.",i,begin,end)
            tmp_process = multiprocessing.Process(target=sub_process, args=(start_date,end_date,df_only_name1, q))
            process_list.append(tmp_process)
        for process in process_list:
            process.start()
            # print("start",process)
        while (q.qsize() != epochs):
            # print(q.qsize(),"begin")
            if (q.qsize() >= 1):
                print(q.qsize())
                time.sleep(5)
            else:
                time.sleep(20)
    
        time.sleep(1)
        df_factor = pd.DataFrame()
        while not q.empty():
            list_g = q.get()
            df_factor = df_factor.append(list_g)
        #=========
        #df_factor = download_factor(start_date,end_date,stock_df)  #分清有无复权,若有则设定开初下载数据时间有最初日期,然后再重新下载数据
        df_factor = df_factor.drop_duplicates(subset=["code"], keep="last", inplace=False)
        print(df_factor)
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        #exit(0)
    
        print("下载factor结束,开始下载数据")
        stock_df["start_date"] = start_date
        for code in df_factor["code"]:
            stock_df.loc[stock_df["code"] == code,"start_date"] = ori_date
        #print(stock_df[220:240])
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),"下边开始下载数据")
        #==============================#下载数据
        all_nums = len(stock_df)
        epochs = 5
        step = int(all_nums / epochs)
        process_list = []
        q = multiprocessing.Queue(maxsize=epochs)
        for i in range(epochs):
            begin = i * step
            end = begin + step
            if i == epochs - 1:
                end = all_nums
            df_only_name1 = stock_df[begin:end]
            print("no.", i, begin, end)
            tmp_process = multiprocessing.Process(target=sub_process2, args=(end_date, df_only_name1,pathsave, q))
            process_list.append(tmp_process)
        for process in process_list:
            process.start()
            # print("start",process)
        while (q.qsize() != epochs):
            # print(q.qsize(),"begin")
            if (q.qsize() >= 1):
                print(q.qsize())
                time.sleep(5)
            else:
                time.sleep(20)
    
        time.sleep(1)
        #df_process = pd.DataFrame()
        while not q.empty():
            list_g = q.get()
            print(list_g,"done")
            #df_process = df_process.append(list_g)
        #=============================
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        print("下载数据结束")
        rewrite_new_file(pathsave)
        #print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        bs.logout()
  • 相关阅读:
    linux kernel内存碎片防治技术
    内核线程
    Linux内核高端内存
    Lcd(一)显示原理
    LSB和MSB
    图解slub
    数据库小试题2
    编写函数获取上月的最后一天
    php中的static静态变量
    mysql小试题
  • 原文地址:https://www.cnblogs.com/rongye/p/13549071.html
Copyright © 2011-2022 走看看