  • 【深度学习 TPU、tensorflow】kaggle竞赛 使用TPU对104种花朵进行分类 第一次尝试 40%准确率





    最新的Tensorflow版本(TF 2.1)专注于TPU,现在在使用自定义训练循环的模型中,它们都通过Keras高级API和较低级别得到支持。






    Getting started with 100+ flowers on TPU

    import math, re, os
    import tensorflow as tf
    import numpy as np
    from matplotlib import pyplot as plt
    from kaggle_datasets import KaggleDatasets
    from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
    print("Tensorflow version " + tf.__version__)
    AUTO = tf.data.experimental.AUTOTUNE

    TPU or GPU detection

    # Detect hardware, return appropriate distribution strategy
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
        print('Running on TPU ', tpu.master())
    except ValueError:
        tpu = None
    if tpu:
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.
    print("REPLICAS: ", strategy.num_replicas_in_sync)

    Competition data access

    TPUs read data directly from Google Cloud Storage (GCS). This Kaggle utility will copy the dataset to a GCS bucket co-located with the TPU. If you have multiple datasets attached to the notebook, you can pass the name of a specific dataset to the get_gcs_path function. The name of the dataset is the name of the directory it is mounted in. Use !ls /kaggle/input/ to list attached datasets.

    GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"


    IMAGE_SIZE = [512, 512] # At this size, a GPU will run out of memory. Use the TPU.
                            # For GPU training, please select 224 x 224 px image size.
    EPOCHS = 12
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync
    GCS_PATH_SELECT = { # available image sizes
        192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
        224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
        331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
        512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
    TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
    VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
    TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') # predictions on this dataset should be submitted for the competition
    CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
               'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
               'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
               'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
               'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
               'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
               'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
               'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
               'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
               'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
               'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

    Visualization utilities

    data -> pixels, nothing of much interest for the machine learning practitioner in this section.

    # numpy and matplotlib defaults
    np.set_printoptions(threshold=15, linewidth=80)
    def batch_to_numpy_images_and_labels(data):
        images, labels = data
        numpy_images = images.numpy()
        numpy_labels = labels.numpy()
        if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
            numpy_labels = [None for _ in enumerate(numpy_images)]
        # If no labels, only image IDs, return None for labels (this is the case for test data)
        return numpy_images, numpy_labels
    def title_from_label_and_target(label, correct_label):
        if correct_label is None:
            return CLASSES[label], True
        correct = (label == correct_label)
        return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"u2192" if not correct else '',
                                    CLASSES[correct_label] if not correct else ''), correct
    def display_one_flower(image, title, subplot, red=False, titlesize=16):
        if len(title) > 0:
            plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
        return (subplot[0], subplot[1], subplot[2]+1)
    def display_batch_of_images(databatch, predictions=None):
        """This will work with:
        display_batch_of_images(images, predictions)
        display_batch_of_images((images, labels))
        display_batch_of_images((images, labels), predictions)
        # data
        images, labels = batch_to_numpy_images_and_labels(databatch)
        if labels is None:
            labels = [None for _ in enumerate(images)]
        # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
        rows = int(math.sqrt(len(images)))
        cols = len(images)//rows
        # size and spacing
        FIGSIZE = 13.0
        SPACING = 0.1
        if rows < cols:
        # display
        for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
            title = '' if label is None else CLASSES[label]
            correct = True
            if predictions is not None:
                title, correct = title_from_label_and_target(predictions[i], label)
            dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
            subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
        if label is None and predictions is None:
            plt.subplots_adjust(wspace=0, hspace=0)
            plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    def display_confusion_matrix(cmat, score, precision, recall):
        ax = plt.gca()
        ax.matshow(cmat, cmap='Reds')
        ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
        plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
        ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
        plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
        titlestring = ""
        if score is not None:
            titlestring += 'f1 = {:.3f} '.format(score)
        if precision is not None:
            titlestring += '
    precision = {:.3f} '.format(precision)
        if recall is not None:
            titlestring += '
    recall = {:.3f} '.format(recall)
        if len(titlestring) > 0:
            ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    def display_training_curves(training, validation, title, subplot):
        if subplot%10==1: # set up the subplots on the first call
            plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        ax = plt.subplot(subplot)
        ax.set_title('model '+ title)
        ax.legend(['train', 'valid.'])


    def decode_image(image_data):
        image = tf.image.decode_jpeg(image_data, channels=3)
        image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
        image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
        return image
    def read_labeled_tfrecord(example):
            "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
            "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
        example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
        image = decode_image(example['image'])
        label = tf.cast(example['class'], tf.int32)
        return image, label # returns a dataset of (image, label) pairs
    def read_unlabeled_tfrecord(example):
            "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
            "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
            # class is missing, this competitions's challenge is to predict flower classes for the test dataset
        example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
        image = decode_image(example['image'])
        idnum = example['id']
        return image, idnum # returns a dataset of image(s)
    def load_dataset(filenames, labeled=True, ordered=False):
        # Read from TFRecords. For optimal performance, reading from multiple files at once and
        # disregarding data order. Order does not matter since we will be shuffling the data anyway.
        ignore_order = tf.data.Options()
        if not ordered:
            ignore_order.experimental_deterministic = False # disable order, increase speed
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
        dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
        dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
        # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
        return dataset
    def data_augment(image, label):
        # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
        # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
        # of the TPU while the TPU itself is computing gradients.
        image = tf.image.random_flip_left_right(image)
        #image = tf.image.random_saturation(image, 0, 2)
        return image, label   
    def get_training_dataset():
        dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        dataset = dataset.repeat() # the training dataset must repeat for several epochs
        dataset = dataset.shuffle(2048)
        dataset = dataset.batch(BATCH_SIZE)
        dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
        return dataset
    def get_validation_dataset(ordered=False):
        dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
        dataset = dataset.batch(BATCH_SIZE)
        dataset = dataset.cache()
        dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
        return dataset
    def get_test_dataset(ordered=False):
        dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
        dataset = dataset.batch(BATCH_SIZE)
        dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
        return dataset
    def count_data_items(filenames):
        # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
        n = [int(re.compile(r"-([0-9]*).").search(filename).group(1)) for filename in filenames]
        return np.sum(n)
    NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
    print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

    Dataset visualizations

    # data dump
    print("Training data shapes:")
    for image, label in get_training_dataset().take(3):
        print(image.numpy().shape, label.numpy().shape)
    print("Training data label examples:", label.numpy())
    print("Validation data shapes:")
    for image, label in get_validation_dataset().take(3):
        print(image.numpy().shape, label.numpy().shape)
    print("Validation data label examples:", label.numpy())
    print("Test data shapes:")
    for image, idnum in get_test_dataset().take(3):
        print(image.numpy().shape, idnum.numpy().shape)
    print("Test data IDs:", idnum.numpy().astype('U')) # U=unicode string
    # Peek at training data
    training_dataset = get_training_dataset()
    training_dataset = training_dataset.unbatch().batch(20)
    train_batch = iter(training_dataset)
    # run this cell again for next set of images


    # peer at test data
    test_dataset = get_test_dataset()
    test_dataset = test_dataset.unbatch().batch(20)
    test_batch = iter(test_dataset)
    # run this cell again for next set of images



    Not the best but it converges …

    with strategy.scope():
        #pretrained_model = tf.keras.applications.DenseNet201(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
        #pretrained_model = tf.keras.applications.Xception(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
        pretrained_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
        pretrained_model.trainable = False # False = transfer learning, True = fine-tuning
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(len(CLASSES), activation='softmax')
        loss = 'sparse_categorical_crossentropy',


    history = model.fit(get_training_dataset(), steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, validation_data=get_validation_dataset())
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

    Confusion matrix

    cmdataset = get_validation_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
    cm_probabilities = model.predict(images_ds)
    cm_predictions = np.argmax(cm_probabilities, axis=-1)
    print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
    print("Predicted labels: ", cm_predictions.shape, cm_predictions)
    cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
    score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
    display_confusion_matrix(cmat, score, precision, recall)
    print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))


    test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.
    print('Computing predictions...')
    test_images_ds = test_ds.map(lambda image, idnum: image)
    probabilities = model.predict(test_images_ds)
    predictions = np.argmax(probabilities, axis=-1)
    print('Generating submission.csv file...')
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
    np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
    !head submission.csv

    Visual validation

    dataset = get_validation_dataset()
    dataset = dataset.unbatch().batch(20)
    batch = iter(dataset)
    # run this cell again for next set of images
    images, labels = next(batch)
    probabilities = model.predict(images)
    predictions = np.argmax(probabilities, axis=-1)
    display_batch_of_images((images, labels), predictions)


