mirror of
https://github.com/K-Dense-AI/claude-scientific-skills.git
synced 2026-01-26 16:58:56 +08:00
839 lines
32 KiB
Python
839 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
AI-powered scientific schematic generation using Nano Banana Pro.
|
|
|
|
This script uses a smart iterative refinement approach:
|
|
1. Generate initial image with Nano Banana Pro
|
|
2. AI quality review using Gemini 3 Pro for scientific critique
|
|
3. Only regenerate if quality is below threshold for document type
|
|
4. Repeat until quality meets standards (max iterations)
|
|
|
|
Requirements:
|
|
- OPENROUTER_API_KEY environment variable
|
|
- requests library
|
|
|
|
Usage:
|
|
python generate_schematic_ai.py "Create a flowchart showing CONSORT participant flow" -o flowchart.png
|
|
python generate_schematic_ai.py "Neural network architecture diagram" -o architecture.png --iterations 2
|
|
python generate_schematic_ai.py "Simple block diagram" -o diagram.png --doc-type poster
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
print("Error: requests library not found. Install with: pip install requests")
|
|
sys.exit(1)
|
|
|
|
# Try to load .env file from multiple potential locations
|
|
def _load_env_file():
|
|
"""Load .env file from current directory, parent directories, or package directory.
|
|
|
|
Returns True if a .env file was found and loaded, False otherwise.
|
|
Note: This does NOT override existing environment variables.
|
|
"""
|
|
try:
|
|
from dotenv import load_dotenv
|
|
except ImportError:
|
|
return False # python-dotenv not installed
|
|
|
|
# Try current working directory first
|
|
env_path = Path.cwd() / ".env"
|
|
if env_path.exists():
|
|
load_dotenv(dotenv_path=env_path, override=False)
|
|
return True
|
|
|
|
# Try parent directories (up to 5 levels)
|
|
cwd = Path.cwd()
|
|
for _ in range(5):
|
|
env_path = cwd / ".env"
|
|
if env_path.exists():
|
|
load_dotenv(dotenv_path=env_path, override=False)
|
|
return True
|
|
cwd = cwd.parent
|
|
if cwd == cwd.parent: # Reached root
|
|
break
|
|
|
|
# Try the package's parent directory (scientific-writer project root)
|
|
script_dir = Path(__file__).resolve().parent
|
|
for _ in range(5):
|
|
env_path = script_dir / ".env"
|
|
if env_path.exists():
|
|
load_dotenv(dotenv_path=env_path, override=False)
|
|
return True
|
|
script_dir = script_dir.parent
|
|
if script_dir == script_dir.parent:
|
|
break
|
|
|
|
return False
|
|
|
|
|
|
class ScientificSchematicGenerator:
|
|
"""Generate scientific schematics using AI with smart iterative refinement.
|
|
|
|
Uses Gemini 3 Pro for quality review to determine if regeneration is needed.
|
|
Multiple passes only occur if the generated schematic doesn't meet the
|
|
quality threshold for the target document type.
|
|
"""
|
|
|
|
# Quality thresholds by document type (score out of 10)
|
|
# Higher thresholds for more formal publications
|
|
QUALITY_THRESHOLDS = {
|
|
"journal": 8.5, # Nature, Science, etc. - highest standards
|
|
"conference": 8.0, # Conference papers - high standards
|
|
"poster": 7.0, # Academic posters - good quality
|
|
"presentation": 6.5, # Slides/talks - clear but less formal
|
|
"report": 7.5, # Technical reports - professional
|
|
"grant": 8.0, # Grant proposals - must be compelling
|
|
"thesis": 8.0, # Dissertations - formal academic
|
|
"preprint": 7.5, # arXiv, etc. - good quality
|
|
"default": 7.5, # Default threshold
|
|
}
|
|
|
|
# Scientific diagram best practices prompt template
|
|
SCIENTIFIC_DIAGRAM_GUIDELINES = """
|
|
Create a high-quality scientific diagram with these requirements:
|
|
|
|
VISUAL QUALITY:
|
|
- Clean white or light background (no textures or gradients)
|
|
- High contrast for readability and printing
|
|
- Professional, publication-ready appearance
|
|
- Sharp, clear lines and text
|
|
- Adequate spacing between elements to prevent crowding
|
|
|
|
TYPOGRAPHY:
|
|
- Clear, readable sans-serif fonts (Arial, Helvetica style)
|
|
- Minimum 10pt font size for all labels
|
|
- Consistent font sizes throughout
|
|
- All text horizontal or clearly readable
|
|
- No overlapping text
|
|
|
|
SCIENTIFIC STANDARDS:
|
|
- Accurate representation of concepts
|
|
- Clear labels for all components
|
|
- Include scale bars, legends, or axes where appropriate
|
|
- Use standard scientific notation and symbols
|
|
- Include units where applicable
|
|
|
|
ACCESSIBILITY:
|
|
- Colorblind-friendly color palette (use Okabe-Ito colors if using color)
|
|
- High contrast between elements
|
|
- Redundant encoding (shapes + colors, not just colors)
|
|
- Works well in grayscale
|
|
|
|
LAYOUT:
|
|
- Logical flow (left-to-right or top-to-bottom)
|
|
- Clear visual hierarchy
|
|
- Balanced composition
|
|
- Appropriate use of whitespace
|
|
- No clutter or unnecessary decorative elements
|
|
"""
|
|
|
|
def __init__(self, api_key: Optional[str] = None, verbose: bool = False):
|
|
"""
|
|
Initialize the generator.
|
|
|
|
Args:
|
|
api_key: OpenRouter API key (or use OPENROUTER_API_KEY env var)
|
|
verbose: Print detailed progress information
|
|
"""
|
|
# Priority: 1) explicit api_key param, 2) environment variable, 3) .env file
|
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
|
|
# If not found in environment, try loading from .env file
|
|
if not self.api_key:
|
|
_load_env_file()
|
|
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
|
|
if not self.api_key:
|
|
raise ValueError(
|
|
"OPENROUTER_API_KEY not found. Please either:\n"
|
|
" 1. Set the OPENROUTER_API_KEY environment variable\n"
|
|
" 2. Add OPENROUTER_API_KEY to your .env file\n"
|
|
" 3. Pass api_key parameter to the constructor\n"
|
|
"Get your API key from: https://openrouter.ai/keys"
|
|
)
|
|
|
|
self.verbose = verbose
|
|
self._last_error = None # Track last error for better reporting
|
|
self.base_url = "https://openrouter.ai/api/v1"
|
|
# Nano Banana Pro - Google's advanced image generation model
|
|
# https://openrouter.ai/google/gemini-3-pro-image-preview
|
|
self.image_model = "google/gemini-3-pro-image-preview"
|
|
# Gemini 3 Pro for quality review - excellent vision and reasoning
|
|
self.review_model = "google/gemini-3-pro"
|
|
|
|
def _log(self, message: str):
|
|
"""Log message if verbose mode is enabled."""
|
|
if self.verbose:
|
|
print(f"[{time.strftime('%H:%M:%S')}] {message}")
|
|
|
|
def _make_request(self, model: str, messages: List[Dict[str, Any]],
|
|
modalities: Optional[List[str]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Make a request to OpenRouter API.
|
|
|
|
Args:
|
|
model: Model identifier
|
|
messages: List of message dictionaries
|
|
modalities: Optional list of modalities (e.g., ["image", "text"])
|
|
|
|
Returns:
|
|
API response as dictionary
|
|
"""
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
"HTTP-Referer": "https://github.com/scientific-writer",
|
|
"X-Title": "Scientific Schematic Generator"
|
|
}
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages
|
|
}
|
|
|
|
if modalities:
|
|
payload["modalities"] = modalities
|
|
|
|
self._log(f"Making request to {model}...")
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.base_url}/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=120
|
|
)
|
|
|
|
# Try to get response body even on error
|
|
try:
|
|
response_json = response.json()
|
|
except json.JSONDecodeError:
|
|
response_json = {"raw_text": response.text[:500]}
|
|
|
|
# Check for HTTP errors but include response body in error message
|
|
if response.status_code != 200:
|
|
error_detail = response_json.get("error", response_json)
|
|
self._log(f"HTTP {response.status_code}: {error_detail}")
|
|
raise RuntimeError(f"API request failed (HTTP {response.status_code}): {error_detail}")
|
|
|
|
return response_json
|
|
except requests.exceptions.Timeout:
|
|
raise RuntimeError("API request timed out after 120 seconds")
|
|
except requests.exceptions.RequestException as e:
|
|
raise RuntimeError(f"API request failed: {str(e)}")
|
|
|
|
def _extract_image_from_response(self, response: Dict[str, Any]) -> Optional[bytes]:
|
|
"""
|
|
Extract base64-encoded image from API response.
|
|
|
|
For Nano Banana Pro, images are returned in the 'images' field of the message,
|
|
not in the 'content' field.
|
|
|
|
Args:
|
|
response: API response dictionary
|
|
|
|
Returns:
|
|
Image bytes or None if not found
|
|
"""
|
|
try:
|
|
choices = response.get("choices", [])
|
|
if not choices:
|
|
self._log("No choices in response")
|
|
return None
|
|
|
|
message = choices[0].get("message", {})
|
|
|
|
# IMPORTANT: Nano Banana Pro returns images in the 'images' field
|
|
images = message.get("images", [])
|
|
if images and len(images) > 0:
|
|
self._log(f"Found {len(images)} image(s) in 'images' field")
|
|
|
|
# Get first image
|
|
first_image = images[0]
|
|
if isinstance(first_image, dict):
|
|
# Extract image_url
|
|
if first_image.get("type") == "image_url":
|
|
url = first_image.get("image_url", {})
|
|
if isinstance(url, dict):
|
|
url = url.get("url", "")
|
|
|
|
if url and url.startswith("data:image"):
|
|
# Extract base64 data after comma
|
|
if "," in url:
|
|
base64_str = url.split(",", 1)[1]
|
|
# Clean whitespace
|
|
base64_str = base64_str.replace('\n', '').replace('\r', '').replace(' ', '')
|
|
self._log(f"Extracted base64 data (length: {len(base64_str)})")
|
|
return base64.b64decode(base64_str)
|
|
|
|
# Fallback: check content field (for other models or future changes)
|
|
content = message.get("content", "")
|
|
|
|
if self.verbose:
|
|
self._log(f"Content type: {type(content)}, length: {len(str(content))}")
|
|
|
|
# Handle string content
|
|
if isinstance(content, str) and "data:image" in content:
|
|
import re
|
|
match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=\n\r]+)', content, re.DOTALL)
|
|
if match:
|
|
base64_str = match.group(1).replace('\n', '').replace('\r', '').replace(' ', '')
|
|
self._log(f"Found image in content field (length: {len(base64_str)})")
|
|
return base64.b64decode(base64_str)
|
|
|
|
# Handle list content
|
|
if isinstance(content, list):
|
|
for i, block in enumerate(content):
|
|
if isinstance(block, dict) and block.get("type") == "image_url":
|
|
url = block.get("image_url", {})
|
|
if isinstance(url, dict):
|
|
url = url.get("url", "")
|
|
if url and url.startswith("data:image") and "," in url:
|
|
base64_str = url.split(",", 1)[1].replace('\n', '').replace('\r', '').replace(' ', '')
|
|
self._log(f"Found image in content block {i}")
|
|
return base64.b64decode(base64_str)
|
|
|
|
self._log("No image data found in response")
|
|
return None
|
|
|
|
except Exception as e:
|
|
self._log(f"Error extracting image: {str(e)}")
|
|
import traceback
|
|
if self.verbose:
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
def _image_to_base64(self, image_path: str) -> str:
|
|
"""
|
|
Convert image file to base64 data URL.
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
|
|
Returns:
|
|
Base64 data URL string
|
|
"""
|
|
with open(image_path, "rb") as f:
|
|
image_data = f.read()
|
|
|
|
# Determine image type from extension
|
|
ext = Path(image_path).suffix.lower()
|
|
mime_type = {
|
|
".png": "image/png",
|
|
".jpg": "image/jpeg",
|
|
".jpeg": "image/jpeg",
|
|
".gif": "image/gif",
|
|
".webp": "image/webp"
|
|
}.get(ext, "image/png")
|
|
|
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
|
return f"data:{mime_type};base64,{base64_data}"
|
|
|
|
def generate_image(self, prompt: str) -> Optional[bytes]:
|
|
"""
|
|
Generate an image using Nano Banana Pro.
|
|
|
|
Args:
|
|
prompt: Description of the diagram to generate
|
|
|
|
Returns:
|
|
Image bytes or None if generation failed
|
|
"""
|
|
self._last_error = None # Reset error
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
]
|
|
|
|
try:
|
|
response = self._make_request(
|
|
model=self.image_model,
|
|
messages=messages,
|
|
modalities=["image", "text"]
|
|
)
|
|
|
|
# Debug: print response structure if verbose
|
|
if self.verbose:
|
|
self._log(f"Response keys: {response.keys()}")
|
|
if "error" in response:
|
|
self._log(f"API Error: {response['error']}")
|
|
if "choices" in response and response["choices"]:
|
|
msg = response["choices"][0].get("message", {})
|
|
self._log(f"Message keys: {msg.keys()}")
|
|
# Show content preview without printing huge base64 data
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
preview = content[:200] + "..." if len(content) > 200 else content
|
|
self._log(f"Content preview: {preview}")
|
|
elif isinstance(content, list):
|
|
self._log(f"Content is list with {len(content)} items")
|
|
for i, item in enumerate(content[:3]):
|
|
if isinstance(item, dict):
|
|
self._log(f" Item {i}: type={item.get('type')}")
|
|
|
|
# Check for API errors in response
|
|
if "error" in response:
|
|
error_msg = response["error"]
|
|
if isinstance(error_msg, dict):
|
|
error_msg = error_msg.get("message", str(error_msg))
|
|
self._last_error = f"API Error: {error_msg}"
|
|
print(f"✗ {self._last_error}")
|
|
return None
|
|
|
|
image_data = self._extract_image_from_response(response)
|
|
if image_data:
|
|
self._log(f"✓ Generated image ({len(image_data)} bytes)")
|
|
else:
|
|
self._last_error = "No image data in API response - model may not support image generation"
|
|
self._log(f"✗ {self._last_error}")
|
|
# Additional debug info when image extraction fails
|
|
if self.verbose and "choices" in response:
|
|
msg = response["choices"][0].get("message", {})
|
|
self._log(f"Full message structure: {json.dumps({k: type(v).__name__ for k, v in msg.items()})}")
|
|
|
|
return image_data
|
|
except RuntimeError as e:
|
|
self._last_error = str(e)
|
|
self._log(f"✗ Generation failed: {self._last_error}")
|
|
return None
|
|
except Exception as e:
|
|
self._last_error = f"Unexpected error: {str(e)}"
|
|
self._log(f"✗ Generation failed: {self._last_error}")
|
|
import traceback
|
|
if self.verbose:
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
def review_image(self, image_path: str, original_prompt: str,
|
|
iteration: int, doc_type: str = "default",
|
|
max_iterations: int = 2) -> Tuple[str, float, bool]:
|
|
"""
|
|
Review generated image using Gemini 3 Pro for quality analysis.
|
|
|
|
Uses Gemini 3 Pro's superior vision and reasoning capabilities to
|
|
evaluate the schematic quality and determine if regeneration is needed.
|
|
|
|
Args:
|
|
image_path: Path to the generated image
|
|
original_prompt: Original user prompt
|
|
iteration: Current iteration number
|
|
doc_type: Document type (journal, poster, presentation, etc.)
|
|
max_iterations: Maximum iterations allowed
|
|
|
|
Returns:
|
|
Tuple of (critique text, quality score 0-10, needs_improvement bool)
|
|
"""
|
|
# Use Gemini 3 Pro for review - excellent vision and analysis
|
|
image_data_url = self._image_to_base64(image_path)
|
|
|
|
# Get quality threshold for this document type
|
|
threshold = self.QUALITY_THRESHOLDS.get(doc_type.lower(),
|
|
self.QUALITY_THRESHOLDS["default"])
|
|
|
|
review_prompt = f"""You are an expert reviewer evaluating a scientific diagram for publication quality.
|
|
|
|
ORIGINAL REQUEST: {original_prompt}
|
|
|
|
DOCUMENT TYPE: {doc_type} (quality threshold: {threshold}/10)
|
|
ITERATION: {iteration}/{max_iterations}
|
|
|
|
Carefully evaluate this diagram on these criteria:
|
|
|
|
1. **Scientific Accuracy** (0-2 points)
|
|
- Correct representation of concepts
|
|
- Proper notation and symbols
|
|
- Accurate relationships shown
|
|
|
|
2. **Clarity and Readability** (0-2 points)
|
|
- Easy to understand at a glance
|
|
- Clear visual hierarchy
|
|
- No ambiguous elements
|
|
|
|
3. **Label Quality** (0-2 points)
|
|
- All important elements labeled
|
|
- Labels are readable (appropriate font size)
|
|
- Consistent labeling style
|
|
|
|
4. **Layout and Composition** (0-2 points)
|
|
- Logical flow (top-to-bottom or left-to-right)
|
|
- Balanced use of space
|
|
- No overlapping elements
|
|
|
|
5. **Professional Appearance** (0-2 points)
|
|
- Publication-ready quality
|
|
- Clean, crisp lines and shapes
|
|
- Appropriate colors/contrast
|
|
|
|
RESPOND IN THIS EXACT FORMAT:
|
|
SCORE: [total score 0-10]
|
|
|
|
STRENGTHS:
|
|
- [strength 1]
|
|
- [strength 2]
|
|
|
|
ISSUES:
|
|
- [issue 1 if any]
|
|
- [issue 2 if any]
|
|
|
|
VERDICT: [ACCEPTABLE or NEEDS_IMPROVEMENT]
|
|
|
|
If score >= {threshold}, the diagram is ACCEPTABLE for {doc_type} publication.
|
|
If score < {threshold}, mark as NEEDS_IMPROVEMENT with specific suggestions."""
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": review_prompt
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": image_data_url
|
|
}
|
|
}
|
|
]
|
|
}
|
|
]
|
|
|
|
try:
|
|
# Use Gemini 3 Pro for high-quality review
|
|
response = self._make_request(
|
|
model=self.review_model,
|
|
messages=messages
|
|
)
|
|
|
|
# Extract text response
|
|
choices = response.get("choices", [])
|
|
if not choices:
|
|
return "Image generated successfully", 8.0
|
|
|
|
message = choices[0].get("message", {})
|
|
content = message.get("content", "")
|
|
|
|
# Check reasoning field (Nano Banana Pro puts analysis here)
|
|
reasoning = message.get("reasoning", "")
|
|
if reasoning and not content:
|
|
content = reasoning
|
|
|
|
if isinstance(content, list):
|
|
# Extract text from content blocks
|
|
text_parts = []
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
text_parts.append(block.get("text", ""))
|
|
content = "\n".join(text_parts)
|
|
|
|
# Try to extract score
|
|
score = 7.5 # Default score if extraction fails
|
|
import re
|
|
|
|
# Look for SCORE: X or SCORE: X/10 format
|
|
score_match = re.search(r'SCORE:\s*(\d+(?:\.\d+)?)', content, re.IGNORECASE)
|
|
if score_match:
|
|
score = float(score_match.group(1))
|
|
else:
|
|
# Fallback: look for any score pattern
|
|
score_match = re.search(r'(?:score|rating|quality)[:\s]+(\d+(?:\.\d+)?)\s*(?:/\s*10)?', content, re.IGNORECASE)
|
|
if score_match:
|
|
score = float(score_match.group(1))
|
|
|
|
# Determine if improvement is needed based on verdict or score
|
|
needs_improvement = False
|
|
if "NEEDS_IMPROVEMENT" in content.upper():
|
|
needs_improvement = True
|
|
elif score < threshold:
|
|
needs_improvement = True
|
|
|
|
self._log(f"✓ Review complete (Score: {score}/10, Threshold: {threshold}/10)")
|
|
self._log(f" Verdict: {'Needs improvement' if needs_improvement else 'Acceptable'}")
|
|
|
|
return (content if content else "Image generated successfully",
|
|
score,
|
|
needs_improvement)
|
|
except Exception as e:
|
|
self._log(f"Review skipped: {str(e)}")
|
|
# Don't fail the whole process if review fails - assume acceptable
|
|
return "Image generated successfully (review skipped)", 7.5, False
|
|
|
|
def improve_prompt(self, original_prompt: str, critique: str,
|
|
iteration: int) -> str:
|
|
"""
|
|
Improve the generation prompt based on critique.
|
|
|
|
Args:
|
|
original_prompt: Original user prompt
|
|
critique: Review critique from previous iteration
|
|
iteration: Current iteration number
|
|
|
|
Returns:
|
|
Improved prompt for next generation
|
|
"""
|
|
improved_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES}
|
|
|
|
USER REQUEST: {original_prompt}
|
|
|
|
ITERATION {iteration}: Based on previous feedback, address these specific improvements:
|
|
{critique}
|
|
|
|
Generate an improved version that addresses all the critique points while maintaining scientific accuracy and professional quality."""
|
|
|
|
return improved_prompt
|
|
|
|
def generate_iterative(self, user_prompt: str, output_path: str,
|
|
iterations: int = 2,
|
|
doc_type: str = "default") -> Dict[str, Any]:
|
|
"""
|
|
Generate scientific schematic with smart iterative refinement.
|
|
|
|
Only regenerates if the quality score is below the threshold for the
|
|
specified document type. This saves API calls and time when the first
|
|
generation is already good enough.
|
|
|
|
Args:
|
|
user_prompt: User's description of desired diagram
|
|
output_path: Path to save final image
|
|
iterations: Maximum refinement iterations (default: 2, max: 2)
|
|
doc_type: Document type for quality threshold (journal, poster, etc.)
|
|
|
|
Returns:
|
|
Dictionary with generation results and metadata
|
|
"""
|
|
output_path = Path(output_path)
|
|
output_dir = output_path.parent
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
base_name = output_path.stem
|
|
extension = output_path.suffix or ".png"
|
|
|
|
# Get quality threshold for this document type
|
|
threshold = self.QUALITY_THRESHOLDS.get(doc_type.lower(),
|
|
self.QUALITY_THRESHOLDS["default"])
|
|
|
|
results = {
|
|
"user_prompt": user_prompt,
|
|
"doc_type": doc_type,
|
|
"quality_threshold": threshold,
|
|
"iterations": [],
|
|
"final_image": None,
|
|
"final_score": 0.0,
|
|
"success": False,
|
|
"early_stop": False,
|
|
"early_stop_reason": None
|
|
}
|
|
|
|
current_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES}
|
|
|
|
USER REQUEST: {user_prompt}
|
|
|
|
Generate a publication-quality scientific diagram that meets all the guidelines above."""
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Generating Scientific Schematic")
|
|
print(f"{'='*60}")
|
|
print(f"Description: {user_prompt}")
|
|
print(f"Document Type: {doc_type}")
|
|
print(f"Quality Threshold: {threshold}/10")
|
|
print(f"Max Iterations: {iterations}")
|
|
print(f"Output: {output_path}")
|
|
print(f"{'='*60}\n")
|
|
|
|
for i in range(1, iterations + 1):
|
|
print(f"\n[Iteration {i}/{iterations}]")
|
|
print("-" * 40)
|
|
|
|
# Generate image
|
|
print(f"Generating image...")
|
|
image_data = self.generate_image(current_prompt)
|
|
|
|
if not image_data:
|
|
error_msg = getattr(self, '_last_error', 'Image generation failed - no image data returned')
|
|
print(f"✗ Generation failed: {error_msg}")
|
|
results["iterations"].append({
|
|
"iteration": i,
|
|
"success": False,
|
|
"error": error_msg
|
|
})
|
|
continue
|
|
|
|
# Save iteration image
|
|
iter_path = output_dir / f"{base_name}_v{i}{extension}"
|
|
with open(iter_path, "wb") as f:
|
|
f.write(image_data)
|
|
print(f"✓ Saved: {iter_path}")
|
|
|
|
# Review image using Gemini 3 Pro
|
|
print(f"Reviewing image with Gemini 3 Pro...")
|
|
critique, score, needs_improvement = self.review_image(
|
|
str(iter_path), user_prompt, i, doc_type, iterations
|
|
)
|
|
print(f"✓ Score: {score}/10 (threshold: {threshold}/10)")
|
|
|
|
# Save iteration results
|
|
iteration_result = {
|
|
"iteration": i,
|
|
"image_path": str(iter_path),
|
|
"prompt": current_prompt,
|
|
"critique": critique,
|
|
"score": score,
|
|
"needs_improvement": needs_improvement,
|
|
"success": True
|
|
}
|
|
results["iterations"].append(iteration_result)
|
|
|
|
# Check if quality is acceptable - STOP EARLY if so
|
|
if not needs_improvement:
|
|
print(f"\n✓ Quality meets {doc_type} threshold ({score} >= {threshold})")
|
|
print(f" No further iterations needed!")
|
|
results["final_image"] = str(iter_path)
|
|
results["final_score"] = score
|
|
results["success"] = True
|
|
results["early_stop"] = True
|
|
results["early_stop_reason"] = f"Quality score {score} meets threshold {threshold} for {doc_type}"
|
|
break
|
|
|
|
# If this is the last iteration, we're done regardless
|
|
if i == iterations:
|
|
print(f"\n⚠ Maximum iterations reached")
|
|
results["final_image"] = str(iter_path)
|
|
results["final_score"] = score
|
|
results["success"] = True
|
|
break
|
|
|
|
# Quality below threshold - improve prompt for next iteration
|
|
print(f"\n⚠ Quality below threshold ({score} < {threshold})")
|
|
print(f"Improving prompt based on feedback...")
|
|
current_prompt = self.improve_prompt(user_prompt, critique, i + 1)
|
|
|
|
# Copy final version to output path
|
|
if results["success"] and results["final_image"]:
|
|
final_iter_path = Path(results["final_image"])
|
|
if final_iter_path != output_path:
|
|
import shutil
|
|
shutil.copy(final_iter_path, output_path)
|
|
print(f"\n✓ Final image: {output_path}")
|
|
|
|
# Save review log
|
|
log_path = output_dir / f"{base_name}_review_log.json"
|
|
with open(log_path, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"✓ Review log: {log_path}")
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Generation Complete!")
|
|
print(f"Final Score: {results['final_score']}/10")
|
|
if results["early_stop"]:
|
|
print(f"Iterations Used: {len([r for r in results['iterations'] if r.get('success')])}/{iterations} (early stop)")
|
|
print(f"{'='*60}\n")
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
"""Command-line interface."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate scientific schematics using AI with smart iterative refinement",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Generate a flowchart for a journal paper
|
|
python generate_schematic_ai.py "CONSORT participant flow diagram" -o flowchart.png --doc-type journal
|
|
|
|
# Generate neural network architecture for presentation (lower threshold)
|
|
python generate_schematic_ai.py "Transformer encoder-decoder architecture" -o transformer.png --doc-type presentation
|
|
|
|
# Generate with custom max iterations for poster
|
|
python generate_schematic_ai.py "Biological signaling pathway" -o pathway.png --iterations 2 --doc-type poster
|
|
|
|
# Verbose output
|
|
python generate_schematic_ai.py "Circuit diagram" -o circuit.png -v
|
|
|
|
Document Types (quality thresholds):
|
|
journal 8.5/10 - Nature, Science, peer-reviewed journals
|
|
conference 8.0/10 - Conference papers
|
|
thesis 8.0/10 - Dissertations, theses
|
|
grant 8.0/10 - Grant proposals
|
|
preprint 7.5/10 - arXiv, bioRxiv, etc.
|
|
report 7.5/10 - Technical reports
|
|
poster 7.0/10 - Academic posters
|
|
presentation 6.5/10 - Slides, talks
|
|
default 7.5/10 - General purpose
|
|
|
|
Note: Multiple iterations only occur if quality is BELOW the threshold.
|
|
If the first generation meets the threshold, no extra API calls are made.
|
|
|
|
Environment:
|
|
OPENROUTER_API_KEY OpenRouter API key (required)
|
|
"""
|
|
)
|
|
|
|
parser.add_argument("prompt", help="Description of the diagram to generate")
|
|
parser.add_argument("-o", "--output", required=True,
|
|
help="Output image path (e.g., diagram.png)")
|
|
parser.add_argument("--iterations", type=int, default=2,
|
|
help="Maximum refinement iterations (default: 2, max: 2)")
|
|
parser.add_argument("--doc-type", default="default",
|
|
choices=["journal", "conference", "poster", "presentation",
|
|
"report", "grant", "thesis", "preprint", "default"],
|
|
help="Document type for quality threshold (default: default)")
|
|
parser.add_argument("--api-key", help="OpenRouter API key (or set OPENROUTER_API_KEY)")
|
|
parser.add_argument("-v", "--verbose", action="store_true",
|
|
help="Verbose output")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check for API key
|
|
api_key = args.api_key or os.getenv("OPENROUTER_API_KEY")
|
|
if not api_key:
|
|
print("Error: OPENROUTER_API_KEY environment variable not set")
|
|
print("\nSet it with:")
|
|
print(" export OPENROUTER_API_KEY='your_api_key'")
|
|
print("\nOr provide via --api-key flag")
|
|
sys.exit(1)
|
|
|
|
# Validate iterations - enforce max of 2
|
|
if args.iterations < 1 or args.iterations > 2:
|
|
print("Error: Iterations must be between 1 and 2")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
generator = ScientificSchematicGenerator(api_key=api_key, verbose=args.verbose)
|
|
results = generator.generate_iterative(
|
|
user_prompt=args.prompt,
|
|
output_path=args.output,
|
|
iterations=args.iterations,
|
|
doc_type=args.doc_type
|
|
)
|
|
|
|
if results["success"]:
|
|
print(f"\n✓ Success! Image saved to: {args.output}")
|
|
if results.get("early_stop"):
|
|
print(f" (Completed in {len([r for r in results['iterations'] if r.get('success')])} iteration(s) - quality threshold met)")
|
|
sys.exit(0)
|
|
else:
|
|
print(f"\n✗ Generation failed. Check review log for details.")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"\n✗ Error: {str(e)}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|