Write Better Python Functions Using Type Dispatch

Guest Blog Last Updated : 12 Apr, 2021
14 min read

Before we get into what is Type Dispatch? or How Type dispatch can make you a better Python programmer, let us get some basics out of the way like What is a good python function? To be honest, I would rather let Jeff Knupp explain it.

I truly believe that the only way to really understand how something works is by breaking it into granular pieces and putting it back together or building it from scratch. Let us take a very well-defined use case and write a function for it from scratch. By scratch, I mean starting with the one-to-one translation of the pseudo-code to python code without thinking about any best practices.

Then by analyzing what works & what didn’t, we can build on top of this. We are going to repeat this process until we run out of ideas to improve the code. So this post is going to be not one, but many versions of a function for the same use case.

If you are wondering Why? Because we are going to unlearning everything we know about writing a function and question each line of code as we write it. In this way, we can clearly see how each piece fits. Rather than blindly following a checklist of items about writing a function, we will create our own checklist. In case this post missed something that is of value, then please let me know I will gladly update it.

Guy thinking about python functions in bed using Type Dispatch

We will try to write the best or at least my best python function for a particular use case in 4 iterations. Each iteration is going to be an improvement over the last one, with a clear objective declared upfront. Getting better at anything is an iterative process, depending on where you are in the spectrum of python proficiency, one of the 4 iterations is going to resonate more with you than others. The next step of this post is understanding the use case. Buckle up, it’s going to be a very interesting journey.

Use Case Briefing

Convert the given image input into PyTorch tensor. Yes, that’s it. Quite simple right?

Folks with a data science background can skip this part and jump to the next section, as I am going to give a brief introduction about python libraries needed for this use case.

  1. Pillow is a that adds support for opening, manipulating, and saving many different image file formats.
  2. Numpy stands for Numerical Python. Numpy is the core library for scientific computing in Python. It provides a high-performance object and tools for working with these arrays.
  3. PyTorch is open-source, it provides two major functions namely computing (like NumPy) with strong acceleration via & Deep neural networks built on a tape-based automatic differentiation system. Pytorch can run the code using CPU if there is no CUDA supported GPU available but this will be way slower for training an ML model.

If you wondering why or to whom this use case is useful? Most Machine Learning applications (Face detection, OCR) with image data as input are using these libraries. This use case is considered data preprocessing in the ML world.

How do we go about this?

The steps involved are very straightforward. Open the image using Pillow library, and convert it to NumPy array & transpose it before converting to Tensor. BTW, Gentle reminder — This post is about writing better python functions and not a data pre-processing tutorial for deep learning so if it’s intimidating, don’t worry about it. You will feel more comfortable, once we start coding.

Python Functions use case overall steps Type Dispatch

