zoukankan      html  css  js  c++  java
  • Python Programming for Finance

     

     

    Intro and Getting Stock Price Data

    We're going to run through the basics of importing financial (stock) data into Python using the Pandas framework. From here, we'll manipulate the data and attempt to come up with some sort of system for investing in companies, apply some machine learning, even some deep learning, and then learn how to back-test a strategy. 

    I am using Python 3.5,  64 bit Python.

    Required Modules to start:

    1. Numpy
    2. Matplotlib
    3. Pandas
    4. Pandas-datareader
    5. BeautifulSoup4
    6. scikit-learn / sklearn

    To begin, we're going to make the following imports:

    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web

    Datetime will easily allow us to work with dates, matplotlib to graph things, pandas to manipulate data, and the pandas_datareader is the newest pandas io library.

    Now for some starting setup:

    style.use('ggplot')
    
    start = dt.datetime(2000, 1, 1)
    end = dt.datetime(2016, 12, 31)
    
    

    We're setting a style, so our graphs don't look horrendous. In finance, it's of the utmost importance that your graphs are pretty, even if you're losing money. Next, we're setting a start and end datetime object, this will be the range of dates that we're going to grab stock pricing information for.

    Now, we can make a dataframe from this data:

    df = web.DataReader('TSLA', "yahoo", start, end)
     

    The line web.DataReader('TSLA', "yahoo", start, end) uses the pandas_datareader package, looks for the stock ticker TSLA(Tesla), gets the information from yahoo, for the starting date of whatever start is and ends at the end variable that we chose. 

    So now we've got a Pandas.DataFrame object that contains stock pricing information for Tesla. Let's see what we have here:

    print(df.head())
                     Open   High        Low      Close    Volume  Adj Close
    Date                                                                   
    2010-06-29  19.000000  25.00  17.540001  23.889999  18766300  23.889999
    2010-06-30  25.790001  30.42  23.299999  23.830000  17187100  23.830000
    2010-07-01  25.000000  25.92  20.270000  21.959999   8218800  21.959999
    2010-07-02  23.000000  23.10  18.709999  19.200001   5139800  19.200001
    2010-07-06  20.000000  20.00  15.830000  16.110001   6866900  16.110001

    The .head() is something you can do with Pandas DataFrames, and it will output the first n rows, where n is the optional parameter you pass. If you don't pass a parameter, 5 is the default value. We mosly will use .head() to just get a quick glimpse of our data to make sure we're on the right track. 

    In case you do not know:

    • Open - When the stock market opens in the morning for trading, what was the price of one share?
    • High - over the course of the trading day, what was the highest value for that day?
    • Low - over the course of the trading day, what was the lowest value for that day?
    • Close - When the trading day was over, what was the final price?
    • Volume - For that day, how many shares were traded?
    • Adj Close - This one is slightly more complicated, but, over time, companies may decide to do something called a stock split. For example, Apple did one once their stock price exceeded $1000. Since in most cases, people cannot buy fractions of shares, a stock price of $1,000 is fairly limiting to investors. Companies can do a stock split where they say every share is now 2 shares, and the price is half. Anyone who had 1 share of Apple for $1,000, after a split where Apple doubled the shares, they would have 2 shares of Apple (AAPL), each worth $500. Adj Close is helpful, since it accounts for future stock splits, and gives the relative price to splits. For this reason, the adjusted prices are the prices you're most likely to be dealing with.

    Handling Data and Graphing 

    We're going to further break down some basic data manipulation and visualizations with our stock data. The starting code that we're going to be using (which was covered in the previous tutorial) is:

    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web
    
    style.use('ggplot')
    
    start = dt.datetime(2000, 1, 1)
    end = dt.datetime(2016, 12, 31)
    
    df = web.DataReader('TSLA', "yahoo", start, end)
    View Code

    What are some things we can do with these DataFrames? For one, we can save them easily to a variety of datatypes. One option is a csv:

    df.to_csv('TSLA.csv')

    Rather than reading data from Yahoo's finance API to a DataFrame, we can also read data from a CSV file into a DataFrame:

    df = pd.read_csv('TSLA.csv', parse_dates=True, index_col=0)

    Now, we can graph with:

    df.plot()
    plt.show()
    完整代码
    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web
    
    style.use('ggplot')
    
    start = dt.datetime(2000, 1, 1)
    end = dt.datetime(2016, 12, 31)
    
    df = web.DataReader('TSLA', "yahoo", start, end)
    df.to_csv('TSLA.csv')
    df = pd.read_csv('TSLA.csv', parse_dates=True, index_col=0)
    df.plot()
    plt.show()
    View Code

    except that the only thing we can really see here is the volume, since it's on a scale MUCH larger than stock price. How can we maybe just graph what we're interested in?

    df['Adj Close'].plot()
    plt.show()

    As you can see, you can reference specific columns in the DataFrame like: df['Adj Close'], but you can also reference multiple at a time, like so:

    df[['High','Low']]

    Basic stock data Manipulation

     The starting code that we're going to be using is:

    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web
    style.use('ggplot')
    
    df = pd.read_csv('tsla.csv', parse_dates=True, index_col=0)

    The Pandas module comes equipped with a bunch of built-in functionality that you can leverage, along with ways to create custom Pandas functions. We'll cover some custom functions later, but, for now, let's do a very common operation to this data: Moving Averages.

    The idea of a simple moving average is to take a window of time, and calculate the average price in that window. Then we shift that window over one period, and do it again. In our case, we'll do a 100 day rolling moving average. So this will take the current price, and the prices from the past 99 days, add them up, divide by 100, and there's your current 100-day moving average. Then we move the window over 1 day, and do the same thing again. Doing this in Pandas is as simple as:

    df['100ma'] = df['Adj Close'].rolling(window=100).mean()

    Doing df['100ma'] allows us to either re-define what comprises an existing column if we had one called '100ma,' or create a new one, which is what we're doing here. We're saying that the df['100ma'] column is equal to being the df['Adj Close'] column with a rolling method applied to it, with a window of 100, and this window is going to be a mean() (average) operation.

    Now, we could do:

    print(df.head())
                     Open   High        Low      Close    Volume  Adj Close  100ma
    Date                                                                          
    2010-06-29  19.000000  25.00  17.540001  23.889999  18766300  23.889999    NaN
    2010-06-30  25.790001  30.42  23.299999  23.830000  17187100  23.830000    NaN
    2010-07-01  25.000000  25.92  20.270000  21.959999   8218800  21.959999    NaN
    2010-07-02  23.000000  23.10  18.709999  19.200001   5139800  19.200001    NaN
    2010-07-06  20.000000  20.00  15.830000  16.110001   6866900  16.110001    NaN

    What happened? Under the 100ma column we just see NaN. We chose a 100 moving average, which theoretically requires 100 prior datapoints to compute, so we won't have any data here for the first 100 rows. NaN means "Not a Number." With Pandas, you can decide to do lots of things with missing data, but, for now, let's actually just change the minimum periods parameter:

    df['100ma'] = df['Adj Close'].rolling(window=100,min_periods=0).mean()
    print(df.head())
                     Open   High        Low      Close    Volume  Adj Close  \
    Date                                                                      
    2010-06-29  19.000000  25.00  17.540001  23.889999  18766300  23.889999   
    2010-06-30  25.790001  30.42  23.299999  23.830000  17187100  23.830000   
    2010-07-01  25.000000  25.92  20.270000  21.959999   8218800  21.959999   
    2010-07-02  23.000000  23.10  18.709999  19.200001   5139800  19.200001   
    2010-07-06  20.000000  20.00  15.830000  16.110001   6866900  16.110001   
    
                    100ma  
    Date                   
    2010-06-29  23.889999  
    2010-06-30  23.860000  
    2010-07-01  23.226666  
    2010-07-02  22.220000  
    2010-07-06  20.998000  

    Alright, that worked, now we want to see it! But we've already seen simple graphs, how about something slightly more complex?

    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=5, colspan=1)
    ax2 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1,sharex=ax1)

    If you want to know more about subplot2grid, check out this subplots with Matplotlib tutorial.

    Basically, we're saying we want to create two subplots, and both subplots are going to act like they're on a 6x1 grid, where we have 6 rows and 1 column. The first subplot starts at (0,0) on that grid, spans 5 rows, and spans 1 column. The next axis is also on a 6x1 grid, but it starts at (5,0), spans 1 row, and 1 column. The 2nd axis also has the sharex=ax1, which means that ax2 will always align its x axis with whatever ax1's is, and visa-versa. Now we just make our plots:

    ax1.plot(df.index, df['Adj Close'])
    ax1.plot(df.index, df['100ma'])
    ax2.bar(df.index, df['Volume'])
    
    plt.show()

    Above, we've graphed the close and the 100ma on the first axis, and the volume on the 2nd axis. Our result:

     Full code up to this point:

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    
    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web
    style.use('ggplot')
    
    df = pd.read_csv('tsla.csv', parse_dates=True, index_col=0)
    df['100ma'] = df['Adj Close'].rolling(window=100,min_periods=0).mean()
    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=5, colspan=1)
    ax2 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1,sharex=ax1)
    ax1.plot(df.index, df['Adj Close'])
    ax1.plot(df.index, df['100ma'])
    ax2.bar(df.index, df['Volume'])
    
    plt.show()
    View Code

    More stock manipulations  

    We're going to create a candlestick / OHLC graph based on the Adj Close column, which will allow me to cover resampling and a few more data visualization concepts.

    An OHLC chart, called a candlestick chart, is a chart that condenses the open, high, low, and close data all in one nice format. Plus it makes pretty colors, and remember what I told you about good looking charts?

    Starting code that's been covered up to this point in previous tutorials:

    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    import pandas as pd
    import pandas_datareader.data as web
    style.use('ggplot')
    
    df = pd.read_csv('tsla.csv', parse_dates=True, index_col=0)
     

    Unfortunately, making candlestick graphs right from Pandas isn't built in, even though creating OHLC data is. One day, I am sure this graph type will be made available, but, today, it isn't. That's alright though, we'll make it happen! First, we need to make two new imports:

    from matplotlib.finance import candlestick_ohlc
    import matplotlib.dates as mdates

    The first import is the OHLC graph type from matplotlib, and the second import is the special mdates type that...is mostly just a pain in the butt, but that's the date type for matplotlib graphs. Pandas automatically handles that for you, but, like I said, we don't have that luxury yet with candlesticks.

    First, we need proper OHLC data. Our current data does have OHLC values, and, unless I am mistaken, Tesla has never had a split, but you wont always be this lucky. Thus, we're going to create our own OHLC data, which will also allow us to show another data transformation that comes from Pandas:

    df_ohlc = df['Adj Close'].resample('10D').ohlc()
     

    What we've done here is created a new dataframe, based on the df['Adj Close']column, resamped with a 10 day window, and the resampling is an ohlc (open high low close). We could also do things like .mean() or .sum() for 10 day averages, or 10 day sums. Keep in mind, this 10 day average would be a 10 day average, not a rolling average. Since our data is daily data, resampling it to 10day data effectively shrinks the size of our data significantly. This is how you can normalize multiple datasets. Sometimes, you might have data that tracks once a month on the 1st of the month, other data that logs at the end of each month, and finally some data that logs weekly. You can resample this dataframe to the end of the month, every month, and effectively normalize it all! That's a more advanced Pandas feature that you can learn more about from the Pandas series if you like.

    We'd like to graph both the candlestick data, as well as the volume data. We don't HAVE to resample the volume data, but we should, since it would be too granular compared to our 10D pricing data.

    df_volume = df['Volume'].resample('10D').sum()
     

    We're using sum here, since we really want to know the total volume traded over those 10 days, but you could also use mean instead. Now if we do:

    print(df_ohlc.head())

    We get:

                     open       high        low      close
    Date                                                  
    2010-06-29  23.889999  23.889999  15.800000  17.459999
    2010-07-09  17.400000  20.639999  17.049999  20.639999
    2010-07-19  21.910000  21.910000  20.219999  20.719999
    2010-07-29  20.350000  21.950001  19.590000  19.590000
    2010-08-08  19.600000  19.600000  17.600000  19.150000

    That's expected, but, we want to now move this information to matplotlib, as well as convert the dates to the mdates version. Since we're just going to graph the columns in Matplotlib, we actually don't want the date to be an index anymore, so we can do:

    df_ohlc = df_ohlc.reset_index()

    Now dates is just a regular column. Next, we want to convert it:

    df_ohlc['Date'] = df_ohlc['Date'].map(mdates.date2num)

    Now we're going to setup the figure:

    fig = plt.figure()
    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=5, colspan=1)
    ax2 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1,sharex=ax1)
    ax1.xaxis_date()

    Everything here you've already seen, except ax1.xaxis_date(). What this does for us is converts the axis from the raw mdate numbers to dates.

    Now we can graph the candlestick graph:

    candlestick_ohlc(ax1, df_ohlc.values, width=2, colorup='g')

    Then do volume:

    ax2.fill_between(df_volume.index.map(mdates.date2num),df_volume.values,0)

    The fill_between function will graph x, y, then what to fill to/between. In our case, we're choosing 0.

    plt.show()

    Full code for this tutorial:

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    
    import datetime as dt
    import matplotlib.pyplot as plt
    from matplotlib import style
    from matplotlib.finance import candlestick_ohlc
    import matplotlib.dates as mdates
    import pandas as pd
    import pandas_datareader.data as web
    style.use('ggplot')
    
    df = pd.read_csv('tsla.csv', parse_dates=True, index_col=0)
    
    df_ohlc = df['Adj Close'].resample('10D').ohlc()
    df_volume = df['Volume'].resample('10D').sum()
    
    df_ohlc.reset_index(inplace=True)
    df_ohlc['Date'] = df_ohlc['Date'].map(mdates.date2num)
    
    ax1 = plt.subplot2grid((6,1), (0,0), rowspan=5, colspan=1)
    ax2 = plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)
    ax1.xaxis_date()
    
    candlestick_ohlc(ax1, df_ohlc.values, width=5, colorup='g')
    ax2.fill_between(df_volume.index.map(mdates.date2num), df_volume.values, 0)
    plt.show()
    View Code

    Automating getting the S&P 500 list 

    We're going to be working on how we can go about grabbing pricing information en masse for a larger list of companies, and then how we can work with all of this data at once.

    To begin, we need a list of companies. I could just hand you a list, but actually acquiring a list of stocks can be just one of the many challenges you might encounter. In our case, we want a Python list of the S&P 500 companies.

    Whether you are looking for the Dow Jones companies, the S&P 500, or the Russell 3000, chances are, someone somewhere has posted a post of these companies. You will want to make sure it is up-to-date, but chances are it's not already in the perfect format for you. In our case, we're going to grab the list from Wikipedia: http://en.wikipedia.org/wiki/List_of_S%26P_500_companies.

    The tickers/symbols in Wikipedia are organized on a table. To handle for this, we're going to use the HTML parsing library, Beautiful Soup. If you would like to learn more about Beautiful Soup, I have a quick 4-part tutorial on web scraping with Beautiful Soup.

    First, let's begin with some imports:

    import bs4 as bs
    import pickle
    import requests

    bs4 is for Beautiful Soup, pickle is so we can easily just save this list of companies, rather than hitting Wikipedia every time we run (though remember, in time, you will want to update this list!), and we'll be using requests to grab the source code from Wikipedia's page.

    To begin our function:

    def save_sp500_tickers():
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
        soup = bs.BeautifulSoup(resp.text, 'lxml')
        table = soup.find('table', {'class': 'wikitable sortable'})

    First, we visit the Wikipedia page, and are given the response, which contains our source code. To treat the source code how we want, we want to access the .text attribute, which we turn to soup using BeautifulSoup. If you're not familiar with what BeautifulSoup does for you, it basically turns source code into a BeautifulSoup object that suddenly can be treated much more like a typical Python object.

    There was once a time when Wikipedia attempted to decline access to Python. Currently, at the time of my writing this, the code works without changing headers. If you're finding that the original source code (resp.text) doesn't seem to be returning the same page as you see on your home computer, add the following and change the resp var code:

        headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17'}
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies',
                            headers=headers)

    Once we have our soup, we can find the table of stock data by simply searching for the wikitable sortable classes. The only reason I know to specify this table is because I viewed the sourcecode in a browser first. There may come a time where you want to parse a different website's list of stocks, maybe it's in a table, or maybe it's a list, or maybe something with div tags. This is just one very specific solution. From here, we just iterate through the table:

        tickers = []
        for row in table.findAll('tr')[1:]:
            ticker = row.findAll('td')[0].text
            tickers.append(ticker)

    For each row, after the header row (this is why we're going through with [1:]), we're saying the ticker is the "table data" (td), we grab the .text of it, and we append this ticker to our list.

    Now, it'd be nice if we could just save this list. We'll use the pickle module for this, which serializes Python objects for us.

        with open("sp500tickers.pickle","wb") as f:
            pickle.dump(tickers,f)
    
        return tickers

    We'd like to go ahead and save this so we don't have to request Wikipedia multiple times a day. At any time, we can update this list, or we could program it to check once a month...etc.

    Full code up to this point:

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    
    import bs4 as bs
    import pickle
    import requests
    
    
    def save_sp500_tickers():
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
        soup = bs.BeautifulSoup(resp.text, 'lxml')
        table = soup.find('table', {'class': 'wikitable sortable'})
        tickers = []
        for row in table.findAll('tr')[1:]:
            ticker = row.findAll('td')[0].text
            tickers.append(ticker)
    
        with open("sp500tickers.pickle", "wb") as f:
            pickle.dump(tickers, f)
    
        return tickers
    
    
    save_sp500_tickers()
    View Code
     

    Getting all company pricing data in the S&P 500 

    In the previous , we covered how to acquire the list of companies that we're interested in (S&P 500 in our case), and now we're going to pull stock pricing data on all of them.

    Code up to this point:

    import bs4 as bs
    import pickle
    import requests
    
    
    def save_sp500_tickers():
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
        soup = bs.BeautifulSoup(resp.text, 'lxml')
        table = soup.find('table', {'class': 'wikitable sortable'})
        tickers = []
        for row in table.findAll('tr')[1:]:
            ticker = row.findAll('td')[0].text
            tickers.append(ticker)
    
        with open("sp500tickers.pickle", "wb") as f:
            pickle.dump(tickers, f)
    
        return tickers
    View Code
     

    We're going to add a few new imports:

    import datetime as dt
    import os
    import pandas as pd
    import pandas_datareader.data as web

    We'll use datetime to specify dates for the Pandas datareader, os is to check for, and create, directories. You already know what pandas is for!

    To start our new function:

    def get_data_from_yahoo(reload_sp500=False):
        
        if reload_sp500:
            tickers = save_sp500_tickers()
        else:
            with open("sp500tickers.pickle","rb") as f:
                tickers = pickle.load(f)

    Here's where I'll just show a quick example of one way you could handle for whether or not to reload the S&P 500 list. If we ask it to, the program will re-pull the S&P 500 list, otherwise it will just use our pickle. Now we want to prepare to grab data.

    Now we ne need to decide what we're going to do with the data. What I tend to do is try to parse websites ONCE, and store the data locally. I don't try to know in advance all of the things I might do with the data, but I know if I am going to pull it more than once, I might as well just save it (unless it's a huge dataset, which this is not). Thus, we're going to pull everything we can from what Yahoo returns to us for every stock and just save it. To do this, we'll create a new directory, and, in there, store stock data per company. To begin, we need that initial directory:

        if not os.path.exists('stock_dfs'):
            os.makedirs('stock_dfs')

    You could just store these datasets in the same directory as your script, but this would get pretty messy in my opinion. Now we're ready to pull the data. You already know how to do this, we did it in the very first tutorial!

        start = dt.datetime(2000, 1, 1)
        end = dt.datetime(2016, 12, 31)
        
        for ticker in tickers:
            if not os.path.exists('stock_dfs/{}.csv'.format(ticker)):
                df = web.DataReader(ticker, "yahoo", start, end)
                df.to_csv('stock_dfs/{}.csv'.format(ticker))
            else:
                print('Already have {}'.format(ticker))

    You will likely in time want to do some sort of force_data_update parameter to this function, since, right now, it will not re-pull data it already sees hit has. Since we're pulling daily data, you'd want to have this re-pulling at least the latest data. That said, if that's the case, you might be better off with using a database instead with a table per company, and then just pulling the most recent values from the Yahoo database. We'll keep things simple for now though!

    Full code up to this point:

    #!/usr/bin/env python
    # _*_ coding:utf-8 _*_
    
    import bs4 as bs
    import datetime as dt
    import os
    import pandas as pd
    import pandas_datareader.data as web
    import pickle
    import requests
    
    
    def save_sp500_tickers():
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
        soup = bs.BeautifulSoup(resp.text, 'lxml')
        table = soup.find('table', {'class': 'wikitable sortable'})
        tickers = []
        for row in table.findAll('tr')[1:]:
            ticker = row.findAll('td')[0].text
            tickers.append(ticker)
    
        with open("sp500tickers.pickle", "wb") as f:
            pickle.dump(tickers, f)
    
        return tickers
    
    
    # save_sp500_tickers()
    
    
    def get_data_from_yahoo(reload_sp500=False):
        if reload_sp500:
            tickers = save_sp500_tickers()
        else:
            with open("sp500tickers.pickle", "rb") as f:
                tickers = pickle.load(f)
    
        if not os.path.exists('stock_dfs'):
            os.makedirs('stock_dfs')
    
        start = dt.datetime(2000, 1, 1)
        end = dt.datetime(2016, 12, 31)
    
        for ticker in tickers:
            # just in case your connection breaks, we'd like to save our progress!
            if not os.path.exists('stock_dfs/{}.csv'.format(ticker)):
                df = web.DataReader(ticker, "yahoo", start, end)
                df.to_csv('stock_dfs/{}.csv'.format(ticker))
            else:
                print('Already have {}'.format(ticker))
    
    
    get_data_from_yahoo()
    View Code
     

    Go ahead and run this. You might want to import time and add a time.sleep(0.5) or something if Yahoo throttles you. At the time of my writing this, Yahoo did not throttle me at all and I was able to run this all the way through without any issues. It might take you a while still, however, especially depending on your machine. The good news is, however, we wont need to do it again! In practice, again, since this is daily data, however, you might do this once a day.

    Also, if you have a slow internet, you don't need to do all tickers, even just 10 would be enough, so you can just do for ticker in tickers[:10]:, or something like that to speed things up.

    Once you have the data downloaded, we're going to compile the data we're interested in into one large Pandas DataFrame.

    Combining all S&P 500 company prices into one DataFrame

    In the previous tutorial, we grabbed the Yahoo Finance data for the entire S&P 500 of companies. In this tutorial, we're going to bring this data together into one DataFrame.

    While we do have all of the data at our disposal, we may want to actually assess the data together. To do this, we're going to join all of the stock datasets together. Each of the stock files at the moment come with: Open, High, Low, Close, Volume, and Adj Close. At least to start, we're mostly just interested in the adjusted close for now.

    def compile_data():
        with open("sp500tickers.pickle","rb") as f:
            tickers = pickle.load(f)
    
        main_df = pd.DataFrame()

    To begin, we pull our previously-made list of tickers, and begin with an empty DataFrame, called main_df. Now, we're ready to read in each stock's dataframe:

        for count,ticker in enumerate(tickers):
            df = pd.read_csv('stock_dfs/{}.csv'.format(ticker))
            df.set_index('Date', inplace=True)

    You do not need to use Python's enumerate here, I am just using it so we know where we are in the process of reading in all of the data. You could just iterate over the tickers. From this point, we *could* generate extra columns with interesting data, like:

            df['{}_HL_pct_diff'.format(ticker)] = (df['High'] - df['Low']) / df['Low']
            df['{}_daily_pct_chng'.format(ticker)] = (df['Close'] - df['Open']) / df['Open']

    For now, however, we're not going to be bothered with this. Just know this could be a path to pursue down the road. Instead, we're really just interested in that Adj Close column:

            df.rename(columns={'Adj Close':ticker}, inplace=True)
            df.drop(['Open','High','Low','Close','Volume'],1,inplace=True)

    Now we've got just that column (or maybe extras, like above...but remember, in this example, we're not doing the HL_pct_diff or daily_pct_chng). Notice that we have renamed the Adj Close column to whatever the ticker name is. Let's begin building the shared dataframe:

            if main_df.empty:
                main_df = df
            else:
                main_df = main_df.join(df, how='outer')

    If there's nothing in the main_df, then we'll start with the current df, otherwise we're going to use Pandas' join.

    Still within this for loop, we'll add two more lines:

            if count % 10 == 0:
                print(count)

    This will just output the count of the current ticker if it's evenly divisible by 10. What count % 10 gives us is the remainder if count was to be divided by 10. So if we ask if count % 10 == 0, we're only going to see the if statement True if the current count, divided by 10, has a remainder of 0, or if it is perfectly divisible by 10.

    When we're all done with the for-loop:

        print(main_df.head())
        main_df.to_csv('sp500_joined_closes.csv')

    This function and calling it up to this point:

        with open("sp500tickers.pickle","rb") as f:
            tickers = pickle.load(f)
    
        main_df = pd.DataFrame()
        
        for count,ticker in enumerate(tickers):
            df = pd.read_csv('stock_dfs/{}.csv'.format(ticker))
            df.set_index('Date', inplace=True)
    
            df.rename(columns={'Adj Close':ticker}, inplace=True)
            df.drop(['Open','High','Low','Close','Volume'],1,inplace=True)
    
            if main_df.empty:
                main_df = df
            else:
                main_df = main_df.join(df, how='outer')
    
            if count % 10 == 0:
                print(count)
        print(main_df.head())
        main_df.to_csv('sp500_joined_closes.csv')
    
    
    compile_data()

    Full code up to this point:

    import bs4 as bs
    import datetime as dt
    import os
    import pandas as pd
    import pandas_datareader.data as web
    import pickle
    import requests
    
    
    def save_sp500_tickers():
        resp = requests.get('http://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
        soup = bs.BeautifulSoup(resp.text, 'lxml')
        table = soup.find('table', {'class': 'wikitable sortable'})
        tickers = []
        for row in table.findAll('tr')[1:]:
            ticker = row.findAll('td')[0].text
            tickers.append(ticker)
            
        with open("sp500tickers.pickle","wb") as f:
            pickle.dump(tickers,f)
            
        return tickers
    
    
    def get_data_from_yahoo(reload_sp500=False):
        
        if reload_sp500:
            tickers = save_sp500_tickers()
        else:
            with open("sp500tickers.pickle","rb") as f:
                tickers = pickle.load(f)
        
        if not os.path.exists('stock_dfs'):
            os.makedirs('stock_dfs')
    
        start = dt.datetime(2000, 1, 1)
        end = dt.datetime(2016, 12, 31)
        
        for ticker in tickers:
            # just in case your connection breaks, we'd like to save our progress!
            if not os.path.exists('stock_dfs/{}.csv'.format(ticker)):
                df = web.DataReader(ticker, "yahoo", start, end)
                df.to_csv('stock_dfs/{}.csv'.format(ticker))
            else:
                print('Already have {}'.format(ticker))
    
    
    def compile_data():
        with open("sp500tickers.pickle","rb") as f:
            tickers = pickle.load(f)
    
        main_df = pd.DataFrame()
        
        for count,ticker in enumerate(tickers):
            df = pd.read_csv('stock_dfs/{}.csv'.format(ticker))
            df.set_index('Date', inplace=True)
    
            df.rename(columns={'Adj Close':ticker}, inplace=True)
            df.drop(['Open','High','Low','Close','Volume'],1,inplace=True)
    
            if main_df.empty:
                main_df = df
            else:
                main_df = main_df.join(df, how='outer')
    
            if count % 10 == 0:
                print(count)
        print(main_df.head())
        main_df.to_csv('sp500_joined_closes.csv')
    
    
    compile_data()

    In the next tutorial, we're going to attempt to see if we can quickly find any relationships in the data.

  • 相关阅读:
    信息安全系统设计基础第一次实验报告
    学号20145220《信息安全系统设计基础》第8周学习总结
    学号20145220《信息安全系统设计基础》第8周学习总结
    学号20145220《信息安全系统设计基础》第7周学习总结
    学号20145220《信息安全系统设计基础》第7周学习总结
    学号20145220 《信息安全系统设计基础》第6周学习总结
    # 学号20145220 《信息安全系统设计基础》第6周学习总结
    # 学号 20145220《信息安全系统设计基础》第5周学习总结
    java读取文件中每一行字符串的出现次数
    【转载】Idea常见问题整理
  • 原文地址:https://www.cnblogs.com/ttrrpp/p/6735105.html
Copyright © 2011-2022 走看看