Before we get into what is Type Dispatch? or How Type dispatch can make you a better Python programmer, let us get some basics out of the way like What is a good python function? To be honest, I would rather let Jeff Knupp explain it.
I truly believe that the only way to really understand how something works is by breaking it into granular pieces and putting it back together or building it from scratch. Let us take a very well-defined use case and write a function for it from scratch. By scratch, I mean starting with the one-to-one translation of the pseudo-code to python code without thinking about any best practices.
Then by analyzing what works & what didn’t, we can build on top of this. We are going to repeat this process until we run out of ideas to improve the code. So this post is going to be not one, but many versions of a function for the same use case.
If you are wondering Why? Because we are going to unlearning everything we know about writing a function and question each line of code as we write it. In this way, we can clearly see how each piece fits. Rather than blindly following a checklist of items about writing a function, we will create our own checklist. In case this post missed something that is of value, then please let me know I will gladly update it.
We will try to write the best or at least my best python function for a particular use case in 4 iterations. Each iteration is going to be an improvement over the last one, with a clear objective declared upfront. Getting better at anything is an iterative process, depending on where you are in the spectrum of python proficiency, one of the 4 iterations is going to resonate more with you than others. The next step of this post is understanding the use case. Buckle up, it’s going to be a very interesting journey.
Convert the given image input into PyTorch tensor. Yes, that’s it. Quite simple right?
Folks with a data science background can skip this part and jump to the next section, as I am going to give a brief introduction about python libraries needed for this use case.
If you wondering why or to whom this use case is useful? Most Machine Learning applications (Face detection, OCR) with image data as input are using these libraries. This use case is considered data preprocessing in the ML world.
The steps involved are very straightforward. Open the image using Pillow library, and convert it to NumPy array & transpose it before converting to Tensor. BTW, Gentle reminder — This post is about writing better python functions and not a data pre-processing tutorial for deep learning so if it’s intimidating, don’t worry about it. You will feel more comfortable, once we start coding.
A couple of points to take note about the process steps before we move on.
transpose
the array before converting it to tensor.ToTensor
or transform methods and skip Numpy? Yep, you are right, we can absolutely do that but to drive the point of this post home better, I have consciously decided to take the longer route.
Sorry, the use case is not that simple. There is a small hiccup, The function should support multiple input data types, namely str/Path
, Pillow(Image)
and Numpy(Ndarray)
.
The process step defined in the above image only supports the data types str/Path
as input. Therefore we need to add support for Pillow(Image)
and Numpy(Ndarray)
types, but if we think about it, these data types are intermediate steps in converting the image file to torch tensor. So there is really no additional step from the above image, we just have to duplicate the steps defined for str/Path
data types and alter few initial steps to support Pillow(Image)
and Numpy(Ndarray)
data types as input for our function.
Process steps for 2 additional Workfloware as follows.
After comparing the 3 images it is very clear that:
The First 2 iterations of the function are going to be about the behavior implementation whereas the last two iterations are all about code refractoring. If you are one of those readers who like to read the ending of a book first (I am not judging), you can jump straight to the section with the title — Iteration 4.
Objective:
Make it work. As simple as that, as we are starting from scratch it is best just to focus on getting the basic feature or functionality work. We don’t need all the bells and whistles. For this iteration, we will focus on only implementing Use Process Steps and not worry about Hiccup.
Code:
# Import required libraries
import numpy as np
import torch
from PIL import Image as PILImage
# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Function to convert image to Torch tensor
def Tensor(inp):
# Read image from disk using PIL Image
img = PILImage.open(inp).convert('RGB')
# Convert the image to numpy ndarray
imgArr = np.asarray(img, np.uint8)
# Rearrange the shape of the array so that it is pytorch compatible
imgArr = imgArr.transpose(2, 0, 1)
# Convert Numpy array to Torch Tensor and move it to device
return torch.from_numpy(imgArr).to(device)
We managed to translate the process steps into code. The number of lines in the code almost matches the number of boxes in the diagram.
We need a sample image to test our function so we will download one.
# Download a sample image to disk
import urllib.request
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c3/Python-logo-notext.svg/200px-Python-logo-notext.svg.png"
filename = "sample.png"
urllib.request.urlretrieve(url, filename)
Now, let us take our function for a test drive.
# Pass the sample image through our function
imgTensor = Tensor(filename)
# Check whether the output of the function is a Tensor
assert isinstance(imgTensor, torch.Tensor), "Not a Tensor"
Looks like we managed to nail the basic functionality. Awesome!
What can we improve?
PIL Image
, Ndarray
.
When we add support for more than one data type, we need to be very careful because without proper validation, things can go south real quick and from an end-user point of view the error messages won’t make any sense. So in this iteration, we will:
Code:
# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath
# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Function to convert image to Torch tensor
def from_multiInput_toTensor(inp):
# Input type - str/ Path, then read from disk & convert to array
if isinstance(inp, (str, PurePath)):
try:
image = PILImage.open(inp).convert('RGB')
imageArray = np.asarray(image.copy(), np.uint8)
except Exception as error:
raise error
# Input type - PIL Image, then we convert it to array
elif isinstance(inp, PILImage.Image):
imageArray = np.asarray(inp, np.uint8)
# Input type - ndarray, then assign it to imageArray variable
elif isinstance(inp, np.ndarray):
imageArray = inp
# Raise TypeError with input type is not supported
else:
raise TypeError("Input must be of Type - String or Path or PIL Image or Numpy array")
# Rearrange shape of the array so that it is pytorch compatible
if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)
# Convert Numpy array to Torch Tensor and move it to device
return torch.from_numpy(imageArray).to(device)
We have managed to implement all the functionalities required for this use case. Since ndarray
is the last data type before the output data type – tensor
, we convert the input first to ndarray
for every supported data type and transpose it at the end, right before converting it to tensor
. By adopting this style of coding, we are able to avoid having to convert the input to tensor
& return
the value for every supported data type. Instead, we do it only once at the end.
With the help of isinstance
function, we are able to identify the data type of the input and notify the users by raising TypeError
with an appropriate error message if any unsupported data type is passed. Often with io operations, something can go wrong easily so if the data type is str
or Path
, we read the image file inside try
& except
block and let the user know about the error (if any).
Quick Check:
Before we proceed, we will write a helper function to check whether two tensors are the same or not.
# Test if two torch tensors are same or not
def is_same_tensor(tensor1, tensor2):
assert torch.eq(tensor1, tensor2).all(), "The Tensors are not equal!"
return True
When writing a test function, better to throw a proper error and not a simple print function which can get buried in between other messages. We will also check a couple of things and ensure that the updated version is all good.
# Verify that the output of two versions are same or not
is_same_tensor(Tensor(filename), from_multiInput_toTensor(filename))
2. Is the support for multiple data types working?
# Check the support for Path
path = Path(Path.cwd(), filename).resolve()
is_same_tensor(Tensor(filename), from_multiInput_toTensor(path))
# Check the support for PIL Image
image = PILImage.open(filename).convert('RGB')
is_same_tensor(Tensor(filename), from_multiInput_toTensor(image))
# Check the support for Ndarray
imageArray = np.asarray(image, np.uint8)
is_same_tensor(Tensor(filename), from_multiInput_toTensor(imageArray))
3. Are the validations working?
# Validate whether an error is thrown when user passes wrong file
from_multiInput_toTensor('test.png')
# Validate whether an error is thrown when input type is list
from_multiInput_toTensor([filename])
Great! We didn’t break anything. Now let us move on with refractoring as we are done with behavior implementation.
What can we improve?
imageArray
is confusing at the least, this type of code design is not very intuitive and can make life hell for the person who is going to maintain this codebase. Of course, for our simple use case this is not the case but in the spirit of writing better function, let us avoid this type of coding style.TypeError
is definitely not the right way of coding.Objective:
After taking a closer look at the points from the last section, it is very clear that we need to break the function into 3 smaller ones. i.e one function for each data type.
Code:
# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath
# Change numpy array to torch tensor
def numpy_ToImageTensor(imageArray):
# if not type - ndarray then raise error
if not isinstance(imageArray, np.ndarray):
raise TypeError("Input must be of Type - Numpy array")
# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Transpose the numpy array before converting it to Tensor
if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)
return torch.from_numpy(imageArray).to(device)
# Change image to torch tensor
def pil_ToImageTensor(image):
# if not type - PIL Image then raise error
if not isinstance(image, PILImage.Image):
raise TypeError("Input must be of Type - PIL image")
# Convert the image to numpy
imageArray = np.array(image)
# Return output of numpy_ToImageTensor function
return numpy_ToImageTensor(imageArray)
# Change image file to torch tensor
def file_ToImageTensor(file):
# if not input - string or Path then raise error
if not isinstance(file, (str, PurePath)):
raise TypeError("Input must be of Type - String or Path")
# Read the image from disk and raise error (if any)
try:
image = PILImage.open(file).convert('RGB')
except Exception as error:
raise error
# Return output of pil_ToImageTensor function
return pil_ToImageTensor(image)
This is good progress, we have removed the spaghetti code and added modularity. Now it is like we have a linear chain of three functions where two of them are calling another one, which makes it easier to understand and maintain.
file_ToImageTensor
-> pil_ToImageTensor
-> numpy_ToImageTensor
Modularity makes adding new features quite straightforward and easy to test as we have to alter only one specific function. Some examples for functions to modify for the hypothetical feature request:
file_ToImageTensor
– Ability to read black & white images as well.pil_ToImageTensor
– Resize the image before converting it to Tensor.Please do note that even though you didn’t touch the rest of the code base for the above-stated examples, it is best to test other functions as well to avoid any surprise.
The modular design also paved the way for better function names and parameter names as the input type has become more specific.
from_multiInput_toTensor
-> file_ToImageTensor
inp
-> file
These subtle changes can go a long way in making the code more user-friendly. Honestly, at this point, I feel like I am preaching, so I will try to wrap this quickly.
This post is getting too long for my own liking, so I am not going to test negative or different data type scenarios again and only check for the original use case. You really shouldn’t skip testing after refractoring.
is_same_tensor(from_multiInput_toTensor(filename), file_ToImageTensor(filename))
What can we improve?
isinstance
function and raising TypeError
for every function, which is a kind of code repetition as well.Objective:
Improve the documentation and avoid code repetition with the help of Type Dispatch.
Type Dispatch:
Type dispatch allows you to change the way a function behaves based upon the input types it receives. This is a prominent feature in some programming languages like Julia & Swift. All we have to do is add a decorator typedispatch
before our function. Probably, it is easier to demonstrate than to explain.
Example for Type Dispatch:
Function definitions:
from fastcore.all import *
from typing import List
# Function to multiply two ndarrays
@typedispatch
def multiple(x:np.ndarray, y:np.ndarray ):
return x * y
# Function to multiply a List by an integer
@typedispatch
def multiple(lst:List, x:int):
return [ x*val for val in lst]
Calling 1st function:
x = np.arange(1,3)
print(f'x is {x}')
y = np.array(10)
print(f'y is {y}')
print(f'Result of multiplying two numpy arrays: { multiple(x, y)}')
Calling 2nd function:
x = [1, 2]
print(f'x is {x}')
y = 10
print(f'y is {y}')
print(f'Result of multiplying a List of integers by an integer: {multiple(x, y)}')
The life of a programmer can be made so much better if they don’t have to come up with different function names with no change in purpose (for various data types). If this doesn’t encourage you to use Type Dispatch whenever possible then I don’t know what will 🤷
We will be using fastcore package for implementing Type Dispatch to our use case. For more details on fastcore and Type Dispatch, check this awesome blog by Hamel Husain. Also check out fastai, which inspired me to write this post.
Code:
# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath
from fastcore.all import *
@typedispatch
def to_imageTensor(arr: np.ndarray) -> torch.Tensor:
"""Change ndarray to torch tensor.
The ndarray would be of the shape (Height, Width, # of Channels)
but pytorch tensor expects the shape as
(# of Channels, Height, Width) before putting
the Tensor on GPU if it's available.
Args:
arr[ndarray]: Ndarray which needs to be
converted to torch tensor
Returns:
Torch tensor on GPU (if it's available)
"""
# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Transpose the array before converting to tensor
imgArr = arr.transpose(2, 0, 1) if arr.ndim == 3 else arr
return torch.Tensor(imgArr).to(device)
@typedispatch
def to_imageTensor(image: PILImage.Image) -> torch.Tensor:
"""Change image to torch tensor.
The PIL image cast as numpy array with dtype as uint8,
and then passed to to_imageTensor(arr: np.ndarray) function
for converting numpy array to torch tensor.
Args:
image[PILImage.Image]: PIL Image which
needs to be converted to torch tensor
Returns:
Torch tensor on GPU (if it's available)
"""
return to_imageTensor(np.asarray(image, np.uint8))
@typedispatch
def to_imageTensor(file: (str, PurePath)) -> torch.Tensor:
"""Change image file to torch tensor.
Read the image from disk as 3 channels (RGB) using PIL,
and passed on to to_imageTensor(image: PILImage.Image)
function for converting Image to torch tensor.
Args:
file[str, PurePath]: Image file name which needs to
be converted to torch tensor
Returns:
Torch tensor on GPU (if it's available)
Raises:
Any error thrown while reading the image file,
Mostly FileNotFoundError will be raised.
"""
try:
img = PILImage.open(file).convert('RGB')
except Exception as error:
raise error
return to_imageTensor(img)
@typedispatch
def to_imageTensor(x:object) -> None:
"""For unsupported data types, raise TypeError. """
raise TypeError('Input must be of Type - String or Path or PIL Image or Numpy array')
By utilizing the Type dispatch functionality, we managed to use the same name for all 3 functions, and each one’s behavior is differentiated by their input type. The function name is also shortened with the removal of input type. This makes the function name easier to remember.
By calling the function name, we can see what are the different input types supported by the function. Fastcore by default expects two input parameters since we have only it assigns the second one as an object. The second parameter will not have any impact on the function behavior.
to_imageTensor# OUTPUT (ndarray,object) -> to_imageTensor (Image,object) -> to_imageTensor (str,object) -> to_imageTensor (PurePath,object) -> to_imageTensor (object,object) -> to_imageTensor
With help of the inspecting module, we can access the input & output types of a particular function.
import inspect
inspect.signature(to_imageTensor[np.ndarray])
The docstrings implemented in this iteration make the code more readable. The docstring of a particular input type can be accessed by calling doc along with the input type.
print(to_imageTensor[np.ndarray].__doc__)
As discussed in the last section, we managed to move the TypeError
message to a separate function with input type as object
.
Still working!
is_same_tensor(file_ToImageTensor(filename), to_imageTensor(filename))
Validating error message when the unsupported data type is passed to the function.
to_imageTensor([filename])
This is it. I believe we are done here.
Did we really write the best function for this use case? I think so, but don’t take my word. You can find the notebook version of this post here. If you like to add anything to this post, feel free to reach out to me via Twitter or the comments section. I will gladly update this post based on your comments.
Even though Type Dispatch was the focal point when I started writing this post, soon I realized that the code evolution process is equally important too. So I have decided to include that as well. I hope you enjoyed this journey of writing a python function.