Sandbox
Train
Sandbox
Train
Training a model in Slai.
The train.py
file is a core building block in Slai. It’s used for:
-
Writing training code for your model
-
Loading pre-trained models
The train.py file has two rules:
-
You need a
train()
function -
train()
must return a trained model, or an empty object
def train():
return {}
Here’s a more realistic example, with a basic neural network class that’s returned in the train()
function.
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def train():
return Net()
You can also return multiple models from train()
, like so.
from transformers import pipeline
def train():
gpt2_model = pipeline("text-generation", model="gpt2")
sentiment_analysis_model = pipeline(
"sentiment-analysis", model="siebert/sentiment-roberta-large-english"
)
return {
"gpt2_model": gpt2_model,
"sentiment_analysis_model": sentiment_analysis_model,
}
The object returned from train.py
will be passed to the handler, which is a wrapper around the model object.