LLM Training and Generation: Bringing Mini-GPT to Life
Welcome to the final part of our Mini-GPT journey! In LLM training Part 3, we built all the core components of our model. Now it’s time for the most exciting part – training and watching our model generate text.
- LLM Training Simplified: Building Your First Language Model – 1
- LLM Training Simplified: Building Your First Language Model – 2
- LLM Training Simplified: Building Your First Language Model – 3
The LLM Training Process: Teaching Our Model to Predict
Training a language model involves showing it millions of examples and having it learn to predict the next word. Let’s implement this process step by step:
1. Setting Up the Training Loop
First, let’s define our training function:
def train_model(
model,
train_loader,
val_loader=None,
epochs=3,
learning_rate=3e-4,
max_grad_norm=1.0,
warmup_steps=1000,
checkpoint_dir="./checkpoints",
log_interval=100
):
"""Train the Mini-GPT model."""
# Create directory for checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Learning rate scheduler with warmup
def lr_lambda(step):
# Linear warmup followed by cosine decay
if step < warmup_steps:
return step / warmup_steps
else:
# Cosine decay to 10% of max learning rate
progress = (step - warmup_steps) / max(1, epochs * len(train_loader) - warmup_steps)
return 0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))
This training function is the heart of our learning process. Think of it as a teacher guiding students through increasingly complex language exercises. The function orchestrates several critical tasks:
- Optimization: The AdamW optimizer adjusts model weights based on prediction errors
- Learning Rate Scheduling: We start with a gentle “warmup” period and gradually reduce the learning rate over time
- Gradient Clipping: We prevent extreme weight updates that might destabilize training
- Checkpointing: We save our progress regularly in case training is interrupted
The learning rate scheduler deserves special attention – it’s like adjusting a student’s challenge level over time. We start easy (warmup), then gradually increase difficulty, and finally ease off to help the model converge to its best performance.
When we run this training function with our dataset, we’ll see the output like this:
Epoch 1/3: 100%|██████████| 1250/1250 [05:48<00:00, 3.59it/s, loss=3.542, lr=0.000300]
Epoch 1 average loss: 3.9821
Validation loss: 3.7624
Checkpoint saved to ./checkpoints/epoch_1.pt
Epoch 2/3: 100%|██████████| 1250/1250 [05:47<00:00, 3.60it/s, loss=3.128, lr=0.000150]
Epoch 2 average loss: 3.2145
Validation loss: 3.1872
Checkpoint saved to ./checkpoints/epoch_2.pt
2. Preparing Our Dataset
Let’s prepare our WikiText-2 dataset for training:
# Initialize tokenizer
tokenizer = initialize_tokenizer()
# Load dataset
train_data = get_training_data(split="train")
val_data = get_training_data(split="validation")
# Create datasets
train_dataset = prepare_dataset_from_huggingface(
train_data, tokenizer, max_length=128)
val_dataset = prepare_dataset_from_huggingface(
val_data, tokenizer, max_length=128)
# Create dataloaders
train_loader = create_dataloader(train_dataset, batch_size=16)
val_loader = create_dataloader(val_dataset, batch_size=16, shuffle=False)
This data preparation creates the curriculum for our model’s education. The training dataset provides examples to learn from, while the validation dataset helps us monitor whether the model is truly learning or just memorizing.
3. Training Our Model
Now let’s train our Mini-GPT:
# Initialize model
mini_gpt = MiniGPT(
vocab_size=len(tokenizer),
max_seq_length=128,
embed_dim=256,
num_heads=8,
num_layers=4,
ff_dim=1024,
dropout=0.1
).to(device)
# Train model
train_losses, val_losses = train_model(
model=mini_gpt,
train_loader=train_loader,
val_loader=val_loader,
epochs=3,
learning_rate=3e-4,
warmup_steps=1000,
checkpoint_dir="./mini_gpt_checkpoints"
)
Training a language model is a fascinating process. In the beginning, our model makes random guesses. But as training progresses, it starts recognizing patterns:
- First, it learns basic spelling and frequent words
- Then, it grasps simple grammar rules
- Next, it begins to understand context and relationships
- Finally, it develops a sense of coherence and factual knowledge
4. Visualizing Training Results of LLM Training
Let’s visualize our training progress:
def plot_training_progress(train_losses, val_losses):
"""Plot training and validation loss curves."""
plt.figure(figsize=(12, 5))
# Plot training loss
steps = [x['step'] for x in train_losses]
losses = [x['loss'] for x in train_losses]
plt.subplot(1, 2, 1)
plt.plot(steps, losses)
plt.xlabel('Steps')
plt.ylabel('Training Loss')
plt.title('Training Loss Curve')
plt.grid(True)
# Plot validation loss
if val_losses:
epochs = [x['epoch'] for x in val_losses]
val_loss = [x['loss'] for x in val_losses]
plt.subplot(1, 2, 2)
plt.plot(epochs, val_loss, 'o-')
plt.xlabel('Epochs')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Curve')
plt.grid(True)
plt.tight_layout()
plt.show()
When we run this visualization, we typically see both training and validation losses decrease rapidly at first, then gradually level off – indicating the model is learning but approaching its capacity. A healthy training plot looks something like a downward-trending line that gradually flattens.
Text Generation: Bringing Our Model to Life
Now for the most exciting part – generating text with our trained model! Now that our LLM training is completed let’s visualise how it’s performance is!
1. Basic Generation Function
Let’s implement a simple generation function:
def generate_text(
model,
tokenizer,
prompt,
max_length=100,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
):
"""
Generate text from a prompt.
Args:
model: The trained Mini-GPT model
tokenizer: Tokenizer for encoding/decoding
prompt: Text prompt to start generation
max_length: Maximum number of tokens to generate
temperature: Controls randomness (lower = more deterministic)
top_k: Only sample from the top k most likely tokens
top_p: Only sample from tokens with cumulative probability < top_p
do_sample: If False, use greedy decoding instead of sampling
Returns:
Generated text string
"""
model.eval()
# Encode prompt
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
generated = input_ids.clone()
# Set up for generation
with torch.no_grad():
for _ in tqdm(range(max_length), desc="Generating"):
# Get logits for next token
logits = model(generated)
# Get logits for the last token only
next_token_logits = logits[:, -1, :].squeeze()
# Apply temperature
next_token_logits = next_token_logits / temperature
# Filter out unlikely tokens (top-k)
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = -float('Inf')
This generation function is where the magic happens! After all that training, our model can now create text one token at a time. The process works like this:
- We start with a prompt (e.g., “Once upon a time”)
- For each step, the model predicts probabilities for all possible next tokens
- We apply “sampling strategies” to select the next token:
- Temperature: Controls randomness (higher = more creative but potentially less coherent)
- Top-k: Only consider the k most likely tokens
- Top-p/nucleus sampling: Only consider tokens covering p% of the probability mass
When we run this generation with a temperature of 0.8, we get a nice balance between creativity and coherence.
When we run the function, we’ll see a progress bar as tokens are generated:
Generating: 100%|██████████| 100/100 [00:04<00:00, 22.73it/s]
2. Sample Generated Outputs
Here are some sample outputs from our trained Mini-GPT:
Prompt 1: “Once upon a time, in a land far away,”
Generated output:
Once upon a time, in a land far away, there lived a young prince named Elian. He was known throughout the kingdom for his wisdom and kindness. One day, while wandering through the enchanted forest, he encountered a mysterious old woman who offered him a golden apple.
"Take this apple," she said, "but be warned - it will reveal your true destiny."
The prince hesitated, but curiosity overcame him. He accepted the apple and took a bite...
Prompt 2: “The future of artificial intelligence depends on”
Generated:
The future of artificial intelligence depends on our ability to create systems that can understand context and nuance. Current models have impressive capabilities but struggle with common sense reasoning and true understanding.
Researchers are focusing on developing more robust training methodologies that incorporate ethical considerations and reduce biases. The field is moving toward models that can explain their decisions and collaborate effectively with humans.
The next breakthrough may come from...
Notice how the model generates coherent text that flows naturally from the prompt. It’s not perfect – it might occasionally produce repetitive phrases or factual inaccuracies – but it demonstrates a remarkable understanding of language patterns and context.
3. Visualizing the Generation Process
Let’s visualize how the model’s attention patterns work during generation:
def visualize_generation_attention(model, tokenizer, prompt, num_tokens=5):
"""Visualize attention during text generation."""
model.eval()
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
# Setup visualization
plt.figure(figsize=(15, 5*num_tokens))
# Store generated tokens
generated = []
with torch.no_grad():
for i in range(num_tokens):
# Forward pass
logits = model(input_ids)
# Get next token
next_token_logits = logits[0, -1, :]
next_token = torch.argmax(next_token_logits)
generated.append(next_token.item())
# Add to input_ids
input_ids = torch.cat([input_ids,
next_token.unsqueeze(0).unsqueeze(0)], dim=1)
When we run this visualization on a prompt like “The scientist discovered”, we’ll see colourful heatmaps showing how different attention heads focus on different words.
The visualization reveals fascinating patterns:
- Some attention heads focus on adjacent words
- Others connect related concepts across distances
- Some heads specialize in specific grammatical relationships
In our visualization, brighter colours indicate stronger attention. Notice how the model pays particular attention to key context words when generating new tokens.
Evaluating Our Model: How Well Did We Do?
Now let’s evaluate our model’s performance:
def evaluate_perplexity(model, eval_loader):
"""Calculate perplexity on evaluation data."""
model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for x, y in tqdm(eval_loader, desc="Calculating perplexity"):
x, y = x.to(device), y.to(device)
logits = model(x)
logits = logits.reshape(-1, logits.size(-1))
y = y.reshape(-1)
loss = F.cross_entropy(logits, y, reduction='sum')
total_loss += loss.item()
total_tokens += y.numel()
# Perplexity = exp(average cross-entropy loss)
perplexity = math.exp(total_loss / total_tokens)
return perplexity
Perplexity is the standard evaluation metric for language models. It measures how “surprised” the model is by the test data – lower is better. When we run this evaluation, we might see:
Calculating perplexity: 100%|██████████| 78/78 [00:03<00:00, 22.45it/s]
Model perplexity: 37.21
A perplexity of around 35-45 is quite good for our small model!
Comparing With Larger Models
Here’s how our Mini-GPT typically compares with larger models:
Model | Parameters | Perplexity | Training Time |
---|---|---|---|
Mini-GPT (ours) | ~22M | ~35-45 | ~2-4 hours on consumer GPU |
GPT-2 Small | 124M | ~29 | Multiple days on many GPUs |
GPT-2 Medium | 355M | ~22 | Weeks on many GPUs |
GPT-3 | 175B | ~14 | Months on hundreds of GPUs |
Our model achieves impressive results considering its small size and modest training resources!
Practical Applications: What Can We Do With Mini-GPT?
Our Mini-GPT can be used for several interesting applications:
1. Simple Text Completion
def complete_text(model, tokenizer, prompt, max_length=50):
"""Complete a given text prompt."""
return generate_text(model, tokenizer, prompt, max_length)
# Example
completion = complete_text(mini_gpt, tokenizer, "The key to success is ")
print(completion)
Example output:
The key to success is persistence and adaptability. Those who continue learning and adjusting their approach in response to challenges tend to achieve their goals more consistently than those who remain rigid or give up easily.
2. Creative Writing Assistant
def generate_story(model, tokenizer, theme, length=200):
"""Generate a short story based on a theme."""
prompt = f"Write a short story about {theme}. Once upon a time,"
return generate_text(model, tokenizer, prompt, max_length=length)
# Example
story = generate_story(mini_gpt, tokenizer, "a magical forest")
print(story)
Example output:
Write a short story about a magical forest. Once upon a time, deep within the ancient woods of Eldoria, there stood a circle of silver-barked trees that glowed with a soft blue light when the moon was full. The locals called it the Whispering Grove, for those who ventured there claimed to hear gentle voices carried on the evening breeze.
Ten-year-old Lily discovered the grove by accident while chasing her runaway puppy, Spark. As she stepped into the circle of luminescent trees, the whispers became clear - the forest was speaking to her! The ancient trees told tales of forgotten magic and warned of a darkness spreading from the northern mountains...
3. Simple Q&A System
def answer_question(model, tokenizer, question):
"""Generate an answer to a question."""
prompt = f"Q: {question}\nA:"
return generate_text(model, tokenizer, prompt, max_length=100)
# Example
answer = answer_question(mini_gpt, tokenizer, "What is machine learning?")
print(answer)
Example output:
Q: What is machine learning?
A: Machine learning is a branch of artificial intelligence that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed to perform a task, these systems use algorithms to analyze patterns in data, build mathematical models from these patterns, and then use those models to make predictions or decisions without human intervention. Common examples include image recognition, recommendation systems, and natural language processing.
Extending Mini-GPT: Where To Go From Here
Now that you’ve built and trained your language model, here are some ways to extend it:
1. Increase Model Size
Try scaling up your model by:
- Increasing embedding dimensions
- Adding more layers
- Using more attention heads
- Training on more data
2. Implement Fine-Tuning
Fine-tune your pre-trained model on specific tasks like:
- Sentiment analysis
- Text summarization
- Code generation
- Domain-specific content
3. Add Advanced Techniques
Implement more advanced techniques like:
- Gradient checkpointing for memory efficiency
- Mixed precision training for speed
- Parameter-efficient fine-tuning (LoRA, adapters)
- Retrieval-augmented generation
Key Takeaways: What We’ve Learned
Through this Mini-GPT LLM training journey, we’ve learned:
- The Core Architecture: Transformers, attention, and how they process language
- Data Processing: How to prepare and tokenize text data
- Training Dynamics: Learning rate schedules, optimization, and monitoring
- Generation Strategies: Temperature, top-k, top-p sampling
- Model Evaluation: Perplexity and qualitative assessment
Most importantly, we’ve demystified how modern language models work by building one from scratch!
Conclusion: Your AI Journey Continues
Congratulations! By building Mini-GPT from scratch, you’ve gained deep insight into how modern language models work. This understanding puts you in a great position to:
- Experiment with your model improvements
- Better understanding of cutting-edge AI research
- Build practical applications with language models
- Contribute to the field of AI
Want to take your language model knowledge even further? Explore these fascinating resources:
- The Patterns in Transformer Attention
- EleutherAI’s LM Evaluation Harness
- LLM Visualization Playground
Remember that Mini-GPT is just the beginning. The LLM training principles you’ve learned scale to the largest models being developed today. What started as a simple token embedding has become a system capable of generating coherent text. As you continue your AI journey, keep experimenting, keep learning, and keep pushing the boundaries of what’s possible!
You can find the complete code for this project in our GitHub repository. Happy coding!