Integrating ML Models#
This document explains how to integrate a machine learning model into a FastAPI application. It includes a simple example to make it easy to understand. The goal is to create an API that receives input data, processes it through a pre-trained machine learning model, and returns predictions.
π New to App-Generator? Join our 10k+ Community using GitHub One-Click SignIN.
Scenario: Letβs consider a simple scenario where we want to predict whether a person has diabetes based on some medical features. Weβll use a pre-trained machine learning model (for simplicity, we will use scikit-learn).
- Steps:
Train a simple machine learning model.
Save the trained model to a file.
Create a FastAPI app to serve predictions from the model.
Prerequisites#
Python 3.7+
FastAPI
Uvicorn (for running FastAPI)
Scikit-learn
Joblib (for saving and loading the model)
Install dependencies:#
pip install fastapi[all] scikit-learn joblib uvicorn matplotlib
Step 1: Train a Simple Model#
We will start by training a simple machine learning model to predict whether a person has diabetes based on their medical features.
# train_model.py
import joblib
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
# Load diabetes dataset
diabetes_data = load_diabetes()
# Visualize the distribution of the target variable
plt.hist(diabetes_data.target, bins=50)
plt.xlabel('Target Value')
plt.ylabel('Frequency')
plt.title('Distribution of Target Variable')
plt.show()
# Convert target into binary (1 = diabetic, 0 = not diabetic)
y = diabetes_data.target > 100
# Split into training and testing sets
X = diabetes_data.data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a RandomForestClassifier
model = RandomForestClassifier()
model.fit(X_train, y_train)
# Save the trained model to a file
joblib.dump(model, 'diabetes_model.joblib')
print("Model trained and saved.")
Run this script to train and save the model:
python train_model.py
Step 2: Create FastAPI Application#
Now, we will create a FastAPI application that loads the saved model and exposes an endpoint to make predictions.
# main.py
import joblib
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
# Load the saved model
model = joblib.load('diabetes_model.joblib')
# Create FastAPI app
app = FastAPI()
# Define the input data structure
class DiabetesFeatures(BaseModel):
age: float
sex: float
bmi: float
map: float
tc: float
ldl: float
hdl: float
tch: float
ltg: float
glu: float
# Define a prediction endpoint
@app.post("/predict/")
def predict(features: DiabetesFeatures):
# Convert input data to a numpy array
input_data = np.array([[ # Convert input features to the appropriate format
features.age,
features.sex,
features.bmi,
features.map,
features.tc,
features.ldl,
features.hdl,
features.tch,
features.ltg,
features.glu
]])
# Get the prediction probability (probability for "diabetic")
prediction_prob = model.predict_proba(input_data)
# Set a threshold for predicting 'Diabetic' (e.g., 0.6 probability for being diabetic)
threshold = 0.6
is_diabetic = prediction_prob[0][1] > threshold # Use the second column for the probability of class '1'
# Return the result based on probability threshold
return {"prediction": "Diabetic" if is_diabetic else "Not Diabetic"}
Step 3: Run the FastAPI Server#
To start the FastAPI server, use uvicorn:
uvicorn main:app --reload
Step 4: Test the API#
Once the server is running, you can test the API using curl or through an interactive UI at http://127.0.0.1:8000/docs.
Example curl request:
curl -X 'POST' \
'http://127.0.0.1:8000/predict/' \
-H 'Content-Type: application/json' \
-d '{
"age": 50.0,
"sex": 1.0,
"bmi": 25.0,
"map": 92.0,
"tc": 220.0,
"ldl": 120.0,
"hdl": 50.0,
"tch": 70.0,
"ltg": 3.5,
"glu": 90.0
}'
Expected Response:
{
"prediction": "Not Diabetic"
}
Conclusion#
In this guide demonstrates how to integrate a machine learning model with FastAPI. The key steps were:
Train and save a machine learning model.
Load the model in a FastAPI application.
Expose an endpoint to make predictions based on input data.
Links#
π New to App-Generator? Join our 10k+ Community using GitHub One-Click SignIN.
π
Download
products and start fast a new projectπ Bootstrap your startUp, MVP or Legacy project with a custom development sprint