Files
claude-scientific-skills/scientific-packages/transformers/scripts/generate_text.py
2025-10-19 14:12:02 -07:00

233 lines
6.5 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Text generation with various strategies.
This script demonstrates different generation strategies:
- Greedy decoding
- Beam search
- Sampling with temperature
- Top-k and top-p sampling
Usage:
python generate_text.py --model gpt2 --prompt "The future of AI" --strategy sampling
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_with_greedy(model, tokenizer, prompt, max_length):
"""Greedy decoding (deterministic)."""
print("\n" + "=" * 60)
print("GREEDY DECODING")
print("=" * 60)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"\nGenerated:\n{text}")
def generate_with_beam_search(model, tokenizer, prompt, max_length, num_beams=5):
"""Beam search for higher quality."""
print("\n" + "=" * 60)
print(f"BEAM SEARCH (num_beams={num_beams})")
print("=" * 60)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
num_beams=num_beams,
early_stopping=True,
no_repeat_ngram_size=2,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"\nGenerated:\n{text}")
def generate_with_sampling(model, tokenizer, prompt, max_length, temperature=0.8):
"""Sampling with temperature."""
print("\n" + "=" * 60)
print(f"SAMPLING (temperature={temperature})")
print("=" * 60)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"\nGenerated:\n{text}")
def generate_with_top_k_top_p(model, tokenizer, prompt, max_length, top_k=50, top_p=0.95, temperature=0.8):
"""Top-k and top-p (nucleus) sampling."""
print("\n" + "=" * 60)
print(f"TOP-K TOP-P SAMPLING (k={top_k}, p={top_p}, temp={temperature})")
print("=" * 60)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"\nGenerated:\n{text}")
def generate_multiple(model, tokenizer, prompt, max_length, num_sequences=3):
"""Generate multiple diverse sequences."""
print("\n" + "=" * 60)
print(f"MULTIPLE SEQUENCES (n={num_sequences})")
print("=" * 60)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
num_return_sequences=num_sequences,
temperature=0.9,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
)
print(f"\nPrompt: {prompt}\n")
for i, output in enumerate(outputs, 1):
text = tokenizer.decode(output, skip_special_tokens=True)
print(f"\n--- Sequence {i} ---\n{text}\n")
def main():
parser = argparse.ArgumentParser(description="Text generation with various strategies")
parser.add_argument(
"--model",
type=str,
default="gpt2",
help="Model name or path",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Input prompt for generation",
)
parser.add_argument(
"--strategy",
type=str,
default="all",
choices=["greedy", "beam", "sampling", "top_k_top_p", "multiple", "all"],
help="Generation strategy to use",
)
parser.add_argument(
"--max-length",
type=int,
default=100,
help="Maximum number of new tokens to generate",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Device (cuda, cpu, or auto)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Sampling temperature",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Use 8-bit quantization",
)
args = parser.parse_args()
print("=" * 60)
print("Text Generation Demo")
print("=" * 60)
print(f"Model: {args.model}")
print(f"Strategy: {args.strategy}")
print(f"Max length: {args.max_length}")
print(f"Device: {args.device}")
print("=" * 60)
# Load model and tokenizer
print("\nLoading model...")
if args.device == "auto":
device_map = "auto"
device = None
else:
device_map = None
device = args.device
model_kwargs = {"device_map": device_map} if device_map else {}
if args.quantize:
print("Using 8-bit quantization...")
model_kwargs["load_in_8bit"] = True
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if device and not device_map:
model = model.to(device)
print(f"Model loaded on: {model.device if hasattr(model, 'device') else 'multiple devices'}")
# Generate based on strategy
strategies = {
"greedy": lambda: generate_with_greedy(model, tokenizer, args.prompt, args.max_length),
"beam": lambda: generate_with_beam_search(model, tokenizer, args.prompt, args.max_length),
"sampling": lambda: generate_with_sampling(model, tokenizer, args.prompt, args.max_length, args.temperature),
"top_k_top_p": lambda: generate_with_top_k_top_p(model, tokenizer, args.prompt, args.max_length),
"multiple": lambda: generate_multiple(model, tokenizer, args.prompt, args.max_length),
}
if args.strategy == "all":
for strategy_fn in strategies.values():
strategy_fn()
else:
strategies[args.strategy]()
print("\n" + "=" * 60)
print("Generation complete!")
print("=" * 60)
if __name__ == "__main__":
main()