#!/usr/bin/env python3 """ AI-powered scientific schematic generation using Nano Banana 2. This script uses a smart iterative refinement approach: 1. Generate initial image with Nano Banana 2 2. AI quality review using Gemini 3.1 Pro Preview for scientific critique 3. Only regenerate if quality is below threshold for document type 4. Repeat until quality meets standards (max iterations) Requirements: - OPENROUTER_API_KEY environment variable - requests library Usage: python generate_schematic_ai.py "Create a flowchart showing CONSORT participant flow" -o flowchart.png python generate_schematic_ai.py "Neural network architecture diagram" -o architecture.png --iterations 2 python generate_schematic_ai.py "Simple block diagram" -o diagram.png --doc-type poster """ import argparse import base64 import json import os import sys import time from pathlib import Path from typing import Optional, Dict, Any, List, Tuple try: import requests except ImportError: print("Error: requests library not found. Install with: pip install requests") sys.exit(1) # Try to load .env file from multiple potential locations def _load_env_file(): """Load .env file from current directory, parent directories, or package directory. Returns True if a .env file was found and loaded, False otherwise. Note: This does NOT override existing environment variables. """ try: from dotenv import load_dotenv except ImportError: return False # python-dotenv not installed # Try current working directory first env_path = Path.cwd() / ".env" if env_path.exists(): load_dotenv(dotenv_path=env_path, override=False) return True # Try parent directories (up to 5 levels) cwd = Path.cwd() for _ in range(5): env_path = cwd / ".env" if env_path.exists(): load_dotenv(dotenv_path=env_path, override=False) return True cwd = cwd.parent if cwd == cwd.parent: # Reached root break # Try the package's parent directory (scientific-writer project root) script_dir = Path(__file__).resolve().parent for _ in range(5): env_path = script_dir / ".env" if env_path.exists(): load_dotenv(dotenv_path=env_path, override=False) return True script_dir = script_dir.parent if script_dir == script_dir.parent: break return False class ScientificSchematicGenerator: """Generate scientific schematics using AI with smart iterative refinement. Uses Gemini 3.1 Pro Preview for quality review to determine if regeneration is needed. Multiple passes only occur if the generated schematic doesn't meet the quality threshold for the target document type. """ # Quality thresholds by document type (score out of 10) # Higher thresholds for more formal publications QUALITY_THRESHOLDS = { "journal": 8.5, # Nature, Science, etc. - highest standards "conference": 8.0, # Conference papers - high standards "poster": 7.0, # Academic posters - good quality "presentation": 6.5, # Slides/talks - clear but less formal "report": 7.5, # Technical reports - professional "grant": 8.0, # Grant proposals - must be compelling "thesis": 8.0, # Dissertations - formal academic "preprint": 7.5, # arXiv, etc. - good quality "default": 7.5, # Default threshold } # Scientific diagram best practices prompt template SCIENTIFIC_DIAGRAM_GUIDELINES = """ Create a high-quality scientific diagram with these requirements: VISUAL QUALITY: - Clean white or light background (no textures or gradients) - High contrast for readability and printing - Professional, publication-ready appearance - Sharp, clear lines and text - Adequate spacing between elements to prevent crowding TYPOGRAPHY: - Clear, readable sans-serif fonts (Arial, Helvetica style) - Minimum 10pt font size for all labels - Consistent font sizes throughout - All text horizontal or clearly readable - No overlapping text SCIENTIFIC STANDARDS: - Accurate representation of concepts - Clear labels for all components - Include scale bars, legends, or axes where appropriate - Use standard scientific notation and symbols - Include units where applicable ACCESSIBILITY: - Colorblind-friendly color palette (use Okabe-Ito colors if using color) - High contrast between elements - Redundant encoding (shapes + colors, not just colors) - Works well in grayscale LAYOUT: - Logical flow (left-to-right or top-to-bottom) - Clear visual hierarchy - Balanced composition - Appropriate use of whitespace - No clutter or unnecessary decorative elements IMPORTANT - NO FIGURE NUMBERS: - Do NOT include "Figure 1:", "Fig. 1", or any figure numbering in the image - Do NOT add captions or titles like "Figure: ..." at the top or bottom - Figure numbers and captions are added separately in the document/LaTeX - The diagram should contain only the visual content itself """ def __init__(self, api_key: Optional[str] = None, verbose: bool = False): """ Initialize the generator. Args: api_key: OpenRouter API key (or use OPENROUTER_API_KEY env var) verbose: Print detailed progress information """ # Priority: 1) explicit api_key param, 2) environment variable, 3) .env file self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") # If not found in environment, try loading from .env file if not self.api_key: _load_env_file() self.api_key = os.getenv("OPENROUTER_API_KEY") if not self.api_key: raise ValueError( "OPENROUTER_API_KEY not found. Please either:\n" " 1. Set the OPENROUTER_API_KEY environment variable\n" " 2. Add OPENROUTER_API_KEY to your .env file\n" " 3. Pass api_key parameter to the constructor\n" "Get your API key from: https://openrouter.ai/keys" ) self.verbose = verbose self._last_error = None # Track last error for better reporting self.base_url = "https://openrouter.ai/api/v1" # Nano Banana 2 - Google's advanced image generation model # https://openrouter.ai/google/gemini-3-pro-image-preview self.image_model = "google/gemini-3.1-flash-image-preview" # Gemini 3.1 Pro Preview for quality review - excellent vision and reasoning self.review_model = "google/gemini-3.1-pro-preview" def _log(self, message: str): """Log message if verbose mode is enabled.""" if self.verbose: print(f"[{time.strftime('%H:%M:%S')}] {message}") def _make_request(self, model: str, messages: List[Dict[str, Any]], modalities: Optional[List[str]] = None) -> Dict[str, Any]: """ Make a request to OpenRouter API. Args: model: Model identifier messages: List of message dictionaries modalities: Optional list of modalities (e.g., ["image", "text"]) Returns: API response as dictionary """ headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/scientific-writer", "X-Title": "Scientific Schematic Generator" } payload = { "model": model, "messages": messages } if modalities: payload["modalities"] = modalities self._log(f"Making request to {model}...") try: response = requests.post( f"{self.base_url}/chat/completions", headers=headers, json=payload, timeout=120 ) # Try to get response body even on error try: response_json = response.json() except json.JSONDecodeError: response_json = {"raw_text": response.text[:500]} # Check for HTTP errors but include response body in error message if response.status_code != 200: error_detail = response_json.get("error", response_json) self._log(f"HTTP {response.status_code}: {error_detail}") raise RuntimeError(f"API request failed (HTTP {response.status_code}): {error_detail}") return response_json except requests.exceptions.Timeout: raise RuntimeError("API request timed out after 120 seconds") except requests.exceptions.RequestException as e: raise RuntimeError(f"API request failed: {str(e)}") def _extract_image_from_response(self, response: Dict[str, Any]) -> Optional[bytes]: """ Extract base64-encoded image from API response. For Nano Banana 2, images are returned in the 'images' field of the message, not in the 'content' field. Args: response: API response dictionary Returns: Image bytes or None if not found """ try: choices = response.get("choices", []) if not choices: self._log("No choices in response") return None message = choices[0].get("message", {}) # IMPORTANT: Nano Banana 2 returns images in the 'images' field images = message.get("images", []) if images and len(images) > 0: self._log(f"Found {len(images)} image(s) in 'images' field") # Get first image first_image = images[0] if isinstance(first_image, dict): # Extract image_url if first_image.get("type") == "image_url": url = first_image.get("image_url", {}) if isinstance(url, dict): url = url.get("url", "") if url and url.startswith("data:image"): # Extract base64 data after comma if "," in url: base64_str = url.split(",", 1)[1] # Clean whitespace base64_str = base64_str.replace('\n', '').replace('\r', '').replace(' ', '') self._log(f"Extracted base64 data (length: {len(base64_str)})") return base64.b64decode(base64_str) # Fallback: check content field (for other models or future changes) content = message.get("content", "") if self.verbose: self._log(f"Content type: {type(content)}, length: {len(str(content))}") # Handle string content if isinstance(content, str) and "data:image" in content: import re match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=\n\r]+)', content, re.DOTALL) if match: base64_str = match.group(1).replace('\n', '').replace('\r', '').replace(' ', '') self._log(f"Found image in content field (length: {len(base64_str)})") return base64.b64decode(base64_str) # Handle list content if isinstance(content, list): for i, block in enumerate(content): if isinstance(block, dict) and block.get("type") == "image_url": url = block.get("image_url", {}) if isinstance(url, dict): url = url.get("url", "") if url and url.startswith("data:image") and "," in url: base64_str = url.split(",", 1)[1].replace('\n', '').replace('\r', '').replace(' ', '') self._log(f"Found image in content block {i}") return base64.b64decode(base64_str) self._log("No image data found in response") return None except Exception as e: self._log(f"Error extracting image: {str(e)}") import traceback if self.verbose: traceback.print_exc() return None def _image_to_base64(self, image_path: str) -> str: """ Convert image file to base64 data URL. Args: image_path: Path to image file Returns: Base64 data URL string """ with open(image_path, "rb") as f: image_data = f.read() # Determine image type from extension ext = Path(image_path).suffix.lower() mime_type = { ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp" }.get(ext, "image/png") base64_data = base64.b64encode(image_data).decode("utf-8") return f"data:{mime_type};base64,{base64_data}" def generate_image(self, prompt: str) -> Optional[bytes]: """ Generate an image using Nano Banana 2. Args: prompt: Description of the diagram to generate Returns: Image bytes or None if generation failed """ self._last_error = None # Reset error messages = [ { "role": "user", "content": prompt } ] try: response = self._make_request( model=self.image_model, messages=messages, modalities=["image", "text"] ) # Debug: print response structure if verbose if self.verbose: self._log(f"Response keys: {response.keys()}") if "error" in response: self._log(f"API Error: {response['error']}") if "choices" in response and response["choices"]: msg = response["choices"][0].get("message", {}) self._log(f"Message keys: {msg.keys()}") # Show content preview without printing huge base64 data content = msg.get("content", "") if isinstance(content, str): preview = content[:200] + "..." if len(content) > 200 else content self._log(f"Content preview: {preview}") elif isinstance(content, list): self._log(f"Content is list with {len(content)} items") for i, item in enumerate(content[:3]): if isinstance(item, dict): self._log(f" Item {i}: type={item.get('type')}") # Check for API errors in response if "error" in response: error_msg = response["error"] if isinstance(error_msg, dict): error_msg = error_msg.get("message", str(error_msg)) self._last_error = f"API Error: {error_msg}" print(f"✗ {self._last_error}") return None image_data = self._extract_image_from_response(response) if image_data: self._log(f"✓ Generated image ({len(image_data)} bytes)") else: self._last_error = "No image data in API response - model may not support image generation" self._log(f"✗ {self._last_error}") # Additional debug info when image extraction fails if self.verbose and "choices" in response: msg = response["choices"][0].get("message", {}) self._log(f"Full message structure: {json.dumps({k: type(v).__name__ for k, v in msg.items()})}") return image_data except RuntimeError as e: self._last_error = str(e) self._log(f"✗ Generation failed: {self._last_error}") return None except Exception as e: self._last_error = f"Unexpected error: {str(e)}" self._log(f"✗ Generation failed: {self._last_error}") import traceback if self.verbose: traceback.print_exc() return None def review_image(self, image_path: str, original_prompt: str, iteration: int, doc_type: str = "default", max_iterations: int = 2) -> Tuple[str, float, bool]: """ Review generated image using Gemini 3.1 Pro Preview for quality analysis. Uses Gemini 3.1 Pro Preview's superior vision and reasoning capabilities to evaluate the schematic quality and determine if regeneration is needed. Args: image_path: Path to the generated image original_prompt: Original user prompt iteration: Current iteration number doc_type: Document type (journal, poster, presentation, etc.) max_iterations: Maximum iterations allowed Returns: Tuple of (critique text, quality score 0-10, needs_improvement bool) """ # Use Gemini 3.1 Pro Preview for review - excellent vision and analysis image_data_url = self._image_to_base64(image_path) # Get quality threshold for this document type threshold = self.QUALITY_THRESHOLDS.get(doc_type.lower(), self.QUALITY_THRESHOLDS["default"]) review_prompt = f"""You are an expert reviewer evaluating a scientific diagram for publication quality. ORIGINAL REQUEST: {original_prompt} DOCUMENT TYPE: {doc_type} (quality threshold: {threshold}/10) ITERATION: {iteration}/{max_iterations} Carefully evaluate this diagram on these criteria: 1. **Scientific Accuracy** (0-2 points) - Correct representation of concepts - Proper notation and symbols - Accurate relationships shown 2. **Clarity and Readability** (0-2 points) - Easy to understand at a glance - Clear visual hierarchy - No ambiguous elements 3. **Label Quality** (0-2 points) - All important elements labeled - Labels are readable (appropriate font size) - Consistent labeling style 4. **Layout and Composition** (0-2 points) - Logical flow (top-to-bottom or left-to-right) - Balanced use of space - No overlapping elements 5. **Professional Appearance** (0-2 points) - Publication-ready quality - Clean, crisp lines and shapes - Appropriate colors/contrast RESPOND IN THIS EXACT FORMAT: SCORE: [total score 0-10] STRENGTHS: - [strength 1] - [strength 2] ISSUES: - [issue 1 if any] - [issue 2 if any] VERDICT: [ACCEPTABLE or NEEDS_IMPROVEMENT] If score >= {threshold}, the diagram is ACCEPTABLE for {doc_type} publication. If score < {threshold}, mark as NEEDS_IMPROVEMENT with specific suggestions.""" messages = [ { "role": "user", "content": [ { "type": "text", "text": review_prompt }, { "type": "image_url", "image_url": { "url": image_data_url } } ] } ] try: # Use Gemini 3.1 Pro Preview for high-quality review response = self._make_request( model=self.review_model, messages=messages ) # Extract text response choices = response.get("choices", []) if not choices: return "Image generated successfully", 8.0 message = choices[0].get("message", {}) content = message.get("content", "") # Check reasoning field (Nano Banana 2 puts analysis here) reasoning = message.get("reasoning", "") if reasoning and not content: content = reasoning if isinstance(content, list): # Extract text from content blocks text_parts = [] for block in content: if isinstance(block, dict) and block.get("type") == "text": text_parts.append(block.get("text", "")) content = "\n".join(text_parts) # Try to extract score score = 7.5 # Default score if extraction fails import re # Look for SCORE: X or SCORE: X/10 format score_match = re.search(r'SCORE:\s*(\d+(?:\.\d+)?)', content, re.IGNORECASE) if score_match: score = float(score_match.group(1)) else: # Fallback: look for any score pattern score_match = re.search(r'(?:score|rating|quality)[:\s]+(\d+(?:\.\d+)?)\s*(?:/\s*10)?', content, re.IGNORECASE) if score_match: score = float(score_match.group(1)) # Determine if improvement is needed based on verdict or score needs_improvement = False if "NEEDS_IMPROVEMENT" in content.upper(): needs_improvement = True elif score < threshold: needs_improvement = True self._log(f"✓ Review complete (Score: {score}/10, Threshold: {threshold}/10)") self._log(f" Verdict: {'Needs improvement' if needs_improvement else 'Acceptable'}") return (content if content else "Image generated successfully", score, needs_improvement) except Exception as e: self._log(f"Review skipped: {str(e)}") # Don't fail the whole process if review fails - assume acceptable return "Image generated successfully (review skipped)", 7.5, False def improve_prompt(self, original_prompt: str, critique: str, iteration: int) -> str: """ Improve the generation prompt based on critique. Args: original_prompt: Original user prompt critique: Review critique from previous iteration iteration: Current iteration number Returns: Improved prompt for next generation """ improved_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES} USER REQUEST: {original_prompt} ITERATION {iteration}: Based on previous feedback, address these specific improvements: {critique} Generate an improved version that addresses all the critique points while maintaining scientific accuracy and professional quality.""" return improved_prompt def generate_iterative(self, user_prompt: str, output_path: str, iterations: int = 2, doc_type: str = "default") -> Dict[str, Any]: """ Generate scientific schematic with smart iterative refinement. Only regenerates if the quality score is below the threshold for the specified document type. This saves API calls and time when the first generation is already good enough. Args: user_prompt: User's description of desired diagram output_path: Path to save final image iterations: Maximum refinement iterations (default: 2, max: 2) doc_type: Document type for quality threshold (journal, poster, etc.) Returns: Dictionary with generation results and metadata """ output_path = Path(output_path) output_dir = output_path.parent output_dir.mkdir(parents=True, exist_ok=True) base_name = output_path.stem extension = output_path.suffix or ".png" # Get quality threshold for this document type threshold = self.QUALITY_THRESHOLDS.get(doc_type.lower(), self.QUALITY_THRESHOLDS["default"]) results = { "user_prompt": user_prompt, "doc_type": doc_type, "quality_threshold": threshold, "iterations": [], "final_image": None, "final_score": 0.0, "success": False, "early_stop": False, "early_stop_reason": None } current_prompt = f"""{self.SCIENTIFIC_DIAGRAM_GUIDELINES} USER REQUEST: {user_prompt} Generate a publication-quality scientific diagram that meets all the guidelines above.""" print(f"\n{'='*60}") print(f"Generating Scientific Schematic") print(f"{'='*60}") print(f"Description: {user_prompt}") print(f"Document Type: {doc_type}") print(f"Quality Threshold: {threshold}/10") print(f"Max Iterations: {iterations}") print(f"Output: {output_path}") print(f"{'='*60}\n") for i in range(1, iterations + 1): print(f"\n[Iteration {i}/{iterations}]") print("-" * 40) # Generate image print(f"Generating image...") image_data = self.generate_image(current_prompt) if not image_data: error_msg = getattr(self, '_last_error', 'Image generation failed - no image data returned') print(f"✗ Generation failed: {error_msg}") results["iterations"].append({ "iteration": i, "success": False, "error": error_msg }) continue # Save iteration image iter_path = output_dir / f"{base_name}_v{i}{extension}" with open(iter_path, "wb") as f: f.write(image_data) print(f"✓ Saved: {iter_path}") # Review image using Gemini 3.1 Pro Preview print(f"Reviewing image with Gemini 3.1 Pro Preview...") critique, score, needs_improvement = self.review_image( str(iter_path), user_prompt, i, doc_type, iterations ) print(f"✓ Score: {score}/10 (threshold: {threshold}/10)") # Save iteration results iteration_result = { "iteration": i, "image_path": str(iter_path), "prompt": current_prompt, "critique": critique, "score": score, "needs_improvement": needs_improvement, "success": True } results["iterations"].append(iteration_result) # Check if quality is acceptable - STOP EARLY if so if not needs_improvement: print(f"\n✓ Quality meets {doc_type} threshold ({score} >= {threshold})") print(f" No further iterations needed!") results["final_image"] = str(iter_path) results["final_score"] = score results["success"] = True results["early_stop"] = True results["early_stop_reason"] = f"Quality score {score} meets threshold {threshold} for {doc_type}" break # If this is the last iteration, we're done regardless if i == iterations: print(f"\n⚠ Maximum iterations reached") results["final_image"] = str(iter_path) results["final_score"] = score results["success"] = True break # Quality below threshold - improve prompt for next iteration print(f"\n⚠ Quality below threshold ({score} < {threshold})") print(f"Improving prompt based on feedback...") current_prompt = self.improve_prompt(user_prompt, critique, i + 1) # Copy final version to output path if results["success"] and results["final_image"]: final_iter_path = Path(results["final_image"]) if final_iter_path != output_path: import shutil shutil.copy(final_iter_path, output_path) print(f"\n✓ Final image: {output_path}") # Save review log log_path = output_dir / f"{base_name}_review_log.json" with open(log_path, "w") as f: json.dump(results, f, indent=2) print(f"✓ Review log: {log_path}") print(f"\n{'='*60}") print(f"Generation Complete!") print(f"Final Score: {results['final_score']}/10") if results["early_stop"]: print(f"Iterations Used: {len([r for r in results['iterations'] if r.get('success')])}/{iterations} (early stop)") print(f"{'='*60}\n") return results def main(): """Command-line interface.""" parser = argparse.ArgumentParser( description="Generate scientific schematics using AI with smart iterative refinement", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Generate a flowchart for a journal paper python generate_schematic_ai.py "CONSORT participant flow diagram" -o flowchart.png --doc-type journal # Generate neural network architecture for presentation (lower threshold) python generate_schematic_ai.py "Transformer encoder-decoder architecture" -o transformer.png --doc-type presentation # Generate with custom max iterations for poster python generate_schematic_ai.py "Biological signaling pathway" -o pathway.png --iterations 2 --doc-type poster # Verbose output python generate_schematic_ai.py "Circuit diagram" -o circuit.png -v Document Types (quality thresholds): journal 8.5/10 - Nature, Science, peer-reviewed journals conference 8.0/10 - Conference papers thesis 8.0/10 - Dissertations, theses grant 8.0/10 - Grant proposals preprint 7.5/10 - arXiv, bioRxiv, etc. report 7.5/10 - Technical reports poster 7.0/10 - Academic posters presentation 6.5/10 - Slides, talks default 7.5/10 - General purpose Note: Multiple iterations only occur if quality is BELOW the threshold. If the first generation meets the threshold, no extra API calls are made. Environment: OPENROUTER_API_KEY OpenRouter API key (required) """ ) parser.add_argument("prompt", help="Description of the diagram to generate") parser.add_argument("-o", "--output", required=True, help="Output image path (e.g., diagram.png)") parser.add_argument("--iterations", type=int, default=2, help="Maximum refinement iterations (default: 2, max: 2)") parser.add_argument("--doc-type", default="default", choices=["journal", "conference", "poster", "presentation", "report", "grant", "thesis", "preprint", "default"], help="Document type for quality threshold (default: default)") parser.add_argument("--api-key", help="OpenRouter API key (or set OPENROUTER_API_KEY)") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") args = parser.parse_args() # Check for API key api_key = args.api_key or os.getenv("OPENROUTER_API_KEY") if not api_key: print("Error: OPENROUTER_API_KEY environment variable not set") print("\nSet it with:") print(" export OPENROUTER_API_KEY='your_api_key'") print("\nOr provide via --api-key flag") sys.exit(1) # Validate iterations - enforce max of 2 if args.iterations < 1 or args.iterations > 2: print("Error: Iterations must be between 1 and 2") sys.exit(1) try: generator = ScientificSchematicGenerator(api_key=api_key, verbose=args.verbose) results = generator.generate_iterative( user_prompt=args.prompt, output_path=args.output, iterations=args.iterations, doc_type=args.doc_type ) if results["success"]: print(f"\n✓ Success! Image saved to: {args.output}") if results.get("early_stop"): print(f" (Completed in {len([r for r in results['iterations'] if r.get('success')])} iteration(s) - quality threshold met)") sys.exit(0) else: print(f"\n✗ Generation failed. Check review log for details.") sys.exit(1) except Exception as e: print(f"\n✗ Error: {str(e)}") sys.exit(1) if __name__ == "__main__": main()