Skip to content Skip to sidebar Skip to footer

Is It Possible To Use Vector Methods To Shift Images Stored In A Numpy Ndarray For Data Augmentation?

Background: This is one of the exercise problems in the text book Hands on Machine Learning by Aurelien Geron. The question is: Write a function that can shift an MNIST image in a

Solution 1:

3 nested for loops with an if condition while reshaping and appending is clearly not a good idea; numpy.roll does the job beautifully in a vector way:

import numpy as np
import matplotlib.pyplot as plt 
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train.shape
# (60000, 28, 28)

# plot an original image
plt.gray() 
plt.matshow(x_train[0]) 
plt.show() 

enter image description here

Let's first demonstrate the operations:

# one pixel down:
x_down = np.roll(x_train[0], 1, axis=0)
plt.gray() 
plt.matshow(x_down) 
plt.show() 

enter image description here

# one pixel up:
x_up = np.roll(x_train[0], -1, axis=0)
plt.gray() 
plt.matshow(x_up) 
plt.show() 

enter image description here

# one pixel left:
x_left = np.roll(x_train[0], -1, axis=1)
plt.gray() 
plt.matshow(x_left) 
plt.show() 

enter image description here

# one pixel right:
x_right = np.roll(x_train[0], 1, axis=1)
plt.gray() 
plt.matshow(x_right) 
plt.show() 

enter image description here

Having established that, we can generate, say, "right" versions of all the training images simply by

x_all_right = [np.roll(x, 1, axis=1) for x in x_train]

and similarly for the other 3 directions.

Let's confirm that the first image in x_all_right is indeed what we want:

plt.gray() 
plt.matshow(x_all_right[0]) 
plt.show()

enter image description here

You can even avoid the last list comprehension in favor of pure Numpy code, as

x_all_right = np.roll(x_train, 1, axis=2)

which is more efficient, although slightly less intuitive (just take the respective single-image command versions and increase axis by 1).


Post a Comment for "Is It Possible To Use Vector Methods To Shift Images Stored In A Numpy Ndarray For Data Augmentation?"