mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-01-26 16:58:56 +08:00
233 lines
6.5 KiB
Python
Executable File
233 lines
6.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Text generation with various strategies.
|
|
|
|
This script demonstrates different generation strategies:
|
|
- Greedy decoding
|
|
- Beam search
|
|
- Sampling with temperature
|
|
- Top-k and top-p sampling
|
|
|
|
Usage:
|
|
python generate_text.py --model gpt2 --prompt "The future of AI" --strategy sampling
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
def generate_with_greedy(model, tokenizer, prompt, max_length):
|
|
"""Greedy decoding (deterministic)."""
|
|
print("\n" + "=" * 60)
|
|
print("GREEDY DECODING")
|
|
print("=" * 60)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_length,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
print(f"\nPrompt: {prompt}")
|
|
print(f"\nGenerated:\n{text}")
|
|
|
|
|
|
def generate_with_beam_search(model, tokenizer, prompt, max_length, num_beams=5):
|
|
"""Beam search for higher quality."""
|
|
print("\n" + "=" * 60)
|
|
print(f"BEAM SEARCH (num_beams={num_beams})")
|
|
print("=" * 60)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_length,
|
|
num_beams=num_beams,
|
|
early_stopping=True,
|
|
no_repeat_ngram_size=2,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
print(f"\nPrompt: {prompt}")
|
|
print(f"\nGenerated:\n{text}")
|
|
|
|
|
|
def generate_with_sampling(model, tokenizer, prompt, max_length, temperature=0.8):
|
|
"""Sampling with temperature."""
|
|
print("\n" + "=" * 60)
|
|
print(f"SAMPLING (temperature={temperature})")
|
|
print("=" * 60)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_length,
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
print(f"\nPrompt: {prompt}")
|
|
print(f"\nGenerated:\n{text}")
|
|
|
|
|
|
def generate_with_top_k_top_p(model, tokenizer, prompt, max_length, top_k=50, top_p=0.95, temperature=0.8):
|
|
"""Top-k and top-p (nucleus) sampling."""
|
|
print("\n" + "=" * 60)
|
|
print(f"TOP-K TOP-P SAMPLING (k={top_k}, p={top_p}, temp={temperature})")
|
|
print("=" * 60)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_length,
|
|
do_sample=True,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
repetition_penalty=1.2,
|
|
no_repeat_ngram_size=3,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
print(f"\nPrompt: {prompt}")
|
|
print(f"\nGenerated:\n{text}")
|
|
|
|
|
|
def generate_multiple(model, tokenizer, prompt, max_length, num_sequences=3):
|
|
"""Generate multiple diverse sequences."""
|
|
print("\n" + "=" * 60)
|
|
print(f"MULTIPLE SEQUENCES (n={num_sequences})")
|
|
print("=" * 60)
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=max_length,
|
|
do_sample=True,
|
|
num_return_sequences=num_sequences,
|
|
temperature=0.9,
|
|
top_p=0.95,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
print(f"\nPrompt: {prompt}\n")
|
|
for i, output in enumerate(outputs, 1):
|
|
text = tokenizer.decode(output, skip_special_tokens=True)
|
|
print(f"\n--- Sequence {i} ---\n{text}\n")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Text generation with various strategies")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="gpt2",
|
|
help="Model name or path",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
required=True,
|
|
help="Input prompt for generation",
|
|
)
|
|
parser.add_argument(
|
|
"--strategy",
|
|
type=str,
|
|
default="all",
|
|
choices=["greedy", "beam", "sampling", "top_k_top_p", "multiple", "all"],
|
|
help="Generation strategy to use",
|
|
)
|
|
parser.add_argument(
|
|
"--max-length",
|
|
type=int,
|
|
default=100,
|
|
help="Maximum number of new tokens to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="auto",
|
|
help="Device (cuda, cpu, or auto)",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.8,
|
|
help="Sampling temperature",
|
|
)
|
|
parser.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
help="Use 8-bit quantization",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 60)
|
|
print("Text Generation Demo")
|
|
print("=" * 60)
|
|
print(f"Model: {args.model}")
|
|
print(f"Strategy: {args.strategy}")
|
|
print(f"Max length: {args.max_length}")
|
|
print(f"Device: {args.device}")
|
|
print("=" * 60)
|
|
|
|
# Load model and tokenizer
|
|
print("\nLoading model...")
|
|
|
|
if args.device == "auto":
|
|
device_map = "auto"
|
|
device = None
|
|
else:
|
|
device_map = None
|
|
device = args.device
|
|
|
|
model_kwargs = {"device_map": device_map} if device_map else {}
|
|
|
|
if args.quantize:
|
|
print("Using 8-bit quantization...")
|
|
model_kwargs["load_in_8bit"] = True
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
|
|
if device and not device_map:
|
|
model = model.to(device)
|
|
|
|
print(f"Model loaded on: {model.device if hasattr(model, 'device') else 'multiple devices'}")
|
|
|
|
# Generate based on strategy
|
|
strategies = {
|
|
"greedy": lambda: generate_with_greedy(model, tokenizer, args.prompt, args.max_length),
|
|
"beam": lambda: generate_with_beam_search(model, tokenizer, args.prompt, args.max_length),
|
|
"sampling": lambda: generate_with_sampling(model, tokenizer, args.prompt, args.max_length, args.temperature),
|
|
"top_k_top_p": lambda: generate_with_top_k_top_p(model, tokenizer, args.prompt, args.max_length),
|
|
"multiple": lambda: generate_multiple(model, tokenizer, args.prompt, args.max_length),
|
|
}
|
|
|
|
if args.strategy == "all":
|
|
for strategy_fn in strategies.values():
|
|
strategy_fn()
|
|
else:
|
|
strategies[args.strategy]()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Generation complete!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|