Deploying fast.ai models with Baseten

Baseten makes it easy to build complex user-facing applications with machine learning models. Often, step one of building these applications is deploying a machine learning model. Baseten supports deploying most types of machine learning models. In this post, we’ll go over deploying a simple fast.ai model and end up with a REST API to call the model.

Training and serializing models

To get started, let’s create a fast.ai model to deploy (note that this step is agnostic to Baseten, we’re using a simple training script just to have a model to deploy). Our model is a simple CNN that predicts the type of animal in a photo. We’ll train it on pet images using the create_model method, and then do some simple serialization using joblib. We now have a serialized model stored in a binary to upload to Baseten.


import joblib
from fastai.data.external import URLs, untar_data
from fastai.data.transforms import Resize, get_image_files
from fastai.vision.data import ImageDataLoaders
from fastai.vision.learner import cnn_learner, error_rate
from torchvision.models import resnet34

from fai_model import label_func


def create_model():
   path = untar_data(URLs.PETS)
   files = get_image_files(f"{path}/images")
   dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))
   learn = cnn_learner(dls, resnet34, metrics=error_rate)
   return learn


pet_model = create_model()


with open("model.joblib", "wb") as f:
   joblib.dump(pet_model, f)
Creating a requirements.txt

Next, we’ll define the requirements that we need to load and run our model. In this case we only explicitly depend on fast.ai, joblib, and Pillow so we add those to a file called requirements.txt.


fastai==2.5.2
joblib==1.0.1
Pillow==8.3.2
Writing a Python class for inference

Fast.ai models are deployed to Baseten using our custom model interface. The most important parts of this class are the predict and load methods. The load method gets called when the model’s container comes up and the predict method gets served by the Baseten inference API.


import joblib
import requests
from numpy import asarray
from PIL import Image


def label_func(f):
   return f[0].isupper()


class FastAiModel:
   def __init__(self):
       self._model = None

   def load(self):
       self._model = joblib.load("model/model.joblib")

   def predict(self, inputs):
       image_urls = [inp["url"] for inp in inputs]
       images = [self._fetch_image_url(img) for img in image_urls]
       predictions = [self._model.predict(image) for image in images]
       clean_predictions = [self._clean_prediction(pred) for pred in predictions]
       return clean_predictions

   def _fetch_image_url(self, url):
       img = Image.open(requests.get(url, stream=True).raw)
       return asarray(img)

   def _clean_prediction(self, prediction):
       return [prediction[0], prediction[1].tolist(), prediction[2].tolist()]
Putting it all together

With these pieces in place, we can deploy the model to Baseten's infrastructure with the Baseten client. The Baseten client library is pip installable and can be installed wherever you do your work (e.g. local Jupyter notebook, Google Colab Notebook, Databricks, or SageMaker Studio).

Once Baseten has been installed, we’ll simply import it and deploy the model with its files (including the serialized model) as well as the requirements.txt.


import baseten

baseten.deploy_custom(    
	model_name="FastAI demo",    
	model_class="FastaiModel",    
	model_files=["fai_model.py", "model.joblib"],    
	requirements_file="requirements.txt"
 )

The deploy_custom function uploads the model to Baseten where a container is built and deployed within our infrastructure. After a few moments, our model will be ready to use behind a REST API or within a Baseten application.

We’d love to hear what your experience has been deploying models with Baseten — drop us a line!

Want to deploy your own models? Sign up for Baseten for free!