import argparse import logging import time from random import random from PIL import Image, ImageFilter from skimage import io import numpy as np def diff(img, x1, y1, x2, y2):# edge weights _out = np.sum((img[x1, y1] - img[x2, y2]) ** 2) return np.sqrt(_out) def create_edge(img, width, x, y, x1, y1, diff): vertex_id = lambda x, y: y * width + x w = diff(img, x, y, x1, y1) return (vertex_id(x, y), vertex_id(x1, y1), w) def build_graph(img, width, height, diff, neighborhood_8=False): graph_edges = [] for y in range(height): for x in range(width): if x > 0: graph_edges.append(create_edge(img, width, x, y, x-1, y, diff)) if y > 0: graph_edges.append(create_edge(img, width, x, y, x, y-1, diff)) if neighborhood_8: if x > 0 and y > 0: graph_edges.append(create_edge(img, width, x, y, x-1, y-1, diff)) if x > 0 and y < height-1: graph_edges.append(create_edge(img, width, x, y, x-1, y+1, diff)) return graph_edges def GetGaussianBlurImage2Graph(sigma, neighbor, input_file): if neighbor != 4 and neighbor!= 8: logger.warn('Invalid neighborhood choosed. The acceptable values are 4 or 8.') start_time = time.time() image_file = Image.open(input_file) size = image_file.size # (width, height) in Pillow/PIL logger.info('Image info: {} | {} | {}'.format(image_file.format, size, image_file.mode)) # Gaussian Filter logger.info("GaussianBlur...") smooth = image_file.filter(ImageFilter.GaussianBlur(sigma)) smooth = np.array(smooth).astype(int)# height x width x 3 logger.info("Creating graph...") graph_edges = build_graph(smooth, size[1], size[0], diff, neighbor==8) logger.info("Numbers of graph edges: {}".format(len(graph_edges))) logger.info('Total running time: {:0.4}s'.format(time.time() - start_time)) if __name__ == '__main__': # argument parser parser = argparse.ArgumentParser(description='Img2Graph(Graph-based Segmentation)') parser.add_argument('--sigma', type=float, default=0.5, help='a float for the Gaussin Filter') parser.add_argument('--neighbor', type=int, default=8, choices=[4, 8], help='choose the neighborhood format, 4 or 8') parser.add_argument('--input-file', type=str, default="./datas/BigTree.jpg", help='the file path of the input image') args = parser.parse_args() # basic logging settings logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M') logger = logging.getLogger(__name__) print("input:",'sigma=',args.sigma, 'neighbor=', args.neighbor,'input-file=',args.input_file) GetGaussianBlurImage2Graph(args.sigma, args.neighbor, args.input_file)