Now that we’ve trained, evaluated, and fine-tuned the Transformer model, the next step is to deploy the model so it can be used in real time applications. In this post, we’ll cover:

  • Saving and loading the trained model.
  • Deploying the model as an API using FastAPI.
  • Handling real time user input.
  • Explaining the model’s predictions and outputs.

By the end of this post, you’ll be able to deploy a Transformer model and interact with it in real time through a simple API.


Saving and Loading the Trained Model

Before deploying the model, we need to save it after training. This ensures that you can reload it later for inference (making predictions) without retraining it from scratch.

Saving the model:

def save_model(model, path="transformer_model.pth"):
torch.save(model.state_dict(), path)
print(f"Model saved to {path}")

# Example usage after training or fine tuning:
save_model(model)

Loading the model:

def load_model(model, path="transformer_model.pth"):
model.load_state_dict(torch.load(path))
model.eval() # Set the model to evaluation mode
print(f"Model loaded from {path}")

# Example usage:
loaded_model = Transformer(embed_size=256, heads=8, depth=4, forward_expansion=4, max_len=50, dropout=0.1, vocab_size=30522)
load_model(loaded_model)

What’s happening here:

  • We use torch.save() to store the model’s state (parameters) in a file.
  • We can later use torch.load() to reload the model from this file for inference.

Deploying the Model as an API with FastAPI

FastAPI allows us to create an API endpoint where users can send text input to the Transformer model, and the model will generate a response. We’ll set up a simple API to handle text input and return predictions.

FastAPI Setup:

  1. Install FastAPI and Uvicorn (for running the server):
    pip install fastapi uvicorn
  2. Create an API with FastAPI:
from fastapi import FastAPI
from transformers import BertTokenizer
import torch

app = FastAPI()

# Load the model and tokenizer
model = Transformer(embed_size=256, heads=8, depth=4, forward_expansion=4, max_len=50, dropout=0.1, vocab_size=30522)
load_model(model)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

@app.post("/predict/")
async def predict_text(input_text: str):
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']

# Make a prediction
with torch.no_grad():
outputs = model(input_ids)
predicted_token_id = torch.argmax(outputs, dim=-1).item()

# Convert the predicted token ID back to a word
predicted_word = tokenizer.decode(predicted_token_id)

return {"input": input_text, "prediction": predicted_word}

Running the FastAPI server:

To run this API, use the following command in your terminal:

uvicorn filename:app --reload

This will launch the FastAPI server at http://127.0.0.1:8000.

What’s happening here:

  • We define a FastAPI app with a /predict/ route that accepts text input via a POST request.
  • The input text is tokenized using the BertTokenizer, passed through the Transformer model, and the model predicts the next word.
  • The predicted word is decoded back into text and returned as part of the response.

Handling Real Time User Input

Once the API is up and running, users can send POST requests with text input, and the model will respond with a prediction. Here’s an example of how to send a request to the API using requests in Python:

import requests

response = requests.post("http://127.0.0.1:8000/predict/", json={"input_text": "The quick brown fox"})
print(response.json())

Example output:

{
"input": "The quick brown fox",
"prediction": "something"
}

What’s happening here:

  • The model receives the text input The quick brown fox.
  • It generates the next word, in this case, something.
  • The API returns both the input and the predicted next word.

Explanation of the Output

Input:

  • Text: The input is a string provided by the user, which is tokenized into word IDs before being passed to the Transformer model.

Output:

  • Prediction: The output is the word that the model predicts as the next word in the sequence.

Example:

  • Input: "The quick brown fox"
  • Output: "something"

The model uses its learned weights to predict which word is most likely to follow the input sequence based on its training data. The word "jumps" is chosen because it fits the context of the sentence.


Understanding the API’s Behavior

As you test the API with more input texts, the Transformer model will make predictions based on patterns it learned during training. Here are some things to note:

  • Contextual Predictions: The model tries to predict the next word based on the context of the input text. For example, in "The quick brown fox", the model might predict "jumps" because it has seen similar sentence structures during training.
  • Unknown Input: If the input text is very different from what the model was trained on, the predictions might not make much sense. Fine tuning the model on more relevant datasets can help improve these predictions.

Output Explanation (Real Time Predictions)

When you query the API with input, here’s what happens under the hood:

  1. Tokenization: The input text is tokenized into word IDs using a tokenizer (e.g., BertTokenizer).
    • Example: "The quick brown fox"[101, 1996, 4248, 2829, 4419] (IDs for words).
  2. Model Prediction: The tokenized input is passed through the Transformer model, and the model generates a distribution over the vocabulary (logits). The highest scoring word ID is selected as the predicted next word.
    • Example output logits: [0.1, 0.3, 0.6, ...] (higher values indicate higher confidence).
  3. Decoding: The predicted word ID is converted back into a readable word using the tokenizer’s decode() method.
    • Example: ID [4419]"something".
  4. API Response: The API returns the input text and the predicted next word.

Conclusion

In this post, we deployed a trained Transformer model to handle real time text predictions via a FastAPI interface. Here’s what we covered:

  1. We saved and loaded the trained model for reuse.
  2. We deployed the model using FastAPI to serve predictions via an API.
  3. We handled real time user input and explained how the model generates predictions.

Server file

Client file

With the Transformer model now deployed, it can serve predictions for a wide variety of applications like chatbots, text autocompletion, and more. But it could always use more training.


0 Comments

Leave a Reply

Avatar placeholder

Your email address will not be published. Required fields are marked *