extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(extra_update_ops): train_step = optimizer.minimize(mean_loss)