Image segmentation using K-means Clustering in Python

Image segmentation is the process of dividing an image into groups in order to appropriately identify the pixels in a decision-making application. It separates a picture into a number of distinct sections with high similarity between pixels in each and high contrast between regions.

There are a variety of image segmentation methods, but clustering is one of the most efficient. It is a widely used technique in the machine learning field.

In this tutorial, we will learn how to do image segmentation using K means clustering in Python programming.

K-Means Clustering is an unsupervised learning technique (for more information on K-Means Clustering, see here).

So, let’s get this tutorial started….

First, make sure you have all of the required Python libraries loaded on your system, and then import all of them as shown below.

import numpy
import math
import random
from PIL import Image
import random
import sys
import operator
import os

After importing the needed libraries, you must read in the image as pixel values and get its size.

Below is our code that will do it:

image_path = 'C:\\Users\\OS\\imageseg_folder\\inputimg.jpg'
image_name = "inputimg"
image=Image.open(image_path,'r')
width, height=image.size
imagePixels=list(image.getdata())

Now we have to select K points at random from the image to initialize the centers with the help of the snippet given below:

for K in ( 2, 5, 10):
    print("Segmenting image by K = {}".format(K))
    
    initial_centers = set()
    for x in range(K):
        initial_centers.add(imagePixels[random.randint(0,len(imagePixels)-1)])

Now we’ll use the k-means algorithm. We’ll start by looking for convergence. Calculate the distance between each pixel in the image and assign it to the corresponding centre (from which its distance is the shortest) to generate clusters.

Lastly, compute new centers using the sample means of the pixel R, G, B values.

old_centers = set()
    new_centers = initial_centers
   
    while old_centers != new_centers:
        old_centers = new_centers
        clusterDict = dict([(key, []) for key in new_centers])
        
        for eachPixelTupleIndex in range(len(imagePixels)):
            distanceDict={}
            for eachCenter in new_centers:
                pixelValues=imagePixels[eachPixelTupleIndex]
                distanceList= numpy.subtract(pixelValues,eachCenter)
                distance=0
                for eachNumber in distanceList:
                    distance+=eachNumber**2
                distance= math.sqrt(distance)
                distanceDict[eachCenter]=distance
            
            bestCenter=min(distanceDict.items(), key=operator.itemgetter(1))[0]
            clusterDict[bestCenter].append(eachPixelTupleIndex)
new_centers = set()
        for center in clusterDict:
            new_center_temp = (0,0,0)
            for pixelIndex in clusterDict[center]:
                new_center_temp = tuple(map(operator.add, new_center_temp, imagePixels[pixelIndex]))
            new_center = tuple(map(lambda x: int(x/len(clusterDict[center])), new_center_temp))
            new_centers.add(new_center)

All that remains now is to reproduce and save the images.

 newIm = Image.new("RGB", (width, height))
    pix = newIm.load()
    for i in range(1, height+1):
        for j in range(width):
            for center in clusterDict:
                if ((i - 1) * width + j) in clusterDict[center]:
                    pix[j,i-1] = center
    newIm.save("segmented_{}_K{}.png".format(image_name, K), "PNG")

print("Hurray!It's Done!")

Hurray, it’s all finished! Let’s have a look at the results.

So, we’ve shown the original image and segmented images at various k values above. We hope you found this tutorial useful. 

Leave a Reply

Your email address will not be published.