Image Segmentation with Tensorflow using CNNs and Conditional Random Fields
A post showing how to perform Image Segmentation with a recently released TF-Slim library and pretrained models. It covers the training and post-processing using Conditional Random Fields.
Introduction
In the previous post, we implemented the upsampling and made sure it is correct by comparing it to the implementation of the scikit-image library. To be more specific we had FCN-32 Segmentation network implemented which is described in the paper Fully convolutional networks for semantic segmentation.
In this post we will perform a simple training: we will get a sample image fromPASCAL VOC dataset along with annotation, train our network on them and test our network on the same image. It was done this way so that it can also be run on CPU – it takes only 10 iterations for the training to complete. Another point of this post is to show that segmentation that our network (FCN-32s) produces is very coarse – even if we run it on the same image that we were training it on. In this post we tackle this problem by performing Conditional Random Field post-processing stage, which refines our segmentation by taking into account pure RGB features of image and probabilities produced by our network. Overall, we get a refined segmentation. The set-up of this post is very simple on purpose. Similar approach to Segmentation was described in the paper Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs by Chen et al. Please, take into account that setup in this post was made only to show limitation of FCN-32s model, to perform the training for real-life scenario, we refer readers to the paper Fully convolutional networks for semantic segmentation.
The blog post is created using jupyter notebook. After each chunk of a code you can see the result of its evaluation. You can also get the notebook file from here. The content of the blog post is partially borrowed from slim walkthough notebook.
Setup
To be able to run the code, you will need to have Tensorflow installed. I have used r0.12. You will need to use this fork of tensorflow/models.
I am also using scikit-image library and numpy for this tutorial plus other dependencies. One of the ways to install them is to download Anaconda software package for python.
Follow all the other steps described in the previous posts – it shows how to download the VGG-16 model and perform all other necessary for this tutorial steps.
Upsampling helper functions and Image Loading
In this part, we define helper functions that were used in the previous post. If you recall, we used upsampling to upsample the downsampled predictions that we get from our network. We get downsampled predictions because of max-pooling layers that are used in VGG-16 network.
We also write code for image and respective ground-truth segmentation loading. The code is well-commented, so don’t be afraid to read it.
import numpy as np
def get_kernel_size(factor):
"""
Find the kernel size given the desired factor of upsampling.
"""
return 2 * factor - factor % 2
def upsample_filt(size):
"""
Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
return (1 - abs(og[0] - center) / factor) *
(1 - abs(og[1] - center) / factor)
def bilinear_upsample_weights(factor, number_of_classes):
"""
Create weights matrix for transposed convolution with bilinear filter
initialization.
"""
filter_size = get_kernel_size(factor)
weights = np.zeros((filter_size,
filter_size,
number_of_classes,
number_of_classes), dtype=np.float32)
upsample_kernel = upsample_filt(filter_size)
for i in xrange(number_of_classes):
weights[:, :, i, i] = upsample_kernel
return weights
%matplotlib inline
from __future__ import division
import os
import sys
import tensorflow as tf
import skimage.io as io
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
sys.path.append("/home/dpakhom1/workspace/my_models/slim/")
checkpoints_dir = '/home/dpakhom1/checkpoints'
image_filename = 'cat.jpg'
annotation_filename = 'cat_annotation.png'
image_filename_placeholder = tf.placeholder(tf.string)
annotation_filename_placeholder = tf.placeholder(tf.string)
is_training_placeholder = tf.placeholder(tf.bool)
feed_dict_to_use = {image_filename_placeholder: image_filename,
annotation_filename_placeholder: annotation_filename,
is_training_placeholder: True}
image_tensor = tf.read_file(image_filename_placeholder)
annotation_tensor = tf.read_file(annotation_filename_placeholder)
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)
annotation_tensor = tf.image.decode_png(annotation_tensor, channels=1)
# Get ones for each class instead of a number -- we need that
# for cross-entropy loss later on. Sometimes the groundtruth
# masks have values other than 1 and 0.
class_labels_tensor = tf.equal(annotation_tensor, 1)
background_labels_tensor = tf.not_equal(annotation_tensor, 1)
# Convert the boolean values into floats -- so that
# computations in cross-entropy loss is correct
bit_mask_class = tf.to_float(class_labels_tensor)
bit_mask_background = tf.to_float(background_labels_tensor)
combined_mask = tf.concat(concat_dim=2, values=[bit_mask_class,
bit_mask_background])
# Lets reshape our input so that it becomes suitable for
# tf.softmax_cross_entropy_with_logits with [batch_size, num_classes]
flat_labels = tf.reshape(tensor=combined_mask, shape=(-1, 2))
Loss function definition and training using Adam Optimization Algorithm.
In this part, we connect everything together: add the upsampling layer to our network, define the loss function that can be differentiated and perform training.
Following the Fully convolutional networks for semantic segmentation paper, we define loss as a pixel-wise cross-entropy. We can do this, because after upsampling we got the predictions of the same size as the input and we can compare the acquired segmentation to the respective ground-truth segmentation:
Where NN is a number of pixels, KK - number of classes, tnktnk a variable representing the ground-truth with 1-of-KK coding scheme, ynkynk represent our predictions (softmax output).
For this case we use Adam optimizer because it requires less parameter tuning to get good results.
In this particular case we train and evaluate our results on one image – which is a much simpler case compared to real-world scenario. We do this to show the drawback of the approach – just to show that is has poor localization copabilities. If this holds for this simple case, it will also show similar of worse results on unseen images.
import numpy as np
import tensorflow as tf
import sys
import os
from matplotlib import pyplot as plt
fig_size = [15, 4]
plt.rcParams["figure.figsize"] = fig_size
import urllib2
slim = tf.contrib.slim
from nets import vgg
from preprocessing import vgg_preprocessing
# Load the mean pixel values and the function
# that performs the subtraction from each pixel
from preprocessing.vgg_preprocessing import (_mean_image_subtraction,
_R_MEAN, _G_MEAN, _B_MEAN)
upsample_factor = 32
number_of_classes = 2
log_folder = '/home/dpakhom1/tf_projects/segmentation/log_folder'
vgg_checkpoint_path = os.path.join(checkpoints_dir, 'vgg_16.ckpt')
# Convert image to float32 before subtracting the
# mean pixel value
image_float = tf.to_float(image_tensor, name='ToFloat')
# Subtract the mean pixel value from each pixel
mean_centered_image = _mean_image_subtraction(image_float,
[_R_MEAN, _G_MEAN, _B_MEAN])
processed_images = tf.expand_dims(mean_centered_image, 0)
upsample_filter_np = bilinear_upsample_weights(upsample_factor,
number_of_classes)
upsample_filter_tensor = tf.constant(upsample_filter_np)
# Define the model that we want to use -- specify to use only two classes at the last layer
with slim.arg_scope(vgg.vgg_arg_scope()):
logits, end_points = vgg.vgg_16(processed_images,
num_classes=2,
is_training=is_training_placeholder,
spatial_squeeze=False,
fc_conv_padding='SAME')
downsampled_logits_shape = tf.shape(logits)
# Calculate the ouput size of the upsampled tensor
upsampled_logits_shape = tf.pack([
downsampled_logits_shape[0],
downsampled_logits_shape[1] * upsample_factor,
downsampled_logits_shape[2] * upsample_factor,
downsampled_logits_shape[3]
])
# Perform the upsampling
upsampled_logits = tf.nn.conv2d_transpose(logits, upsample_filter_tensor,
output_shape=upsampled_logits_shape,
strides=[1, upsample_factor, upsample_factor,