A couple of points to take note about the process steps before we move on.

  1. PIL image when converted to Numpy array, it takes shape — (Height, Width, # of Channels). In Pytorch the image sensor must be of the shape — (# of Channels, Height, Width) so we going to transpose the array before converting it to tensor.
  2. For experienced Pytorch users who are wondering, why the hell are we not using torch functions like ToTensor or transform methods and skip Numpy? Yep, you are right, we can absolutely do that but to drive the point of this post home better, I have consciously decided to take the longer route.

 

Hiccup

Sorry, the use case is not that simple. There is a small hiccup, The function should support multiple input data types, namely str/Path, Pillow(Image) and Numpy(Ndarray).

The process step defined in the above image only supports the data types str/Path as input. Therefore we need to add support for Pillow(Image) and Numpy(Ndarray) types, but if we think about it, these data types are intermediate steps in converting the image file to torch tensor. So there is really no additional step from the above image, we just have to duplicate the steps defined for str/Path data types and alter few initial steps to support Pillow(Image) and Numpy(Ndarray) data types as input for our function.

Process steps for 2 additional Workfloware as follows.

use case process steps for image Type Dispatch

process steps python functions Type Dispatch

After comparing the 3 images it is very clear that:

  • Image data type’s process steps are a subset of the original process steps.
  • Array data type’s process steps are a subset of the Image data type process steps.

 

Solution

The First 2 iterations of the function are going to be about the behavior implementation whereas the last two iterations are all about code refractoring. If you are one of those readers who like to read the ending of a book first (I am not judging), you can jump straight to the section with the title — Iteration 4.

Iteration 1

Objective:

Make it work. As simple as that, as we are starting from scratch it is best just to focus on getting the basic feature or functionality work. We don’t need all the bells and whistles. For this iteration, we will focus on only implementing Use Process Steps and not worry about Hiccup.

Code:

# Import required libraries
import numpy as np
import torch
from PIL import Image as PILImage

# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Function to convert image to Torch tensor
def Tensor(inp):
    
    # Read image from disk using PIL Image
    img = PILImage.open(inp).convert('RGB')
    
    # Convert the image to numpy ndarray
    imgArr = np.asarray(img, np.uint8)
    
    # Rearrange the shape of the array so that it is pytorch compatible
    imgArr = imgArr.transpose(2, 0, 1)
    
    # Convert Numpy array to Torch Tensor and move it to device
    return torch.from_numpy(imgArr).to(device)

What have we done?

We managed to translate the process steps into code. The number of lines in the code almost matches the number of boxes in the diagram.

Quick Check:

We need a sample image to test our function so we will download one.

# Download a sample image to disk
import urllib.request 

url = "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c3/Python-logo-notext.svg/200px-Python-logo-notext.svg.png"
filename = "sample.png"
urllib.request.urlretrieve(url, filename)

Now, let us take our function for a test drive.

# Pass the sample image through our function
imgTensor = Tensor(filename)

# Check whether the output of the function is a Tensor
assert isinstance(imgTensor, torch.Tensor), "Not a Tensor"

Looks like we managed to nail the basic functionality. Awesome!

What can we improve?

  • We are yet to implement multiple data type input functionality i.e support for data types like PIL Image, Ndarray.
  • Also, we should definitely improve the name of the function, it doesn’t really say anything about the function.

 

Iteration 2

Objective:

When we add support for more than one data type, we need to be very careful because without proper validation, things can go south real quick and from an end-user point of view the error messages won’t make any sense. So in this iteration, we will:

  1. Implement multiple data type support and validations for these types.
  2. Make the function name more useful.

Code:

# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath

# Set Torch device to GPU if CUDA supported GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Function to convert image to Torch tensor
def from_multiInput_toTensor(inp):
    
    # Input type - str/ Path, then read from disk & convert to array
    if isinstance(inp, (str, PurePath)):
        try: 
            image = PILImage.open(inp).convert('RGB')
            imageArray = np.asarray(image.copy(), np.uint8)
        except Exception as error:
            raise error
            
    # Input type - PIL Image, then we convert it to array      
    elif isinstance(inp, PILImage.Image):
        imageArray = np.asarray(inp, np.uint8)
        
    # Input type - ndarray, then assign it to imageArray variable  
    elif isinstance(inp, np.ndarray):
        imageArray = inp
        
    # Raise TypeError with input type is not supported
    else: 
        raise TypeError("Input must be of Type - String or Path or PIL Image or Numpy array")
        
    # Rearrange shape of the array so that it is pytorch compatible
    if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)
        
    # Convert Numpy array to Torch Tensor and move it to device
    return torch.from_numpy(imageArray).to(device)

What have we done?

We have managed to implement all the functionalities required for this use case. Since ndarray is the last data type before the output data type – tensor, we convert the input first to ndarray for every supported data type and transpose it at the end, right before converting it to tensor. By adopting this style of coding, we are able to avoid having to convert the input to tensor & return the value for every supported data type. Instead, we do it only once at the end.

With the help of isinstance function, we are able to identify the data type of the input and notify the users by raising TypeError with an appropriate error message if any unsupported data type is passed. Often with io operations, something can go wrong easily so if the data type is str or Path, we read the image file inside try & except block and let the user know about the error (if any).

Quick Check:

Before we proceed, we will write a helper function to check whether two tensors are the same or not.

# Test if two torch tensors are same or not
def is_same_tensor(tensor1, tensor2):
    assert torch.eq(tensor1, tensor2).all(), "The Tensors are not equal!"
    return True

When writing a test function, better to throw a proper error and not a simple print function which can get buried in between other messages. We will also check a couple of things and ensure that the updated version is all good.

  1. Is the output of the functions from the last two iterations the same?
# Verify that the output of two versions are same or not
is_same_tensor(Tensor(filename), from_multiInput_toTensor(filename))

2. Is the support for multiple data types working?

# Check the support for Path
path = Path(Path.cwd(), filename).resolve()
is_same_tensor(Tensor(filename), from_multiInput_toTensor(path))

# Check the support for PIL Image
image = PILImage.open(filename).convert('RGB')
is_same_tensor(Tensor(filename), from_multiInput_toTensor(image))

# Check the support for Ndarray
imageArray = np.asarray(image, np.uint8)
is_same_tensor(Tensor(filename), from_multiInput_toTensor(imageArray))

3. Are the validations working?

# Validate whether an error is thrown when user passes wrong file
from_multiInput_toTensor('test.png')

Are the validations working

# Validate whether an error is thrown when input type is list
from_multiInput_toTensor([filename])

Validate whether an error is thrown when input type is list

Great! We didn’t break anything. Now let us move on with refractoring as we are done with behavior implementation.

