14 KiB
Text Generation Strategies
Comprehensive guide to text generation methods in Transformers for controlling output quality, creativity, and diversity.
Overview
Text generation is the process of predicting tokens sequentially using a language model. The choice of generation strategy significantly impacts output quality, diversity, and computational cost.
When to use each strategy:
- Greedy: Fast, deterministic, good for short outputs or when consistency is critical
- Beam Search: Better quality for tasks with clear "correct" answers (translation, summarization)
- Sampling: Creative, diverse outputs for open-ended generation (stories, dialogue)
- Top-k/Top-p: Balanced creativity and coherence
Basic Generation Methods
Greedy Decoding
Selects the highest probability token at each step. Fast but prone to repetition and suboptimal sequences.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("The future of AI", return_tensors="pt")
# Greedy decoding (default)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))
Characteristics:
- Deterministic (always same output for same input)
- Fast (single forward pass per token)
- Prone to repetition in longer sequences
- Best for: Short generations, deterministic applications
Parameters:
outputs = model.generate(
**inputs,
max_new_tokens=50, # Number of tokens to generate
min_length=10, # Minimum total length
pad_token_id=tokenizer.pad_token_id,
)
Beam Search
Maintains multiple hypotheses (beams) and selects the sequence with highest overall probability.
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_beams=5, # Number of beams
early_stopping=True, # Stop when all beams finish
no_repeat_ngram_size=2, # Prevent 2-gram repetition
)
Characteristics:
- Higher quality than greedy for tasks with "correct" answers
- Slower than greedy (num_beams forward passes per step)
- Still can suffer from repetition
- Best for: Translation, summarization, QA generation
Advanced Parameters:
outputs = model.generate(
**inputs,
num_beams=5,
num_beam_groups=1, # Diverse beam search groups
diversity_penalty=0.0, # Penalty for similar beams
length_penalty=1.0, # >1: longer sequences, <1: shorter
early_stopping=True, # Stop when num_beams sequences finish
no_repeat_ngram_size=2, # Block repeating n-grams
num_return_sequences=1, # Return top-k sequences (≤ num_beams)
)
Length Penalty:
length_penalty > 1.0: Favor longer sequenceslength_penalty = 1.0: No penaltylength_penalty < 1.0: Favor shorter sequences
Sampling (Multinomial)
Randomly sample tokens according to the probability distribution.
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True, # Enable sampling
temperature=1.0, # Sampling temperature
num_beams=1, # Must be 1 for sampling
)
Characteristics:
- Non-deterministic (different output each time)
- More diverse and creative than greedy/beam search
- Can produce incoherent output if not controlled
- Best for: Creative writing, dialogue, open-ended generation
Temperature Parameter:
# Low temperature (0.1-0.7): More focused, less random
outputs = model.generate(**inputs, do_sample=True, temperature=0.5)
# Medium temperature (0.7-1.0): Balanced
outputs = model.generate(**inputs, do_sample=True, temperature=0.8)
# High temperature (1.0-2.0): More random, more creative
outputs = model.generate(**inputs, do_sample=True, temperature=1.5)
temperature → 0: Approaches greedy decodingtemperature = 1.0: Sample from original distributiontemperature > 1.0: Flatter distribution, more randomtemperature < 1.0: Sharper distribution, more confident
Advanced Sampling Methods
Top-k Sampling
Sample from only the k most likely tokens.
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=50,
top_k=50, # Consider top 50 tokens
temperature=0.8,
)
How it works:
- Filter to top-k most probable tokens
- Renormalize probabilities
- Sample from filtered distribution
Choosing k:
k=1: Equivalent to greedy decodingk=10-50: More focused, coherent outputk=100-500: More diverse output- Too high k: Includes low-probability tokens (noise)
- Too low k: Less diverse, may miss good alternatives
Top-p (Nucleus) Sampling
Sample from the smallest set of tokens whose cumulative probability ≥ p.
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=50,
top_p=0.95, # Nucleus probability
temperature=0.8,
)
How it works:
- Sort tokens by probability
- Find smallest set with cumulative probability ≥ p
- Sample from this set
Choosing p:
p=0.9-0.95: Good balance (recommended)p=1.0: Sample from full distribution- Higher p: More diverse, might include unlikely tokens
- Lower p: More focused, like top-k with adaptive k
Top-p vs Top-k:
- Top-p adapts to probability distribution shape
- Top-k is fixed regardless of distribution
- Top-p generally better for variable-quality contexts
- Can combine:
top_k=50, top_p=0.95(apply both filters)
Combining Strategies
# Recommended for high-quality open-ended generation
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=100,
temperature=0.8, # Moderate temperature
top_k=50, # Limit to top 50 tokens
top_p=0.95, # Nucleus sampling
repetition_penalty=1.2, # Discourage repetition
no_repeat_ngram_size=3, # Block 3-gram repetition
)
Controlling Generation Quality
Repetition Control
Prevent models from repeating themselves:
outputs = model.generate(
**inputs,
max_new_tokens=100,
# Method 1: Repetition penalty
repetition_penalty=1.2, # Penalize repeated tokens (>1.0)
# Method 2: Block n-gram repetition
no_repeat_ngram_size=3, # Never repeat 3-grams
# Method 3: Encoder repetition penalty (for seq2seq)
encoder_repetition_penalty=1.0, # Penalize input tokens
)
Repetition Penalty Values:
1.0: No penalty1.0-1.5: Mild penalty (recommended: 1.1-1.3)>1.5: Strong penalty (may harm coherence)
Length Control
outputs = model.generate(
**inputs,
# Hard constraints
min_length=20, # Minimum total length
max_length=100, # Maximum total length
max_new_tokens=50, # Maximum new tokens (excluding input)
# Soft constraints (with beam search)
length_penalty=1.0, # Encourage longer/shorter outputs
# Early stopping
early_stopping=True, # Stop when condition met
)
Bad Words and Forced Tokens
# Prevent specific tokens
bad_words_ids = [
tokenizer.encode("badword1", add_special_tokens=False),
tokenizer.encode("badword2", add_special_tokens=False),
]
outputs = model.generate(
**inputs,
bad_words_ids=bad_words_ids,
)
# Force specific tokens
force_words_ids = [
tokenizer.encode("important", add_special_tokens=False),
]
outputs = model.generate(
**inputs,
force_words_ids=force_words_ids,
)
Streaming Generation
Generate and process tokens as they're produced:
from transformers import TextStreamer, TextIteratorStreamer
from threading import Thread
# Simple streaming (prints to stdout)
streamer = TextStreamer(tokenizer, skip_prompt=True)
outputs = model.generate(**inputs, streamer=streamer, max_new_tokens=100)
# Iterator streaming (for custom processing)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=100)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for text in streamer:
print(text, end="", flush=True)
thread.join()
Advanced Techniques
Contrastive Search
Balance coherence and diversity using contrastive objective:
outputs = model.generate(
**inputs,
max_new_tokens=50,
penalty_alpha=0.6, # Contrastive penalty
top_k=4, # Consider top-4 tokens
)
When to use:
- Open-ended text generation
- Reduces repetition without sacrificing coherence
- Good alternative to sampling
Diverse Beam Search
Generate multiple diverse outputs:
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_beams=10,
num_beam_groups=5, # 5 groups of 2 beams each
diversity_penalty=1.0, # Penalty for similar beams
num_return_sequences=5, # Return 5 diverse outputs
)
Constrained Beam Search
Force output to include specific phrases:
from transformers import PhrasalConstraint
constraints = [
PhrasalConstraint(
tokenizer("machine learning", add_special_tokens=False).input_ids
),
]
outputs = model.generate(
**inputs,
constraints=constraints,
num_beams=10, # Requires beam search
)
Speculative Decoding
Accelerate generation using a smaller draft model:
from transformers import AutoModelForCausalLM
# Load main and assistant models
model = AutoModelForCausalLM.from_pretrained("large-model")
assistant_model = AutoModelForCausalLM.from_pretrained("small-model")
# Generate with speculative decoding
outputs = model.generate(
**inputs,
assistant_model=assistant_model,
do_sample=True,
temperature=0.8,
)
Benefits:
- 2-3x faster generation
- Identical output distribution to regular generation
- Works with sampling and greedy decoding
Recipe: Recommended Settings by Task
Creative Writing / Dialogue
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=200,
temperature=0.9,
top_p=0.95,
top_k=50,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
)
Translation / Summarization
outputs = model.generate(
**inputs,
num_beams=5,
max_new_tokens=150,
early_stopping=True,
length_penalty=1.0,
no_repeat_ngram_size=2,
)
Code Generation
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.2, # Low temperature for correctness
top_p=0.95,
do_sample=True,
)
Chatbot / Instruction Following
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.15,
)
Factual QA / Information Extraction
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_beams=3,
early_stopping=True,
# Or greedy for very short answers:
# (no special parameters needed)
)
Debugging Generation
Check Token Probabilities
outputs = model.generate(
**inputs,
max_new_tokens=20,
output_scores=True, # Return generation scores
return_dict_in_generate=True, # Return as dict
)
# Access generation scores
scores = outputs.scores # Tuple of tensors (seq_len, vocab_size)
# Get token probabilities
import torch
probs = torch.softmax(scores[0], dim=-1)
Monitor Generation Process
from transformers import LogitsProcessor, LogitsProcessorList
class DebugLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids, scores):
# Print top 5 tokens at each step
top_tokens = scores[0].topk(5)
print(f"Top 5 tokens: {top_tokens}")
return scores
outputs = model.generate(
**inputs,
max_new_tokens=10,
logits_processor=LogitsProcessorList([DebugLogitsProcessor()]),
)
Common Issues and Solutions
Issue: Repetitive output
- Solution: Increase
repetition_penalty(1.2-1.5), setno_repeat_ngram_size=3 - For sampling: Increase
temperature, enabletop_p
Issue: Incoherent output
- Solution: Lower
temperature(0.5-0.8), use beam search - Set
top_k=50ortop_p=0.9to filter unlikely tokens
Issue: Too short output
- Solution: Increase
min_length, setlength_penalty > 1.0(beam search) - Check if EOS token is being generated early
Issue: Too slow generation
- Solution: Use greedy instead of beam search
- Reduce
num_beams - Try speculative decoding with assistant model
- Use smaller model variant
Issue: Output doesn't follow format
- Solution: Use constrained beam search
- Add format examples to prompt
- Use
bad_words_idsto prevent format-breaking tokens
Performance Optimization
# Use half precision
model = AutoModelForCausalLM.from_pretrained(
"model-name",
torch_dtype=torch.float16,
device_map="auto"
)
# Use KV cache optimization (default, but can be disabled)
outputs = model.generate(**inputs, use_cache=True)
# Batch generation
inputs = tokenizer(["Prompt 1", "Prompt 2"], return_tensors="pt", padding=True)
outputs = model.generate(**inputs, max_new_tokens=50)
# Static cache for longer sequences (if supported)
outputs = model.generate(**inputs, cache_implementation="static")
This guide covers the main generation strategies. For task-specific examples, see task_patterns.md.