diff --git a/__pycache__/engine.cpython-312.pyc b/__pycache__/engine.cpython-312.pyc index fc73241..b1fbcdd 100644 Binary files a/__pycache__/engine.cpython-312.pyc and b/__pycache__/engine.cpython-312.pyc differ diff --git a/__pycache__/gallery_app.cpython-312.pyc b/__pycache__/gallery_app.cpython-312.pyc index ea14328..9f8a4f5 100644 Binary files a/__pycache__/gallery_app.cpython-312.pyc and b/__pycache__/gallery_app.cpython-312.pyc differ diff --git a/engine.py b/engine.py index c34910c..1a03618 100644 --- a/engine.py +++ b/engine.py @@ -1,6 +1,9 @@ import os import shutil import sqlite3 +import base64 +import requests +from datetime import datetime from contextlib import contextmanager from PIL import Image from io import BytesIO @@ -68,6 +71,27 @@ class SorterEngine: for cat in ["_TRASH", "control", "Default", "Action", "Solo"]: cursor.execute("INSERT OR IGNORE INTO categories VALUES (?)", (cat,)) + # --- CAPTION TABLES --- + # Per-category prompt templates + cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts + (profile TEXT, category TEXT, prompt_template TEXT, + PRIMARY KEY (profile, category))''') + + # Stored captions + cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions + (image_path TEXT PRIMARY KEY, caption TEXT, model TEXT, + generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + + # Caption API settings per profile + cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings + (profile TEXT PRIMARY KEY, + api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions', + model_name TEXT DEFAULT 'local-model', + max_tokens INTEGER DEFAULT 300, + temperature REAL DEFAULT 0.7, + timeout_seconds INTEGER DEFAULT 60, + batch_size INTEGER DEFAULT 4)''') + # --- PERFORMANCE INDEXES --- # Index for staging_area queries filtered by category cursor.execute("CREATE INDEX IF NOT EXISTS idx_staging_category ON staging_area(target_category)") @@ -75,6 +99,8 @@ class SorterEngine: cursor.execute("CREATE INDEX IF NOT EXISTS idx_folder_tags_profile ON folder_tags(profile, folder_path)") # Index for profile_categories lookups cursor.execute("CREATE INDEX IF NOT EXISTS idx_profile_categories ON profile_categories(profile)") + # Index for caption lookups by image path + cursor.execute("CREATE INDEX IF NOT EXISTS idx_image_captions ON image_captions(image_path)") conn.commit() conn.close() @@ -826,4 +852,371 @@ class SorterEngine: ) result = {row[0]: {"cat": row[1], "index": row[2]} for row in cursor.fetchall()} conn.close() - return result \ No newline at end of file + return result + + # --- 8. CAPTION SETTINGS & PROMPTS --- + @staticmethod + def get_caption_settings(profile): + """Get caption API settings for a profile.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + # Ensure table exists + cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings + (profile TEXT PRIMARY KEY, + api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions', + model_name TEXT DEFAULT 'local-model', + max_tokens INTEGER DEFAULT 300, + temperature REAL DEFAULT 0.7, + timeout_seconds INTEGER DEFAULT 60, + batch_size INTEGER DEFAULT 4)''') + + cursor.execute("SELECT * FROM caption_settings WHERE profile = ?", (profile,)) + row = cursor.fetchone() + conn.close() + + if row: + return { + "profile": row[0], + "api_endpoint": row[1], + "model_name": row[2], + "max_tokens": row[3], + "temperature": row[4], + "timeout_seconds": row[5], + "batch_size": row[6] + } + else: + # Return defaults + return { + "profile": profile, + "api_endpoint": "http://localhost:8080/v1/chat/completions", + "model_name": "local-model", + "max_tokens": 300, + "temperature": 0.7, + "timeout_seconds": 60, + "batch_size": 4 + } + + @staticmethod + def save_caption_settings(profile, api_endpoint=None, model_name=None, max_tokens=None, + temperature=None, timeout_seconds=None, batch_size=None): + """Save caption API settings for a profile.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + # Ensure table exists + cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings + (profile TEXT PRIMARY KEY, + api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions', + model_name TEXT DEFAULT 'local-model', + max_tokens INTEGER DEFAULT 300, + temperature REAL DEFAULT 0.7, + timeout_seconds INTEGER DEFAULT 60, + batch_size INTEGER DEFAULT 4)''') + + # Get existing values + cursor.execute("SELECT * FROM caption_settings WHERE profile = ?", (profile,)) + row = cursor.fetchone() + + if not row: + row = (profile, "http://localhost:8080/v1/chat/completions", "local-model", 300, 0.7, 60, 4) + + new_values = ( + profile, + api_endpoint if api_endpoint is not None else row[1], + model_name if model_name is not None else row[2], + max_tokens if max_tokens is not None else row[3], + temperature if temperature is not None else row[4], + timeout_seconds if timeout_seconds is not None else row[5], + batch_size if batch_size is not None else row[6] + ) + + cursor.execute("INSERT OR REPLACE INTO caption_settings VALUES (?, ?, ?, ?, ?, ?, ?)", new_values) + conn.commit() + conn.close() + + @staticmethod + def get_category_prompt(profile, category): + """Get prompt template for a category.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts + (profile TEXT, category TEXT, prompt_template TEXT, + PRIMARY KEY (profile, category))''') + + cursor.execute( + "SELECT prompt_template FROM category_prompts WHERE profile = ? AND category = ?", + (profile, category) + ) + row = cursor.fetchone() + conn.close() + + if row and row[0]: + return row[0] + else: + # Default prompt + return "Describe this image in detail for training purposes. Include subjects, actions, setting, colors, and composition." + + @staticmethod + def save_category_prompt(profile, category, prompt): + """Save prompt template for a category.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts + (profile TEXT, category TEXT, prompt_template TEXT, + PRIMARY KEY (profile, category))''') + + cursor.execute( + "INSERT OR REPLACE INTO category_prompts VALUES (?, ?, ?)", + (profile, category, prompt) + ) + conn.commit() + conn.close() + + @staticmethod + def get_all_category_prompts(profile): + """Get all prompt templates for a profile.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts + (profile TEXT, category TEXT, prompt_template TEXT, + PRIMARY KEY (profile, category))''') + + cursor.execute( + "SELECT category, prompt_template FROM category_prompts WHERE profile = ?", + (profile,) + ) + result = {row[0]: row[1] for row in cursor.fetchall()} + conn.close() + return result + + # --- 9. CAPTION STORAGE --- + @staticmethod + def save_caption(image_path, caption, model): + """Save a generated caption to the database.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions + (image_path TEXT PRIMARY KEY, caption TEXT, model TEXT, + generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + + cursor.execute( + "INSERT OR REPLACE INTO image_captions VALUES (?, ?, ?, ?)", + (image_path, caption, model, datetime.now().isoformat()) + ) + conn.commit() + conn.close() + + @staticmethod + def get_caption(image_path): + """Get caption for an image.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions + (image_path TEXT PRIMARY KEY, caption TEXT, model TEXT, + generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + + cursor.execute( + "SELECT caption, model, generated_at FROM image_captions WHERE image_path = ?", + (image_path,) + ) + row = cursor.fetchone() + conn.close() + + if row: + return {"caption": row[0], "model": row[1], "generated_at": row[2]} + return None + + @staticmethod + def get_captions_batch(image_paths): + """Get captions for multiple images.""" + if not image_paths: + return {} + + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + + cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions + (image_path TEXT PRIMARY KEY, caption TEXT, model TEXT, + generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') + + placeholders = ','.join('?' * len(image_paths)) + cursor.execute( + f"SELECT image_path, caption, model, generated_at FROM image_captions WHERE image_path IN ({placeholders})", + image_paths + ) + result = {row[0]: {"caption": row[1], "model": row[2], "generated_at": row[3]} for row in cursor.fetchall()} + conn.close() + return result + + @staticmethod + def delete_caption(image_path): + """Delete caption for an image.""" + conn = sqlite3.connect(SorterEngine.DB_PATH) + cursor = conn.cursor() + cursor.execute("DELETE FROM image_captions WHERE image_path = ?", (image_path,)) + conn.commit() + conn.close() + + # --- 10. VLLM API CAPTIONING --- + @staticmethod + def caption_image_vllm(image_path, prompt, settings): + """ + Generate caption for an image using VLLM API. + + Args: + image_path: Path to the image file + prompt: Text prompt for captioning + settings: Dict with api_endpoint, model_name, max_tokens, temperature, timeout_seconds + + Returns: + Tuple of (caption_text, error_message). If successful, error is None. + """ + try: + # Read and encode image + with open(image_path, 'rb') as f: + img_bytes = f.read() + b64_image = base64.b64encode(img_bytes).decode('utf-8') + + # Determine MIME type + ext = os.path.splitext(image_path)[1].lower() + mime_types = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.webp': 'image/webp', + '.bmp': 'image/bmp', + '.tiff': 'image/tiff' + } + mime_type = mime_types.get(ext, 'image/jpeg') + + # Build request payload (OpenAI-compatible format) + payload = { + "model": settings.get('model_name', 'local-model'), + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_image}"}} + ] + }], + "max_tokens": settings.get('max_tokens', 300), + "temperature": settings.get('temperature', 0.7) + } + + # Make API request + response = requests.post( + settings.get('api_endpoint', 'http://localhost:8080/v1/chat/completions'), + json=payload, + timeout=settings.get('timeout_seconds', 60) + ) + response.raise_for_status() + + result = response.json() + caption = result['choices'][0]['message']['content'] + return caption.strip(), None + + except requests.Timeout: + return None, f"API timeout after {settings.get('timeout_seconds', 60)}s" + except requests.RequestException as e: + return None, f"API error: {str(e)}" + except KeyError as e: + return None, f"Invalid API response: missing {str(e)}" + except Exception as e: + return None, f"Error: {str(e)}" + + @staticmethod + def caption_batch_vllm(image_paths, get_prompt_fn, settings, progress_cb=None): + """ + Caption multiple images using VLLM API. + + Args: + image_paths: List of (image_path, category) tuples + get_prompt_fn: Function(category) -> prompt string + settings: Caption settings dict + progress_cb: Optional callback(current, total, status_msg) for progress updates + + Returns: + Dict with results: {"success": count, "failed": count, "captions": {path: caption}} + """ + results = {"success": 0, "failed": 0, "captions": {}, "errors": {}} + total = len(image_paths) + + for i, (image_path, category) in enumerate(image_paths): + if progress_cb: + progress_cb(i, total, f"Captioning {os.path.basename(image_path)}...") + + prompt = get_prompt_fn(category) + caption, error = SorterEngine.caption_image_vllm(image_path, prompt, settings) + + if caption: + # Save to database + SorterEngine.save_caption(image_path, caption, settings.get('model_name', 'local-model')) + results["captions"][image_path] = caption + results["success"] += 1 + else: + # Store error + error_caption = f"[ERROR] {error}" + SorterEngine.save_caption(image_path, error_caption, settings.get('model_name', 'local-model')) + results["errors"][image_path] = error + results["failed"] += 1 + + if progress_cb: + progress_cb(total, total, "Complete!") + + return results + + @staticmethod + def write_caption_sidecar(image_path, caption): + """ + Write caption to a .txt sidecar file next to the image. + + Args: + image_path: Path to the image file + caption: Caption text to write + + Returns: + Path to sidecar file, or None on error + """ + try: + # Create sidecar path (same name, .txt extension) + base_path = os.path.splitext(image_path)[0] + sidecar_path = f"{base_path}.txt" + + with open(sidecar_path, 'w', encoding='utf-8') as f: + f.write(caption) + + # Fix permissions + SorterEngine.fix_permissions(sidecar_path) + + return sidecar_path + except Exception as e: + print(f"Warning: Could not write sidecar for {image_path}: {e}") + return None + + @staticmethod + def read_caption_sidecar(image_path): + """ + Read caption from a .txt sidecar file if it exists. + + Args: + image_path: Path to the image file + + Returns: + Caption text or None if no sidecar exists + """ + try: + base_path = os.path.splitext(image_path)[0] + sidecar_path = f"{base_path}.txt" + + if os.path.exists(sidecar_path): + with open(sidecar_path, 'r', encoding='utf-8') as f: + return f.read().strip() + except Exception: + pass + return None \ No newline at end of file diff --git a/gallery_app.py b/gallery_app.py index b4efe3b..8049725 100644 --- a/gallery_app.py +++ b/gallery_app.py @@ -81,6 +81,12 @@ class AppState: # Pairing mode index maps (index -> (main_path, adj_path)) self.pair_index_map: Dict[int, Dict] = {} # {idx: {"main": path, "adj": path}} + # === CAPTION STATE === + self.caption_settings: Dict = {} + self.captioning_in_progress: bool = False + self.caption_on_apply: bool = False # Toggle for captioning during APPLY + self.caption_cache: Set[str] = set() # Paths that have captions + def load_active_profile(self): """Load paths from active profile.""" p_data = self.profiles.get(self.profile_name, {}) @@ -147,6 +153,17 @@ class AppState: self.active_cat = cats[0] return cats + def load_caption_settings(self): + """Load caption settings for current profile.""" + self.caption_settings = SorterEngine.get_caption_settings(self.profile_name) + + def refresh_caption_cache(self, image_paths: List[str] = None): + """Refresh the cache of which images have captions.""" + paths = image_paths or self.all_images + if paths: + captions = SorterEngine.get_captions_batch(paths) + self.caption_cache = set(captions.keys()) + def get_filtered_images(self) -> List[str]: """Get images based on current filter mode.""" if self.filter_mode == "all": @@ -220,28 +237,32 @@ def load_images(): if not os.path.exists(state.source_dir): ui.notify(f"Source not found: {state.source_dir}", type='warning') return - + # Auto-save current tags before switching folders if state.all_images and state.staged_data: saved = SorterEngine.save_folder_tags(state.source_dir, state.profile_name) if saved > 0: ui.notify(f"Auto-saved {saved} tags", type='info') - + # Clear staging area when loading a new folder SorterEngine.clear_staging_area() - + state.all_images = SorterEngine.get_images(state.source_dir, recursive=True) - + # Restore previously saved tags for this folder and profile restored = SorterEngine.restore_folder_tags(state.source_dir, state.all_images, state.profile_name) if restored > 0: ui.notify(f"Restored {restored} tags from previous session", type='info') - + # Reset page if out of bounds if state.page >= state.total_pages: state.page = 0 - + refresh_staged_info() + # Refresh caption cache for loaded images + state.refresh_caption_cache() + # Load caption settings + state.load_caption_settings() refresh_ui() # ========================================== @@ -674,14 +695,42 @@ def action_save_tags(): else: ui.notify("No tags to save", type='info') -def action_apply_page(): +async def action_apply_page(): """Apply staged changes for current page only.""" batch = state.get_current_batch() if not batch: ui.notify("No images on current page", type='warning') return + # Get tagged images and their categories before commit (they'll be moved/copied) + tagged_batch = [] + for img_path in batch: + if img_path in state.staged_data: + info = state.staged_data[img_path] + # Calculate destination path + dest_path = os.path.join(state.output_dir, info['name']) + tagged_batch.append((img_path, info['cat'], dest_path)) + SorterEngine.commit_batch(batch, state.output_dir, state.cleanup_mode, state.batch_mode) + + # Caption on apply if enabled + if state.caption_on_apply and tagged_batch: + state.load_caption_settings() + caption_count = 0 + for orig_path, category, dest_path in tagged_batch: + if os.path.exists(dest_path): + prompt = SorterEngine.get_category_prompt(state.profile_name, category) + caption, error = await run.io_bound( + SorterEngine.caption_image_vllm, + dest_path, prompt, state.caption_settings + ) + if caption: + SorterEngine.save_caption(dest_path, caption, state.caption_settings.get('model_name', 'local-model')) + SorterEngine.write_caption_sidecar(dest_path, caption) + caption_count += 1 + if caption_count > 0: + ui.notify(f"Captioned {caption_count} images", type='info') + ui.notify(f"Page processed ({state.batch_mode})", type='positive') # Force disk rescan since files were committed state._last_disk_scan_key = "" @@ -690,6 +739,14 @@ def action_apply_page(): async def action_apply_global(): """Apply all staged changes globally.""" ui.notify("Starting global apply... This may take a while.", type='info') + + # Capture staged data before commit for captioning + staged_before_commit = {} + if state.caption_on_apply: + for img_path, info in state.staged_data.items(): + dest_path = os.path.join(state.output_dir, info['name']) + staged_before_commit[img_path] = {'cat': info['cat'], 'dest': dest_path} + await run.io_bound( SorterEngine.commit_global, state.output_dir, @@ -698,6 +755,29 @@ async def action_apply_global(): state.source_dir, state.profile_name ) + + # Caption on apply if enabled + if state.caption_on_apply and staged_before_commit: + state.load_caption_settings() + ui.notify(f"Captioning {len(staged_before_commit)} images...", type='info') + + caption_count = 0 + for orig_path, info in staged_before_commit.items(): + dest_path = info['dest'] + if os.path.exists(dest_path): + prompt = SorterEngine.get_category_prompt(state.profile_name, info['cat']) + caption, error = await run.io_bound( + SorterEngine.caption_image_vllm, + dest_path, prompt, state.caption_settings + ) + if caption: + SorterEngine.save_caption(dest_path, caption, state.caption_settings.get('model_name', 'local-model')) + SorterEngine.write_caption_sidecar(dest_path, caption) + caption_count += 1 + + if caption_count > 0: + ui.notify(f"Captioned {caption_count} images", type='info') + # Force disk rescan since files were committed state._last_disk_scan_key = "" load_images() @@ -802,20 +882,20 @@ def open_hotkey_dialog(category: str): if cat == category: current_hotkey = hk break - + with ui.dialog() as dialog, ui.card().classes('p-4 bg-gray-800'): ui.label(f'Set Hotkey for "{category}"').classes('font-bold text-white mb-2') - + ui.label('Press a letter key (A-Z) to assign as hotkey').classes('text-gray-400 text-sm mb-4') - + if current_hotkey: ui.label(f'Current: {current_hotkey.upper()}').classes('text-blue-400 mb-2') - + hotkey_input = ui.input( placeholder='Type a letter...', value=current_hotkey or '' ).props('dark outlined dense autofocus').classes('w-full') - + def save_hotkey(): key = hotkey_input.value.lower().strip() if key and len(key) == 1 and key.isalpha(): @@ -823,11 +903,11 @@ def open_hotkey_dialog(category: str): to_remove = [hk for hk, c in state.category_hotkeys.items() if c == category] for hk in to_remove: del state.category_hotkeys[hk] - + # Remove if another category had this hotkey if key in state.category_hotkeys: del state.category_hotkeys[key] - + # Set new hotkey state.category_hotkeys[key] = category ui.notify(f'Hotkey "{key.upper()}" set for {category}', type='positive') @@ -843,7 +923,7 @@ def open_hotkey_dialog(category: str): render_sidebar() else: ui.notify('Please enter a single letter (A-Z)', type='warning') - + with ui.row().classes('w-full justify-end gap-2 mt-4'): ui.button('Clear', on_click=lambda: ( hotkey_input.set_value(''), @@ -851,9 +931,277 @@ def open_hotkey_dialog(category: str): )).props('flat color=grey') ui.button('Cancel', on_click=dialog.close).props('flat') ui.button('Save', on_click=save_hotkey).props('color=green') - + dialog.open() +def open_caption_settings_dialog(): + """Open dialog to configure caption API settings.""" + state.load_caption_settings() + settings = state.caption_settings.copy() + + with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-96'): + ui.label('Caption API Settings').classes('text-xl font-bold text-white mb-4') + + api_endpoint = ui.input( + label='API Endpoint', + value=settings.get('api_endpoint', 'http://localhost:8080/v1/chat/completions') + ).props('dark outlined dense').classes('w-full mb-2') + + model_name = ui.input( + label='Model Name', + value=settings.get('model_name', 'local-model') + ).props('dark outlined dense').classes('w-full mb-2') + + max_tokens = ui.number( + label='Max Tokens', + value=settings.get('max_tokens', 300), + min=50, max=2000 + ).props('dark outlined dense').classes('w-full mb-2') + + ui.label('Temperature').classes('text-gray-400 text-sm') + temperature = ui.slider( + min=0, max=1, step=0.1, + value=settings.get('temperature', 0.7) + ).props('color=purple label-always').classes('w-full mb-2') + + timeout = ui.number( + label='Timeout (seconds)', + value=settings.get('timeout_seconds', 60), + min=10, max=300 + ).props('dark outlined dense').classes('w-full mb-2') + + batch_size = ui.number( + label='Batch Size', + value=settings.get('batch_size', 4), + min=1, max=16 + ).props('dark outlined dense').classes('w-full mb-4') + + def save_settings(): + SorterEngine.save_caption_settings( + state.profile_name, + api_endpoint=api_endpoint.value, + model_name=model_name.value, + max_tokens=int(max_tokens.value), + temperature=float(temperature.value), + timeout_seconds=int(timeout.value), + batch_size=int(batch_size.value) + ) + state.load_caption_settings() + ui.notify('Caption settings saved!', type='positive') + dialog.close() + + with ui.row().classes('w-full justify-end gap-2'): + ui.button('Cancel', on_click=dialog.close).props('flat') + ui.button('Save', on_click=save_settings).props('color=purple') + + dialog.open() + +def open_prompt_editor_dialog(): + """Open dialog to edit category prompts.""" + categories = state.get_categories() + prompts = SorterEngine.get_all_category_prompts(state.profile_name) + + with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-[600px] max-h-[80vh]'): + ui.label('Category Prompts').classes('text-xl font-bold text-white mb-2') + ui.label('Set custom prompts for each category. Leave empty for default.').classes('text-gray-400 text-sm mb-4') + + default_prompt = "Describe this image in detail for training purposes. Include subjects, actions, setting, colors, and composition." + ui.label(f'Default: "{default_prompt[:60]}..."').classes('text-gray-500 text-xs mb-4') + + # Store text areas for later access + prompt_inputs = {} + + with ui.scroll_area().classes('w-full max-h-96'): + for cat in categories: + current_prompt = prompts.get(cat, '') + with ui.card().classes('w-full p-3 bg-gray-700 mb-2'): + ui.label(cat).classes('font-bold text-purple-400 mb-1') + prompt_inputs[cat] = ui.textarea( + value=current_prompt, + placeholder=default_prompt + ).props('dark outlined dense rows=2').classes('w-full') + + def save_all_prompts(): + for cat, textarea in prompt_inputs.items(): + prompt = textarea.value.strip() + if prompt: + SorterEngine.save_category_prompt(state.profile_name, cat, prompt) + else: + # Clear the prompt to use default + SorterEngine.save_category_prompt(state.profile_name, cat, '') + ui.notify(f'Prompts saved for {len(prompt_inputs)} categories!', type='positive') + dialog.close() + + with ui.row().classes('w-full justify-end gap-2 mt-4'): + ui.button('Cancel', on_click=dialog.close).props('flat') + ui.button('Save All', on_click=save_all_prompts).props('color=purple') + + dialog.open() + +def open_caption_dialog(img_path: str): + """Open dialog to view/edit/generate caption for a single image.""" + existing = SorterEngine.get_caption(img_path) + state.load_caption_settings() + + # Get category for this image + staged_info = state.staged_data.get(img_path) + category = staged_info['cat'] if staged_info else state.active_cat + + with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-[500px]'): + ui.label('Image Caption').classes('text-xl font-bold text-white mb-2') + ui.label(os.path.basename(img_path)).classes('text-gray-400 text-sm mb-4 truncate') + + # Thumbnail preview + ui.image(f"/thumbnail?path={img_path}&size=300&q=60").classes('w-full h-48 bg-black rounded mb-4').props('fit=contain') + + # Caption textarea + caption_text = ui.textarea( + label='Caption', + value=existing['caption'] if existing else '', + placeholder='Caption will appear here...' + ).props('dark outlined rows=4').classes('w-full mb-2') + + # Model info + if existing: + ui.label(f"Model: {existing.get('model', 'unknown')} | {existing.get('generated_at', '')}").classes('text-gray-500 text-xs mb-4') + + # Status label for progress + status_label = ui.label('').classes('text-purple-400 text-sm mb-2') + + async def generate_caption(): + status_label.set_text('Generating caption...') + prompt = SorterEngine.get_category_prompt(state.profile_name, category) + + caption, error = await run.io_bound( + SorterEngine.caption_image_vllm, + img_path, prompt, state.caption_settings + ) + + if caption: + caption_text.set_value(caption) + status_label.set_text('Caption generated!') + else: + status_label.set_text(f'Error: {error}') + + def save_caption(): + text = caption_text.value.strip() + if text: + SorterEngine.save_caption(img_path, text, state.caption_settings.get('model_name', 'manual')) + state.caption_cache.add(img_path) + ui.notify('Caption saved!', type='positive') + dialog.close() + refresh_grid_only() + else: + ui.notify('Caption is empty', type='warning') + + def save_with_sidecar(): + text = caption_text.value.strip() + if text: + SorterEngine.save_caption(img_path, text, state.caption_settings.get('model_name', 'manual')) + sidecar_path = SorterEngine.write_caption_sidecar(img_path, text) + state.caption_cache.add(img_path) + if sidecar_path: + ui.notify(f'Caption saved + sidecar written!', type='positive') + else: + ui.notify('Caption saved (sidecar failed)', type='warning') + dialog.close() + refresh_grid_only() + else: + ui.notify('Caption is empty', type='warning') + + with ui.row().classes('w-full justify-between gap-2'): + ui.button('Generate', icon='auto_awesome', on_click=generate_caption).props('color=purple') + with ui.row().classes('gap-2'): + ui.button('Cancel', on_click=dialog.close).props('flat') + ui.button('Save', on_click=save_caption).props('color=green') + ui.button('Save + Sidecar', on_click=save_with_sidecar).props('color=blue').tooltip('Also write .txt file') + + dialog.open() + +async def action_caption_category(): + """Caption all images tagged with the active category.""" + if state.captioning_in_progress: + ui.notify('Captioning already in progress', type='warning') + return + + # Find all images tagged with active category + images_to_caption = [] + for img_path, info in state.staged_data.items(): + if info['cat'] == state.active_cat: + images_to_caption.append((img_path, state.active_cat)) + + if not images_to_caption: + ui.notify(f'No images tagged with {state.active_cat}', type='warning') + return + + state.load_caption_settings() + state.captioning_in_progress = True + + # Create progress dialog + with ui.dialog() as progress_dialog, ui.card().classes('p-6 bg-gray-800 w-96'): + ui.label('Captioning Images...').classes('text-xl font-bold text-white mb-4') + progress_bar = ui.linear_progress(value=0).props('instant-feedback color=purple').classes('w-full mb-2') + progress_label = ui.label('0 / 0').classes('text-gray-400 text-center w-full mb-2') + status_label = ui.label('Starting...').classes('text-purple-400 text-sm text-center w-full') + + cancel_requested = {'value': False} + + def request_cancel(): + cancel_requested['value'] = True + status_label.set_text('Cancelling...') + + ui.button('Cancel', on_click=request_cancel).props('flat color=red').classes('w-full mt-4') + + progress_dialog.open() + + try: + total = len(images_to_caption) + success_count = 0 + fail_count = 0 + + def get_prompt(cat): + return SorterEngine.get_category_prompt(state.profile_name, cat) + + for i, (img_path, category) in enumerate(images_to_caption): + if cancel_requested['value']: + break + + progress_bar.set_value(i / total) + progress_label.set_text(f'{i + 1} / {total}') + status_label.set_text(f'Captioning {os.path.basename(img_path)}...') + + prompt = get_prompt(category) + caption, error = await run.io_bound( + SorterEngine.caption_image_vllm, + img_path, prompt, state.caption_settings + ) + + if caption: + SorterEngine.save_caption(img_path, caption, state.caption_settings.get('model_name', 'local-model')) + state.caption_cache.add(img_path) + success_count += 1 + else: + error_caption = f"[ERROR] {error}" + SorterEngine.save_caption(img_path, error_caption, state.caption_settings.get('model_name', 'local-model')) + fail_count += 1 + + progress_bar.set_value(1) + progress_label.set_text(f'{total} / {total}') + + if cancel_requested['value']: + status_label.set_text(f'Cancelled. {success_count} OK, {fail_count} failed') + else: + status_label.set_text(f'Done! {success_count} OK, {fail_count} failed') + + await asyncio.sleep(1.5) + progress_dialog.close() + + finally: + state.captioning_in_progress = False + refresh_grid_only() + + ui.notify(f'Captioned {success_count}/{total} images', type='positive' if fail_count == 0 else 'warning') + def render_sidebar(): """Render category management sidebar.""" state.sidebar_container.clear() @@ -932,8 +1280,11 @@ def render_sidebar(): .classes('w-full border border-gray-800') # Category Manager (expanded) - ui.label("📂 Categories").classes('text-sm font-bold text-gray-400 mt-2') - + with ui.row().classes('w-full justify-between items-center mt-2'): + ui.label("📂 Categories").classes('text-sm font-bold text-gray-400') + ui.button(icon='edit_note', on_click=open_prompt_editor_dialog) \ + .props('flat dense color=purple size=sm').tooltip('Edit Prompts') + categories = state.get_categories() # Category list with hotkey buttons @@ -1055,6 +1406,7 @@ def render_image_card(img_path: str): """Render individual image card. Uses functools.partial instead of lambdas for better memory efficiency.""" is_staged = img_path in state.staged_data + has_caption = img_path in state.caption_cache thumb_size = 800 card = ui.card().classes('p-2 bg-gray-900 border border-gray-700 no-shadow hover:border-green-500 transition-colors') @@ -1066,8 +1418,16 @@ def render_image_card(img_path: str): # Header with filename and actions with ui.row().classes('w-full justify-between no-wrap mb-1'): - ui.label(os.path.basename(img_path)[:15]).classes('text-xs text-gray-400 truncate') + with ui.row().classes('items-center gap-1'): + ui.label(os.path.basename(img_path)[:15]).classes('text-xs text-gray-400 truncate') + # Caption indicator + if has_caption: + ui.icon('description', size='xs').classes('text-purple-400').tooltip('Has caption') with ui.row().classes('gap-0'): + ui.button( + icon='auto_awesome', + on_click=partial(open_caption_dialog, img_path) + ).props('flat size=sm dense color=purple').tooltip('Caption') ui.button( icon='zoom_in', on_click=partial(open_zoom_dialog, img_path) @@ -1369,20 +1729,25 @@ def build_header(): with ui.button(icon='tune', color='white').props('flat round'): with ui.menu().classes('bg-gray-800 text-white p-4'): ui.label('VIEW SETTINGS').classes('text-xs font-bold mb-2') - + ui.label('Grid Columns:') ui.slider( min=2, max=8, step=1, value=state.grid_cols, on_change=lambda e: (setattr(state, 'grid_cols', e.value), refresh_ui()) ).props('color=green') - + ui.label('Preview Quality:') ui.slider( min=10, max=100, step=10, value=state.preview_quality, on_change=lambda e: (setattr(state, 'preview_quality', e.value), refresh_ui()) ).props('color=green label-always') + + ui.separator().classes('my-2') + ui.label('CAPTION SETTINGS').classes('text-xs font-bold mb-2 text-purple-400') + ui.button('Configure API', icon='api', on_click=open_caption_settings_dialog) \ + .props('flat color=purple').classes('w-full') ui.switch('Dark', value=True, on_change=lambda e: ui.dark_mode().set_value(e.value)) \ .props('color=green') @@ -1416,19 +1781,27 @@ def build_main_content(): ui.radio(['Copy', 'Move'], value=state.batch_mode) \ .bind_value(state, 'batch_mode') \ .props('inline dark color=green') - + # Untagged files mode with ui.column(): ui.label('UNTAGGED FILES:').classes('text-gray-500 text-xs font-bold') ui.radio(['Keep', 'Move to Unused', 'Delete'], value=state.cleanup_mode) \ .bind_value(state, 'cleanup_mode') \ .props('inline dark color=green') - + + # Caption options + with ui.column(): + ui.label('CAPTIONING:').classes('text-gray-500 text-xs font-bold') + ui.checkbox('Caption on Apply').bind_value(state, 'caption_on_apply') \ + .props('color=purple dark') + ui.button('CAPTION CATEGORY', icon='auto_awesome', on_click=action_caption_category) \ + .props('outline color=purple') + # Action buttons with ui.row().classes('items-center gap-6'): ui.button('APPLY PAGE', on_click=action_apply_page) \ .props('outline color=white lg') - + with ui.column().classes('items-center'): ui.button('APPLY GLOBAL', on_click=action_apply_global) \ .props('lg color=red-900') diff --git a/requirements.txt b/requirements.txt index 5682ef1..c538cae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ streamlit Pillow -nicegui \ No newline at end of file +nicegui +requests \ No newline at end of file