What can we improve?

  • With the goal of adding support for multiple data types, we sort of wrote a Spaghetti code. The life cycle of the variable imageArray is confusing at the least, this type of code design is not very intuitive and can make life hell for the person who is going to maintain this codebase. Of course, for our simple use case this is not the case but in the spirit of writing better function, let us avoid this type of coding style.
  • The function name is a lot better but still not ideal. Why? Because it still does not clearly state what the input type for the function is. Letting the user play trial & error with TypeError is definitely not the right way of coding.
  • As the name of the function suggest it takes many inputs which means we have a lot of moving parts in one function. Long run this may lead to many unintended side effects.

Iteration 3

Objective:

After taking a closer look at the points from the last section, it is very clear that we need to break the function into 3 smaller ones. i.e one function for each data type.

Code:

# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath

# Change numpy array to torch tensor
def numpy_ToImageTensor(imageArray):
    
    # if not type - ndarray then raise error
    if not isinstance(imageArray, np.ndarray):
        raise TypeError("Input must be of Type - Numpy array")
    
    # Set Torch device to GPU if CUDA supported GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Transpose the numpy array before converting it to Tensor
    if imageArray.ndim == 3: imageArray = imageArray.transpose(2, 0, 1)
    return torch.from_numpy(imageArray).to(device)

# Change image to torch tensor
def pil_ToImageTensor(image):
    
    # if not type - PIL Image then raise error
    if not isinstance(image, PILImage.Image):
        raise TypeError("Input must be of Type - PIL image")
        
    # Convert the image to numpy 
    imageArray = np.array(image)
    
    # Return output of numpy_ToImageTensor function
    return numpy_ToImageTensor(imageArray)

# Change image file to torch tensor
def file_ToImageTensor(file):
    
    # if not input - string or Path then raise error
    if not isinstance(file, (str, PurePath)):
        raise TypeError("Input must be of Type - String or Path")
        
    # Read the image from disk and raise error (if any)
    try: 
        image = PILImage.open(file).convert('RGB')
    except Exception as error:
        raise error
    
    # Return output of pil_ToImageTensor function
    return pil_ToImageTensor(image)

What have we done?

This is good progress, we have removed the spaghetti code and added modularity. Now it is like we have a linear chain of three functions where two of them are calling another one, which makes it easier to understand and maintain.

file_ToImageTensor -> pil_ToImageTensor -> numpy_ToImageTensor

Modularity makes adding new features quite straightforward and easy to test as we have to alter only one specific function. Some examples for functions to modify for the hypothetical feature request:

  1. file_ToImageTensor – Ability to read black & white images as well.
  2. pil_ToImageTensor – Resize the image before converting it to Tensor.

Please do note that even though you didn’t touch the rest of the code base for the above-stated examples, it is best to test other functions as well to avoid any surprise.

The modular design also paved the way for better function names and parameter names as the input type has become more specific.

from_multiInput_toTensor -> file_ToImageTensor
inp -> file

These subtle changes can go a long way in making the code more user-friendly. Honestly, at this point, I feel like I am preaching, so I will try to wrap this quickly.

Quick Check:

This post is getting too long for my own liking, so I am not going to test negative or different data type scenarios again and only check for the original use case. You really shouldn’t skip testing after refractoring.

is_same_tensor(from_multiInput_toTensor(filename), file_ToImageTensor(filename))

What can we improve?

  • Looks like we are following the same pattern of checking data type with isinstance function and raising TypeError for every function, which is a kind of code repetition as well.
  • Okay, we have to do something about the function names. I am pretty sure no one will remember these names, definitely not my future self.

Iteration 4

Objective:

Improve the documentation and avoid code repetition with the help of Type Dispatch.

Type Dispatch:

Type dispatch allows you to change the way a function behaves based upon the input types it receives. This is a prominent feature in some programming languages like Julia & Swift. All we have to do is add a decorator typedispatch before our function. Probably, it is easier to demonstrate than to explain.

Example for Type Dispatch:

Function definitions:

from fastcore.all import *
from typing import List

# Function to multiply two ndarrays
@typedispatch
def multiple(x:np.ndarray, y:np.ndarray ): 
    return x * y

# Function to multiply a List by an integer
@typedispatch
def multiple(lst:List, x:int): 
    return [ x*val for val in lst]

Calling 1st function:

x = np.arange(1,3)
print(f'x is {x}')

y = np.array(10)
print(f'y is {y}')

print(f'Result of multiplying two numpy arrays: { multiple(x, y)}')

2 numpy arrays

Calling 2nd function:

x = [1, 2]
print(f'x is {x}')

y = 10
print(f'y is {y}')

print(f'Result of multiplying a List of integers by an integer: {multiple(x, y)}')

Result of multiplying a List of integers by an integer

