This is a great demo of the Discrete Cosine Transform (DCT) and is the essence of JPEG encoding. The demo was originally created by Frank Ong.
# Import functions and libraries
import numpy as np
import matplotlib.pyplot as plt
import scipy
from numpy import pi
from numpy import sin
from numpy import zeros
from numpy import r_
from scipy import signal
from scipy import misc # pip install Pillow
import matplotlib.pylab as pylab
%matplotlib inline
pylab.rcParams['figure.figsize'] = (20.0, 7.0)
# im = misc.imread("einstein.tif").astype(float)
# im = misc.imread("house.tif").astype(float)
im = misc.imread("zelda.tif").astype(float)
# im = misc.imread("barbara.png").astype(float)
f = plt.figure()
plt.imshow(im,cmap='gray')
def dct2(a):
return scipy.fftpack.dct( scipy.fftpack.dct( a, axis=0, norm='ortho' ), axis=1, norm='ortho' )
def idct2(a):
return scipy.fftpack.idct( scipy.fftpack.idct( a, axis=0 , norm='ortho'), axis=1 , norm='ortho')
imsize = im.shape
dct = np.zeros(imsize)
# Do 8x8 DCT on image (in-place)
for i in r_[:imsize[0]:8]:
for j in r_[:imsize[1]:8]:
dct[i:(i+8),j:(j+8)] = dct2( im[i:(i+8),j:(j+8)] )
pos = 128
# Extract a block from image
plt.figure()
plt.imshow(im[pos:pos+8,pos:pos+8],cmap='gray')
plt.title( "An 8x8 Image block")
# Display the dct of that block
plt.figure()
plt.imshow(dct[pos:pos+8,pos:pos+8],cmap='gray',vmax= np.max(dct)*0.01,vmin = 0, extent=[0,pi,pi,0])
plt.title( "An 8x8 DCT block")
# Display entire DCT
plt.figure()
plt.imshow(dct,cmap='gray',vmax = np.max(dct)*0.01,vmin = 0)
plt.title( "8x8 DCTs of the image")
# Threshold
thresh = 0.012
dct_thresh = dct * (abs(dct) > (thresh*np.max(dct)))
plt.figure()
plt.imshow(dct_thresh,cmap='gray',vmax = np.max(dct)*0.01,vmin = 0)
plt.title( "Thresholded 8x8 DCTs of the image")
percent_nonzeros = np.sum( dct_thresh != 0.0 ) / (imsize[0]*imsize[1]*1.0)
print "Keeping only %f%% of the DCT coefficients" % (percent_nonzeros*100.0)
im_dct = np.zeros(imsize)
for i in r_[:imsize[0]:8]:
for j in r_[:imsize[1]:8]:
im_dct[i:(i+8),j:(j+8)] = idct2( dct_thresh[i:(i+8),j:(j+8)] )
plt.figure()
plt.imshow( np.hstack( (im, im_dct) ) ,cmap='gray')
plt.title("Comparison between original and DCT compressed images" )
dft = zeros(imsize,dtype='complex');
im_dft = zeros(imsize,dtype='complex');
# 8x8 DFT
for i in r_[:imsize[0]:8]:
for j in r_[:imsize[1]:8]:
dft[i:(i+8),j:(j+8)] = np.fft.fft2( im[i:(i+8),j:(j+8)] )
# Thresh
thresh = 0.013
dft_thresh = dft * (abs(dft) > (thresh*np.max(abs(dft))))
percent_nonzeros_dft = np.sum( dft_thresh != 0.0 ) / (imsize[0]*imsize[1]*1.0)
print "Keeping only %f%% of the DCT coefficients" % (percent_nonzeros*100.0)
print "Keeping only %f%% of the DFT coefficients" % (percent_nonzeros_dft*100.0)
# 8x8 iDFT
for i in r_[:imsize[0]:8]:
for j in r_[:imsize[1]:8]:
im_dft[i:(i+8),j:(j+8)] = np.fft.ifft2( dft_thresh[i:(i+8),j:(j+8)] )
plt.figure()
plt.imshow( np.hstack( (im, im_dct, abs(im_dft)) ) ,cmap='gray')
plt.title("Comparison between original, DCT compressed and DFT compressed images" )
pos = 128
# Extract a block from image
plt.figure()
plt.imshow(im[pos:pos+8,pos:pos+8],cmap='gray')
plt.title( "An 8x8 Image block")
# Display the dct of that block
plt.figure()
plt.imshow(abs(dft[pos:pos+4,pos:pos+4]),cmap='gray',vmax= np.max(abs(dft))*0.01,vmin = 0, extent=[0,2*pi,2*pi,0])
plt.title( "A 4x4 DFT block (because of conjugate symmetry)")