Lower Star Image Filtrations

In this notebook, we will explore the lower star filtration, or sublevelset filtration, on image data. This filtration will allow us to express local minimums (or maximums if flipped) as birth times and saddle points as death times in a 0-dimensional persistence diagram. This is a useful and provably stable way of summarizing critical points in an image.

[1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
from scipy import ndimage
import PIL

from persim import plot_diagrams
from ripser import ripser, lower_star_img

We now define a function which constructs a 0-dimension lower star filtration on an image. It constructs a sparse distance matrix in which every pixel in the image is a vertex, and every vertex is connected to its 8 spatial neighbors (unless it’s at the boundary, in which case there may be fewer). The edge weights are taken to be the maximum of the two pixel values they connect (hence the “lower star”).

In spite of the large number of pixels, this code will be fast (nearly linear in the number of pixels), as there are a linear number of edges in the number of pixels, and only 0-dimension homology is computed.

Gaussian Blob Example

Now, let’s test out this code on an example with an image containing some Gaussian blobs. We will put three negative Gaussians in this image; one which reaches its local min at -3, one at -2, and one at -1:

[2]:
ts = np.linspace(-1, 1, 100)
x1 = np.exp(-ts**2/(0.1**2))
ts -= 0.4
x2 = np.exp(-ts**2/(0.1**2))
img = -x1[None, :]*x1[:, None] - 2*x1[None, :]*x2[:, None] - 3*x2[None, :]*x2[:, None]
plt.imshow(img)
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_5_0.png
[3]:
dgm = lower_star_img(img)

plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.imshow(img)
plt.colorbar()
plt.title("Test Image")
plt.subplot(122)
plot_diagrams(dgm)
plt.title("0-D Persistence Diagram")
plt.tight_layout()
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_6_0.png

As can be seen in this example, there are three dots corresponding to each of the 3 Gaussians, each of which is born at the respective local min. Two of them die at around 0 where they meet other classes at saddle points, and the one born at -3 absorbs the other ones and lives forever because of the “elder rule.”

Cell Biology Image Example

Let’s now look at a slightly more exciting example. We’ll analyze the following creative commons image of plant cells, taken from this link

582e70ad826c4df9be6c695ef3926d88

When we convert this image to grayscale in the range \([0, 255]\), where \(0\) is dark and \(255\) is bright, we see that the interiors of the cells are high while each cell meets at a saddle point somewhere on the boundary which is closer to \(0\).

In this case, we posit that there is a local max of large persistence within each cell; that is, we want to perform a superlevelset / upper star filtration. To hack our code so that it works for local maxes instead of local mins, we can simply feed the lower star filtration function the negative of the image. Let’s do this now and plot the lifetimes:

[4]:
cells_original = plt.imread("Cells.jpg")
cells_grey = np.asarray(PIL.Image.fromarray(cells_original).convert('L'))

plt.subplot(121)
plt.title(cells_original.shape)
plt.imshow(cells_original)
plt.axis('off')
plt.subplot(122)
plt.title(cells_grey.shape)
plt.imshow(cells_grey, cmap='gray')
plt.axis('off')
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_10_0.png
[5]:
dgm = lower_star_img(-cells_grey)

plt.figure(figsize=(6, 6))
plot_diagrams(dgm, lifetime=True)
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_11_0.png

Let’s now pick a persistence threshold, above which we consider a dot to be associated to a cell.

Ripser.py does not currently return representatives for 0-dimensional homology classes but we can do a little workaround where we add a small amount of uniform noise to each pixel. This makes every pixel have a unique value so we can simply find the pixel whose value is equal to the birth time of the class we are looking for. We also perform a local averaging before we do this to encourage the representatives of the maximums to be closer to the center of the cell.

[6]:
smoothed = ndimage.uniform_filter(cells_grey.astype(np.float64), size=10)
smoothed += 0.01 * np.random.randn(*smoothed.shape)

plt.figure(figsize=(10, 5))
plt.subplot(121)
im = plt.imshow(cells_grey, cmap='gray')
plt.colorbar(im, fraction=0.03)

plt.subplot(122)
im = plt.imshow(smoothed, cmap='gray')
plt.colorbar(im, fraction=0.03)

plt.tight_layout()
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_13_0.png
[7]:
dgm = lower_star_img(-smoothed)
plot_diagrams(dgm, lifetime=True)
plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_14_0.png

We’ll eyeball a cutoff and look at all points with lifetime greater than 70. Below, we show each of the 0-dimensional pixel representatives highlighted in the original image.

[8]:
thresh = 70
idxs = np.arange(dgm.shape[0])
idxs = idxs[np.abs(dgm[:, 1] - dgm[:, 0]) > thresh]

plt.figure(figsize=(8, 5))
plt.imshow(cells_original)

X, Y = np.meshgrid(np.arange(smoothed.shape[1]), np.arange(smoothed.shape[0]))
X = X.flatten()
Y = Y.flatten()
for idx in idxs:
    bidx = np.argmin(np.abs(smoothed + dgm[idx, 0]))
    plt.scatter(X[bidx], Y[bidx], 20, 'k')
plt.axis('off')

plt.show()
../_images/notebooks_Lower_Star_Image_Filtrations_16_0.png

The threshold can certainly be tuned, but we see reasonably good results even with these choices. There is about one max per cell; some have duplicates, and others are missing. Overall this naive approach does a good job of identifying the cells even though they have very different shapes.