简易回测框架开发:
框架内容:
上下文信息保存,context
获取数据
下单函数
用户接口
........
import pandas as pd import matplotlib.pyplot as plt import tushare import datetime import dateutil ''' 获取所有的股票交易日,交易日信息保存在csv文件 ''' try: trade_cal = pd.read_csv("trade_cal.csv") except: trade_cal = tushare.trade_cal() trade_cal.to_csv("trade_cal.csv") class Context: def __init__(self, cash, start_date, end_date): ''' 保存股票信息 :param cash: 现金量 :param start_date: 量化策略开始时间 :param end_date: 量化策略结束时间 :param positions: 持仓股票和对应的数量 :param benchmark: 参考股票 :param date_range: 开始-结束之间的所有交易日 :param dt: 当前日期 (循环时当前日期会发生变化) ''' self.cash = cash self.start_date = start_date self.end_date = end_date self.positions = {} # 持仓信息 self.benchmark = None self.date_range = trade_cal[(trade_cal['isOpen']==1)& (trade_cal['calendarDate']>=start_date)& (trade_cal['calendarDate']<=end_date)]['calendarDate'].values self.dt = None class G: ''' 保存用户的全局参数 ''' pass ''' 默认的初始化信息 ''' g = G() CASH = 100000 START_DATE = '2016-01-07' END_DATE = '2017-01-31' context = Context(CASH,START_DATE,END_DATE) def attribute_history(security, count, field=('open','close','high','low','volume')): ''' 获取某股票count天的历史行情,每运行一次该函数,日期范围后移 :param security: 股票代码 :param count: 天数 :param field: 字段 :return: ''' end_date = (context.dt - datetime.timedelta(days=1)).strftime('%Y-%m-%d') start_date = trade_cal[(trade_cal['isOpen']==1)& (trade_cal['calendarDate']<=end_date)][-count:]['calendarDate'].iloc[0] return attribute_daterange_history(security,start_date,end_date,field) def attribute_daterange_history(security, start_date,end_date, field=('open','close','high','low','volume')): ''' 底层,获取某股票某一段时间的历史行情 :param security: :param start_date: :param end_date: :param field: :return: ''' df = tushare.get_k_data(security,start_date,end_date) df.index = df['date'] return df[list(field)] def get_today_data(security): ''' 获取context的"当天"的股票信息,停牌返回Null :param security: :return: ''' try: today = context.dt.strftime('%Y-%m-%d') df = tushare.get_k_data(security,today,today) df.index = df['date'] data = df.loc[today] except KeyError: # 股票停牌 data = pd.Series() return data def _order(today_data, security, amount): ''' 底层买股票的函数 :param today_data: "当天"的股票价格OCHL :param security: 股票代码 :param amount: 交易股数,正数为买入,负数为卖出 :return: ''' p = today_data['open'] # 找不到该股票默认为0股 old_amount = context.positions.get(security, 0) if len(today_data) == 0: print("今日停牌") return if context.cash - amount * p < 0: amount = context.cash // p print('%s:现金不足,已调整为%d' %(today_data['date'],amount)) if amount % 100 != 0: # 买或卖不是100的倍数就调整为100的倍数,卖光则不调整 if amount != -old_amount: # 2345 => 2300 amount = int(amount / 100) * 100 print('%s:不是100的倍数,已调整为%d' %(today_data['date'],amount)) if old_amount < -amount: amount = -old_amount print('%s:卖出股票不能超过持仓数,已调整为%d'%(today_data['date'],amount)) # 更新持仓信息 context.positions[security] = old_amount + amount # 更新钱 context.cash -= amount*p # 持仓为0就删掉 if context.positions[security] == 0: del context.positions[security] def order(security, amount): # 买入股票。amount为正表示买入,负表示卖出 today_data = get_today_data(security) _order(today_data, security, amount) def order_target(security, amount): # 把股票交易到多少股,不能为负数,比原来小是卖出,比原来大是买入 if amount < 0: print("数量不能为负,已调整为0") amount = 0 today_data = get_today_data(security) hold_amount = context.positions.get(security, 0) # TODO: T + 1 closeable total delta_amount = amount - hold_amount _order(today_data,security,delta_amount) def order_value(security, value): # 买多少钱的股票或者卖多少钱的股票 today_data = get_today_data(security) amount = value / today_data['open'] _order(today_data,security,amount) def order_target_value(security, value): # 买到或者卖到多少钱 if value < 0: print("价值不能为负,已调整为0") value = 0 today_data = get_today_data(security) hold_value = context.positions.get(security,0) * today_data['open'] dalta_value = value - hold_value order_value(security,dalta_value) def run(): plt_df = pd.DataFrame(index=pd.to_datetime(context.date_range), columns=['value']) # 最初的钱,算收益率用 init_value = context.cash # 保存停牌前一天的股票价格 last_price = {} # 用户接口1 initialize(context) for dt in context.date_range: context.dt = dateutil.parser.parse(dt) # 用户接口2 handle_data(context) # 股票和现金的总价值 value = context.cash for stock in context.positions: # 考虑停牌的情况 today_data = get_today_data(stock) if len(today_data) == 0: p = last_price[stock] else: p = today_data['open'] last_price[stock] = p value += p * context.positions[stock] plt_df.loc[dt, 'value'] = value plt_df['ratio'] = (plt_df['value']-init_value) / init_value bm_df = attribute_daterange_history(context.benchmark, context.start_date, context.end_date) bm_init = bm_df['open'][0] plt_df['benchmark_raito'] = (bm_df['open']-bm_init) / bm_init print(plt_df) plt_df[['ratio','benchmark_raito']].plot() plt.show() ''' initialize和handle_data是用户的操作 ''' def initialize(context): context.benchmark = '601318' g.p1 = 5 g.p2 = 60 g.security = '601318' def handle_data(context): hist = attribute_history(g.security, g.p2) ma5 = hist['close'][-g.p1:].mean() ma60 = hist['close'].mean() if ma5 > ma60 and g.security not in context.positions: order_value(g.security, context.cash) elif ma5 < ma60 and g.security in context.positions: order_target(g.security,0) if __name__ == '__main__': run()