The life of a programmer can be made so much better if they don’t have to come up with different function names with no change in purpose (for various data types). If this doesn’t encourage you to use Type Dispatch whenever possible then I don’t know what will 🤷

We will be using fastcore package for implementing Type Dispatch to our use case. For more details on fastcore and Type Dispatch, check this awesome blog by Hamel Husain. Also check out fastai, which inspired me to write this post.

Code:

# Import Libraries
import numpy as np
import torch
from PIL import Image as PILImage
from pathlib import Path, PurePath
from fastcore.all import *


@typedispatch
def to_imageTensor(arr: np.ndarray) -> torch.Tensor:
    """Change ndarray to torch tensor.
    
    The ndarray would be of the shape (Height, Width, # of Channels)
    but pytorch tensor expects the shape as 
    (# of Channels, Height, Width) before putting 
    the Tensor on GPU if it's available.
    
    Args:
        arr[ndarray]: Ndarray which needs to be 
        converted to torch tensor
    
    Returns:
        Torch tensor on GPU (if it's available)   
    """
    
    # Set Torch device to GPU if CUDA supported GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Transpose the array before converting to tensor
    imgArr = arr.transpose(2, 0, 1) if arr.ndim == 3 else arr
    return torch.Tensor(imgArr).to(device)


@typedispatch
def to_imageTensor(image: PILImage.Image) -> torch.Tensor:
    """Change image to torch tensor.
    
    The PIL image cast as numpy array with dtype as uint8,
    and then passed to to_imageTensor(arr: np.ndarray) function
    for converting numpy array to torch tensor.
    
    Args:
        image[PILImage.Image]: PIL Image which 
        needs to be converted to torch tensor
    
    Returns:
        Torch tensor on GPU (if it's available)   
    
    """
    return to_imageTensor(np.asarray(image, np.uint8))


@typedispatch
def to_imageTensor(file: (str, PurePath)) -> torch.Tensor:
    """Change image file to torch tensor.
    
    Read the image from disk as 3 channels (RGB) using PIL, 
    and passed on to to_imageTensor(image: PILImage.Image)
    function for converting Image to torch tensor.
    
    Args:
        file[str, PurePath]: Image file name which needs to
        be converted to torch tensor
    
    Returns:
        Torch tensor on GPU (if it's available)
    
    Raises:
        Any error thrown while reading the image file,
        Mostly FileNotFoundError will be raised.
    
    """
    try: 
        img = PILImage.open(file).convert('RGB')
    except Exception as error:
        raise error
    return to_imageTensor(img)


@typedispatch
def to_imageTensor(x:object) -> None:
    """For unsupported data types, raise TypeError. """
    raise TypeError('Input must be of Type - String or Path or PIL Image or Numpy array')

What have we done?

By utilizing the Type dispatch functionality, we managed to use the same name for all 3 functions, and each one’s behavior is differentiated by their input type. The function name is also shortened with the removal of input type. This makes the function name easier to remember.

By calling the function name, we can see what are the different input types supported by the function. Fastcore by default expects two input parameters since we have only it assigns the second one as an object. The second parameter will not have any impact on the function behavior.

to_imageTensor# OUTPUT
(ndarray,object) -> to_imageTensor
(Image,object) -> to_imageTensor
(str,object) -> to_imageTensor
(PurePath,object) -> to_imageTensor
(object,object) -> to_imageTensor

With help of the inspecting module, we can access the input & output types of a particular function.

import inspect
inspect.signature(to_imageTensor[np.ndarray])

import inspect

The docstrings implemented in this iteration make the code more readable. The docstring of a particular input type can be accessed by calling doc along with the input type.

print(to_imageTensor[np.ndarray].__doc__)

to_imageTensor[np.ndarray]

As discussed in the last section, we managed to move the TypeError message to a separate function with input type as object.

Quick Check:

Still working!

is_same_tensor(file_ToImageTensor(filename), to_imageTensor(filename))

Validating error message when the unsupported data type is passed to the function.

to_imageTensor([filename])

to_imageTensor([filename])

What can we improve?

This is it. I believe we are done here.

Closing Thoughts

Did we really write the best function for this use case? I think so, but don’t take my word. You can find the notebook version of this post here. If you like to add anything to this post, feel free to reach out to me via Twitter or the comments section. I will gladly update this post based on your comments.

Even though Type Dispatch was the focal point when I started writing this post, soon I realized that the code evolution process is equally important too. So I have decided to include that as well. I hope you enjoyed this journey of writing a python function.

Credits

  1. Video Card icon by IconBros
  2. zwicon

Author

Author

Aravind Ramalingam 
Deep learning practitioner & Tech enthusiast, Sharing my journey with Python, Computer Vision, Deep Learning, and everything related to it.

Responses From Readers

Clear

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details