The train.py file is a core building block in Slai. It’s used for:

  • Writing training code for your model

  • Loading pre-trained models

The train.py file has two rules:

  • You need a train() function

  • train() must return a trained model, or an empty object

def train():
  return {}

Here’s a more realistic example, with a basic neural network class that’s returned in the train() function.

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(2, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

def train():
  return Net()

You can also return multiple models from train(), like so.

from transformers import pipeline


def train():
  gpt2_model = pipeline("text-generation", model="gpt2")
  sentiment_analysis_model = pipeline(
    "sentiment-analysis", model="siebert/sentiment-roberta-large-english"
  )

  return {
    "gpt2_model": gpt2_model,
    "sentiment_analysis_model": sentiment_analysis_model,
  }

The object returned from train.py will be passed to the handler, which is a wrapper around the model object.