caption
This commit is contained in:
Binary file not shown.
Binary file not shown.
393
engine.py
393
engine.py
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import base64
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from contextlib import contextmanager
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@@ -68,6 +71,27 @@ class SorterEngine:
|
||||
for cat in ["_TRASH", "control", "Default", "Action", "Solo"]:
|
||||
cursor.execute("INSERT OR IGNORE INTO categories VALUES (?)", (cat,))
|
||||
|
||||
# --- CAPTION TABLES ---
|
||||
# Per-category prompt templates
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts
|
||||
(profile TEXT, category TEXT, prompt_template TEXT,
|
||||
PRIMARY KEY (profile, category))''')
|
||||
|
||||
# Stored captions
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions
|
||||
(image_path TEXT PRIMARY KEY, caption TEXT, model TEXT,
|
||||
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
|
||||
|
||||
# Caption API settings per profile
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings
|
||||
(profile TEXT PRIMARY KEY,
|
||||
api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions',
|
||||
model_name TEXT DEFAULT 'local-model',
|
||||
max_tokens INTEGER DEFAULT 300,
|
||||
temperature REAL DEFAULT 0.7,
|
||||
timeout_seconds INTEGER DEFAULT 60,
|
||||
batch_size INTEGER DEFAULT 4)''')
|
||||
|
||||
# --- PERFORMANCE INDEXES ---
|
||||
# Index for staging_area queries filtered by category
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_staging_category ON staging_area(target_category)")
|
||||
@@ -75,6 +99,8 @@ class SorterEngine:
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_folder_tags_profile ON folder_tags(profile, folder_path)")
|
||||
# Index for profile_categories lookups
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_profile_categories ON profile_categories(profile)")
|
||||
# Index for caption lookups by image path
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_image_captions ON image_captions(image_path)")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -827,3 +853,370 @@ class SorterEngine:
|
||||
result = {row[0]: {"cat": row[1], "index": row[2]} for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
# --- 8. CAPTION SETTINGS & PROMPTS ---
|
||||
@staticmethod
|
||||
def get_caption_settings(profile):
|
||||
"""Get caption API settings for a profile."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure table exists
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings
|
||||
(profile TEXT PRIMARY KEY,
|
||||
api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions',
|
||||
model_name TEXT DEFAULT 'local-model',
|
||||
max_tokens INTEGER DEFAULT 300,
|
||||
temperature REAL DEFAULT 0.7,
|
||||
timeout_seconds INTEGER DEFAULT 60,
|
||||
batch_size INTEGER DEFAULT 4)''')
|
||||
|
||||
cursor.execute("SELECT * FROM caption_settings WHERE profile = ?", (profile,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"profile": row[0],
|
||||
"api_endpoint": row[1],
|
||||
"model_name": row[2],
|
||||
"max_tokens": row[3],
|
||||
"temperature": row[4],
|
||||
"timeout_seconds": row[5],
|
||||
"batch_size": row[6]
|
||||
}
|
||||
else:
|
||||
# Return defaults
|
||||
return {
|
||||
"profile": profile,
|
||||
"api_endpoint": "http://localhost:8080/v1/chat/completions",
|
||||
"model_name": "local-model",
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.7,
|
||||
"timeout_seconds": 60,
|
||||
"batch_size": 4
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def save_caption_settings(profile, api_endpoint=None, model_name=None, max_tokens=None,
|
||||
temperature=None, timeout_seconds=None, batch_size=None):
|
||||
"""Save caption API settings for a profile."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure table exists
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS caption_settings
|
||||
(profile TEXT PRIMARY KEY,
|
||||
api_endpoint TEXT DEFAULT 'http://localhost:8080/v1/chat/completions',
|
||||
model_name TEXT DEFAULT 'local-model',
|
||||
max_tokens INTEGER DEFAULT 300,
|
||||
temperature REAL DEFAULT 0.7,
|
||||
timeout_seconds INTEGER DEFAULT 60,
|
||||
batch_size INTEGER DEFAULT 4)''')
|
||||
|
||||
# Get existing values
|
||||
cursor.execute("SELECT * FROM caption_settings WHERE profile = ?", (profile,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
row = (profile, "http://localhost:8080/v1/chat/completions", "local-model", 300, 0.7, 60, 4)
|
||||
|
||||
new_values = (
|
||||
profile,
|
||||
api_endpoint if api_endpoint is not None else row[1],
|
||||
model_name if model_name is not None else row[2],
|
||||
max_tokens if max_tokens is not None else row[3],
|
||||
temperature if temperature is not None else row[4],
|
||||
timeout_seconds if timeout_seconds is not None else row[5],
|
||||
batch_size if batch_size is not None else row[6]
|
||||
)
|
||||
|
||||
cursor.execute("INSERT OR REPLACE INTO caption_settings VALUES (?, ?, ?, ?, ?, ?, ?)", new_values)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def get_category_prompt(profile, category):
|
||||
"""Get prompt template for a category."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts
|
||||
(profile TEXT, category TEXT, prompt_template TEXT,
|
||||
PRIMARY KEY (profile, category))''')
|
||||
|
||||
cursor.execute(
|
||||
"SELECT prompt_template FROM category_prompts WHERE profile = ? AND category = ?",
|
||||
(profile, category)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
else:
|
||||
# Default prompt
|
||||
return "Describe this image in detail for training purposes. Include subjects, actions, setting, colors, and composition."
|
||||
|
||||
@staticmethod
|
||||
def save_category_prompt(profile, category, prompt):
|
||||
"""Save prompt template for a category."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts
|
||||
(profile TEXT, category TEXT, prompt_template TEXT,
|
||||
PRIMARY KEY (profile, category))''')
|
||||
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO category_prompts VALUES (?, ?, ?)",
|
||||
(profile, category, prompt)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def get_all_category_prompts(profile):
|
||||
"""Get all prompt templates for a profile."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS category_prompts
|
||||
(profile TEXT, category TEXT, prompt_template TEXT,
|
||||
PRIMARY KEY (profile, category))''')
|
||||
|
||||
cursor.execute(
|
||||
"SELECT category, prompt_template FROM category_prompts WHERE profile = ?",
|
||||
(profile,)
|
||||
)
|
||||
result = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
# --- 9. CAPTION STORAGE ---
|
||||
@staticmethod
|
||||
def save_caption(image_path, caption, model):
|
||||
"""Save a generated caption to the database."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions
|
||||
(image_path TEXT PRIMARY KEY, caption TEXT, model TEXT,
|
||||
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
|
||||
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO image_captions VALUES (?, ?, ?, ?)",
|
||||
(image_path, caption, model, datetime.now().isoformat())
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def get_caption(image_path):
|
||||
"""Get caption for an image."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions
|
||||
(image_path TEXT PRIMARY KEY, caption TEXT, model TEXT,
|
||||
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
|
||||
|
||||
cursor.execute(
|
||||
"SELECT caption, model, generated_at FROM image_captions WHERE image_path = ?",
|
||||
(image_path,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
return {"caption": row[0], "model": row[1], "generated_at": row[2]}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_captions_batch(image_paths):
|
||||
"""Get captions for multiple images."""
|
||||
if not image_paths:
|
||||
return {}
|
||||
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS image_captions
|
||||
(image_path TEXT PRIMARY KEY, caption TEXT, model TEXT,
|
||||
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
|
||||
|
||||
placeholders = ','.join('?' * len(image_paths))
|
||||
cursor.execute(
|
||||
f"SELECT image_path, caption, model, generated_at FROM image_captions WHERE image_path IN ({placeholders})",
|
||||
image_paths
|
||||
)
|
||||
result = {row[0]: {"caption": row[1], "model": row[2], "generated_at": row[3]} for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def delete_caption(image_path):
|
||||
"""Delete caption for an image."""
|
||||
conn = sqlite3.connect(SorterEngine.DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM image_captions WHERE image_path = ?", (image_path,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# --- 10. VLLM API CAPTIONING ---
|
||||
@staticmethod
|
||||
def caption_image_vllm(image_path, prompt, settings):
|
||||
"""
|
||||
Generate caption for an image using VLLM API.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
prompt: Text prompt for captioning
|
||||
settings: Dict with api_endpoint, model_name, max_tokens, temperature, timeout_seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (caption_text, error_message). If successful, error is None.
|
||||
"""
|
||||
try:
|
||||
# Read and encode image
|
||||
with open(image_path, 'rb') as f:
|
||||
img_bytes = f.read()
|
||||
b64_image = base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
# Determine MIME type
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.webp': 'image/webp',
|
||||
'.bmp': 'image/bmp',
|
||||
'.tiff': 'image/tiff'
|
||||
}
|
||||
mime_type = mime_types.get(ext, 'image/jpeg')
|
||||
|
||||
# Build request payload (OpenAI-compatible format)
|
||||
payload = {
|
||||
"model": settings.get('model_name', 'local-model'),
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_image}"}}
|
||||
]
|
||||
}],
|
||||
"max_tokens": settings.get('max_tokens', 300),
|
||||
"temperature": settings.get('temperature', 0.7)
|
||||
}
|
||||
|
||||
# Make API request
|
||||
response = requests.post(
|
||||
settings.get('api_endpoint', 'http://localhost:8080/v1/chat/completions'),
|
||||
json=payload,
|
||||
timeout=settings.get('timeout_seconds', 60)
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
caption = result['choices'][0]['message']['content']
|
||||
return caption.strip(), None
|
||||
|
||||
except requests.Timeout:
|
||||
return None, f"API timeout after {settings.get('timeout_seconds', 60)}s"
|
||||
except requests.RequestException as e:
|
||||
return None, f"API error: {str(e)}"
|
||||
except KeyError as e:
|
||||
return None, f"Invalid API response: missing {str(e)}"
|
||||
except Exception as e:
|
||||
return None, f"Error: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def caption_batch_vllm(image_paths, get_prompt_fn, settings, progress_cb=None):
|
||||
"""
|
||||
Caption multiple images using VLLM API.
|
||||
|
||||
Args:
|
||||
image_paths: List of (image_path, category) tuples
|
||||
get_prompt_fn: Function(category) -> prompt string
|
||||
settings: Caption settings dict
|
||||
progress_cb: Optional callback(current, total, status_msg) for progress updates
|
||||
|
||||
Returns:
|
||||
Dict with results: {"success": count, "failed": count, "captions": {path: caption}}
|
||||
"""
|
||||
results = {"success": 0, "failed": 0, "captions": {}, "errors": {}}
|
||||
total = len(image_paths)
|
||||
|
||||
for i, (image_path, category) in enumerate(image_paths):
|
||||
if progress_cb:
|
||||
progress_cb(i, total, f"Captioning {os.path.basename(image_path)}...")
|
||||
|
||||
prompt = get_prompt_fn(category)
|
||||
caption, error = SorterEngine.caption_image_vllm(image_path, prompt, settings)
|
||||
|
||||
if caption:
|
||||
# Save to database
|
||||
SorterEngine.save_caption(image_path, caption, settings.get('model_name', 'local-model'))
|
||||
results["captions"][image_path] = caption
|
||||
results["success"] += 1
|
||||
else:
|
||||
# Store error
|
||||
error_caption = f"[ERROR] {error}"
|
||||
SorterEngine.save_caption(image_path, error_caption, settings.get('model_name', 'local-model'))
|
||||
results["errors"][image_path] = error
|
||||
results["failed"] += 1
|
||||
|
||||
if progress_cb:
|
||||
progress_cb(total, total, "Complete!")
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def write_caption_sidecar(image_path, caption):
|
||||
"""
|
||||
Write caption to a .txt sidecar file next to the image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
caption: Caption text to write
|
||||
|
||||
Returns:
|
||||
Path to sidecar file, or None on error
|
||||
"""
|
||||
try:
|
||||
# Create sidecar path (same name, .txt extension)
|
||||
base_path = os.path.splitext(image_path)[0]
|
||||
sidecar_path = f"{base_path}.txt"
|
||||
|
||||
with open(sidecar_path, 'w', encoding='utf-8') as f:
|
||||
f.write(caption)
|
||||
|
||||
# Fix permissions
|
||||
SorterEngine.fix_permissions(sidecar_path)
|
||||
|
||||
return sidecar_path
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not write sidecar for {image_path}: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def read_caption_sidecar(image_path):
|
||||
"""
|
||||
Read caption from a .txt sidecar file if it exists.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
Caption text or None if no sidecar exists
|
||||
"""
|
||||
try:
|
||||
base_path = os.path.splitext(image_path)[0]
|
||||
sidecar_path = f"{base_path}.txt"
|
||||
|
||||
if os.path.exists(sidecar_path):
|
||||
with open(sidecar_path, 'r', encoding='utf-8') as f:
|
||||
return f.read().strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
379
gallery_app.py
379
gallery_app.py
@@ -81,6 +81,12 @@ class AppState:
|
||||
# Pairing mode index maps (index -> (main_path, adj_path))
|
||||
self.pair_index_map: Dict[int, Dict] = {} # {idx: {"main": path, "adj": path}}
|
||||
|
||||
# === CAPTION STATE ===
|
||||
self.caption_settings: Dict = {}
|
||||
self.captioning_in_progress: bool = False
|
||||
self.caption_on_apply: bool = False # Toggle for captioning during APPLY
|
||||
self.caption_cache: Set[str] = set() # Paths that have captions
|
||||
|
||||
def load_active_profile(self):
|
||||
"""Load paths from active profile."""
|
||||
p_data = self.profiles.get(self.profile_name, {})
|
||||
@@ -147,6 +153,17 @@ class AppState:
|
||||
self.active_cat = cats[0]
|
||||
return cats
|
||||
|
||||
def load_caption_settings(self):
|
||||
"""Load caption settings for current profile."""
|
||||
self.caption_settings = SorterEngine.get_caption_settings(self.profile_name)
|
||||
|
||||
def refresh_caption_cache(self, image_paths: List[str] = None):
|
||||
"""Refresh the cache of which images have captions."""
|
||||
paths = image_paths or self.all_images
|
||||
if paths:
|
||||
captions = SorterEngine.get_captions_batch(paths)
|
||||
self.caption_cache = set(captions.keys())
|
||||
|
||||
def get_filtered_images(self) -> List[str]:
|
||||
"""Get images based on current filter mode."""
|
||||
if self.filter_mode == "all":
|
||||
@@ -242,6 +259,10 @@ def load_images():
|
||||
state.page = 0
|
||||
|
||||
refresh_staged_info()
|
||||
# Refresh caption cache for loaded images
|
||||
state.refresh_caption_cache()
|
||||
# Load caption settings
|
||||
state.load_caption_settings()
|
||||
refresh_ui()
|
||||
|
||||
# ==========================================
|
||||
@@ -674,14 +695,42 @@ def action_save_tags():
|
||||
else:
|
||||
ui.notify("No tags to save", type='info')
|
||||
|
||||
def action_apply_page():
|
||||
async def action_apply_page():
|
||||
"""Apply staged changes for current page only."""
|
||||
batch = state.get_current_batch()
|
||||
if not batch:
|
||||
ui.notify("No images on current page", type='warning')
|
||||
return
|
||||
|
||||
# Get tagged images and their categories before commit (they'll be moved/copied)
|
||||
tagged_batch = []
|
||||
for img_path in batch:
|
||||
if img_path in state.staged_data:
|
||||
info = state.staged_data[img_path]
|
||||
# Calculate destination path
|
||||
dest_path = os.path.join(state.output_dir, info['name'])
|
||||
tagged_batch.append((img_path, info['cat'], dest_path))
|
||||
|
||||
SorterEngine.commit_batch(batch, state.output_dir, state.cleanup_mode, state.batch_mode)
|
||||
|
||||
# Caption on apply if enabled
|
||||
if state.caption_on_apply and tagged_batch:
|
||||
state.load_caption_settings()
|
||||
caption_count = 0
|
||||
for orig_path, category, dest_path in tagged_batch:
|
||||
if os.path.exists(dest_path):
|
||||
prompt = SorterEngine.get_category_prompt(state.profile_name, category)
|
||||
caption, error = await run.io_bound(
|
||||
SorterEngine.caption_image_vllm,
|
||||
dest_path, prompt, state.caption_settings
|
||||
)
|
||||
if caption:
|
||||
SorterEngine.save_caption(dest_path, caption, state.caption_settings.get('model_name', 'local-model'))
|
||||
SorterEngine.write_caption_sidecar(dest_path, caption)
|
||||
caption_count += 1
|
||||
if caption_count > 0:
|
||||
ui.notify(f"Captioned {caption_count} images", type='info')
|
||||
|
||||
ui.notify(f"Page processed ({state.batch_mode})", type='positive')
|
||||
# Force disk rescan since files were committed
|
||||
state._last_disk_scan_key = ""
|
||||
@@ -690,6 +739,14 @@ def action_apply_page():
|
||||
async def action_apply_global():
|
||||
"""Apply all staged changes globally."""
|
||||
ui.notify("Starting global apply... This may take a while.", type='info')
|
||||
|
||||
# Capture staged data before commit for captioning
|
||||
staged_before_commit = {}
|
||||
if state.caption_on_apply:
|
||||
for img_path, info in state.staged_data.items():
|
||||
dest_path = os.path.join(state.output_dir, info['name'])
|
||||
staged_before_commit[img_path] = {'cat': info['cat'], 'dest': dest_path}
|
||||
|
||||
await run.io_bound(
|
||||
SorterEngine.commit_global,
|
||||
state.output_dir,
|
||||
@@ -698,6 +755,29 @@ async def action_apply_global():
|
||||
state.source_dir,
|
||||
state.profile_name
|
||||
)
|
||||
|
||||
# Caption on apply if enabled
|
||||
if state.caption_on_apply and staged_before_commit:
|
||||
state.load_caption_settings()
|
||||
ui.notify(f"Captioning {len(staged_before_commit)} images...", type='info')
|
||||
|
||||
caption_count = 0
|
||||
for orig_path, info in staged_before_commit.items():
|
||||
dest_path = info['dest']
|
||||
if os.path.exists(dest_path):
|
||||
prompt = SorterEngine.get_category_prompt(state.profile_name, info['cat'])
|
||||
caption, error = await run.io_bound(
|
||||
SorterEngine.caption_image_vllm,
|
||||
dest_path, prompt, state.caption_settings
|
||||
)
|
||||
if caption:
|
||||
SorterEngine.save_caption(dest_path, caption, state.caption_settings.get('model_name', 'local-model'))
|
||||
SorterEngine.write_caption_sidecar(dest_path, caption)
|
||||
caption_count += 1
|
||||
|
||||
if caption_count > 0:
|
||||
ui.notify(f"Captioned {caption_count} images", type='info')
|
||||
|
||||
# Force disk rescan since files were committed
|
||||
state._last_disk_scan_key = ""
|
||||
load_images()
|
||||
@@ -854,6 +934,274 @@ def open_hotkey_dialog(category: str):
|
||||
|
||||
dialog.open()
|
||||
|
||||
def open_caption_settings_dialog():
|
||||
"""Open dialog to configure caption API settings."""
|
||||
state.load_caption_settings()
|
||||
settings = state.caption_settings.copy()
|
||||
|
||||
with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-96'):
|
||||
ui.label('Caption API Settings').classes('text-xl font-bold text-white mb-4')
|
||||
|
||||
api_endpoint = ui.input(
|
||||
label='API Endpoint',
|
||||
value=settings.get('api_endpoint', 'http://localhost:8080/v1/chat/completions')
|
||||
).props('dark outlined dense').classes('w-full mb-2')
|
||||
|
||||
model_name = ui.input(
|
||||
label='Model Name',
|
||||
value=settings.get('model_name', 'local-model')
|
||||
).props('dark outlined dense').classes('w-full mb-2')
|
||||
|
||||
max_tokens = ui.number(
|
||||
label='Max Tokens',
|
||||
value=settings.get('max_tokens', 300),
|
||||
min=50, max=2000
|
||||
).props('dark outlined dense').classes('w-full mb-2')
|
||||
|
||||
ui.label('Temperature').classes('text-gray-400 text-sm')
|
||||
temperature = ui.slider(
|
||||
min=0, max=1, step=0.1,
|
||||
value=settings.get('temperature', 0.7)
|
||||
).props('color=purple label-always').classes('w-full mb-2')
|
||||
|
||||
timeout = ui.number(
|
||||
label='Timeout (seconds)',
|
||||
value=settings.get('timeout_seconds', 60),
|
||||
min=10, max=300
|
||||
).props('dark outlined dense').classes('w-full mb-2')
|
||||
|
||||
batch_size = ui.number(
|
||||
label='Batch Size',
|
||||
value=settings.get('batch_size', 4),
|
||||
min=1, max=16
|
||||
).props('dark outlined dense').classes('w-full mb-4')
|
||||
|
||||
def save_settings():
|
||||
SorterEngine.save_caption_settings(
|
||||
state.profile_name,
|
||||
api_endpoint=api_endpoint.value,
|
||||
model_name=model_name.value,
|
||||
max_tokens=int(max_tokens.value),
|
||||
temperature=float(temperature.value),
|
||||
timeout_seconds=int(timeout.value),
|
||||
batch_size=int(batch_size.value)
|
||||
)
|
||||
state.load_caption_settings()
|
||||
ui.notify('Caption settings saved!', type='positive')
|
||||
dialog.close()
|
||||
|
||||
with ui.row().classes('w-full justify-end gap-2'):
|
||||
ui.button('Cancel', on_click=dialog.close).props('flat')
|
||||
ui.button('Save', on_click=save_settings).props('color=purple')
|
||||
|
||||
dialog.open()
|
||||
|
||||
def open_prompt_editor_dialog():
|
||||
"""Open dialog to edit category prompts."""
|
||||
categories = state.get_categories()
|
||||
prompts = SorterEngine.get_all_category_prompts(state.profile_name)
|
||||
|
||||
with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-[600px] max-h-[80vh]'):
|
||||
ui.label('Category Prompts').classes('text-xl font-bold text-white mb-2')
|
||||
ui.label('Set custom prompts for each category. Leave empty for default.').classes('text-gray-400 text-sm mb-4')
|
||||
|
||||
default_prompt = "Describe this image in detail for training purposes. Include subjects, actions, setting, colors, and composition."
|
||||
ui.label(f'Default: "{default_prompt[:60]}..."').classes('text-gray-500 text-xs mb-4')
|
||||
|
||||
# Store text areas for later access
|
||||
prompt_inputs = {}
|
||||
|
||||
with ui.scroll_area().classes('w-full max-h-96'):
|
||||
for cat in categories:
|
||||
current_prompt = prompts.get(cat, '')
|
||||
with ui.card().classes('w-full p-3 bg-gray-700 mb-2'):
|
||||
ui.label(cat).classes('font-bold text-purple-400 mb-1')
|
||||
prompt_inputs[cat] = ui.textarea(
|
||||
value=current_prompt,
|
||||
placeholder=default_prompt
|
||||
).props('dark outlined dense rows=2').classes('w-full')
|
||||
|
||||
def save_all_prompts():
|
||||
for cat, textarea in prompt_inputs.items():
|
||||
prompt = textarea.value.strip()
|
||||
if prompt:
|
||||
SorterEngine.save_category_prompt(state.profile_name, cat, prompt)
|
||||
else:
|
||||
# Clear the prompt to use default
|
||||
SorterEngine.save_category_prompt(state.profile_name, cat, '')
|
||||
ui.notify(f'Prompts saved for {len(prompt_inputs)} categories!', type='positive')
|
||||
dialog.close()
|
||||
|
||||
with ui.row().classes('w-full justify-end gap-2 mt-4'):
|
||||
ui.button('Cancel', on_click=dialog.close).props('flat')
|
||||
ui.button('Save All', on_click=save_all_prompts).props('color=purple')
|
||||
|
||||
dialog.open()
|
||||
|
||||
def open_caption_dialog(img_path: str):
|
||||
"""Open dialog to view/edit/generate caption for a single image."""
|
||||
existing = SorterEngine.get_caption(img_path)
|
||||
state.load_caption_settings()
|
||||
|
||||
# Get category for this image
|
||||
staged_info = state.staged_data.get(img_path)
|
||||
category = staged_info['cat'] if staged_info else state.active_cat
|
||||
|
||||
with ui.dialog() as dialog, ui.card().classes('p-6 bg-gray-800 w-[500px]'):
|
||||
ui.label('Image Caption').classes('text-xl font-bold text-white mb-2')
|
||||
ui.label(os.path.basename(img_path)).classes('text-gray-400 text-sm mb-4 truncate')
|
||||
|
||||
# Thumbnail preview
|
||||
ui.image(f"/thumbnail?path={img_path}&size=300&q=60").classes('w-full h-48 bg-black rounded mb-4').props('fit=contain')
|
||||
|
||||
# Caption textarea
|
||||
caption_text = ui.textarea(
|
||||
label='Caption',
|
||||
value=existing['caption'] if existing else '',
|
||||
placeholder='Caption will appear here...'
|
||||
).props('dark outlined rows=4').classes('w-full mb-2')
|
||||
|
||||
# Model info
|
||||
if existing:
|
||||
ui.label(f"Model: {existing.get('model', 'unknown')} | {existing.get('generated_at', '')}").classes('text-gray-500 text-xs mb-4')
|
||||
|
||||
# Status label for progress
|
||||
status_label = ui.label('').classes('text-purple-400 text-sm mb-2')
|
||||
|
||||
async def generate_caption():
|
||||
status_label.set_text('Generating caption...')
|
||||
prompt = SorterEngine.get_category_prompt(state.profile_name, category)
|
||||
|
||||
caption, error = await run.io_bound(
|
||||
SorterEngine.caption_image_vllm,
|
||||
img_path, prompt, state.caption_settings
|
||||
)
|
||||
|
||||
if caption:
|
||||
caption_text.set_value(caption)
|
||||
status_label.set_text('Caption generated!')
|
||||
else:
|
||||
status_label.set_text(f'Error: {error}')
|
||||
|
||||
def save_caption():
|
||||
text = caption_text.value.strip()
|
||||
if text:
|
||||
SorterEngine.save_caption(img_path, text, state.caption_settings.get('model_name', 'manual'))
|
||||
state.caption_cache.add(img_path)
|
||||
ui.notify('Caption saved!', type='positive')
|
||||
dialog.close()
|
||||
refresh_grid_only()
|
||||
else:
|
||||
ui.notify('Caption is empty', type='warning')
|
||||
|
||||
def save_with_sidecar():
|
||||
text = caption_text.value.strip()
|
||||
if text:
|
||||
SorterEngine.save_caption(img_path, text, state.caption_settings.get('model_name', 'manual'))
|
||||
sidecar_path = SorterEngine.write_caption_sidecar(img_path, text)
|
||||
state.caption_cache.add(img_path)
|
||||
if sidecar_path:
|
||||
ui.notify(f'Caption saved + sidecar written!', type='positive')
|
||||
else:
|
||||
ui.notify('Caption saved (sidecar failed)', type='warning')
|
||||
dialog.close()
|
||||
refresh_grid_only()
|
||||
else:
|
||||
ui.notify('Caption is empty', type='warning')
|
||||
|
||||
with ui.row().classes('w-full justify-between gap-2'):
|
||||
ui.button('Generate', icon='auto_awesome', on_click=generate_caption).props('color=purple')
|
||||
with ui.row().classes('gap-2'):
|
||||
ui.button('Cancel', on_click=dialog.close).props('flat')
|
||||
ui.button('Save', on_click=save_caption).props('color=green')
|
||||
ui.button('Save + Sidecar', on_click=save_with_sidecar).props('color=blue').tooltip('Also write .txt file')
|
||||
|
||||
dialog.open()
|
||||
|
||||
async def action_caption_category():
|
||||
"""Caption all images tagged with the active category."""
|
||||
if state.captioning_in_progress:
|
||||
ui.notify('Captioning already in progress', type='warning')
|
||||
return
|
||||
|
||||
# Find all images tagged with active category
|
||||
images_to_caption = []
|
||||
for img_path, info in state.staged_data.items():
|
||||
if info['cat'] == state.active_cat:
|
||||
images_to_caption.append((img_path, state.active_cat))
|
||||
|
||||
if not images_to_caption:
|
||||
ui.notify(f'No images tagged with {state.active_cat}', type='warning')
|
||||
return
|
||||
|
||||
state.load_caption_settings()
|
||||
state.captioning_in_progress = True
|
||||
|
||||
# Create progress dialog
|
||||
with ui.dialog() as progress_dialog, ui.card().classes('p-6 bg-gray-800 w-96'):
|
||||
ui.label('Captioning Images...').classes('text-xl font-bold text-white mb-4')
|
||||
progress_bar = ui.linear_progress(value=0).props('instant-feedback color=purple').classes('w-full mb-2')
|
||||
progress_label = ui.label('0 / 0').classes('text-gray-400 text-center w-full mb-2')
|
||||
status_label = ui.label('Starting...').classes('text-purple-400 text-sm text-center w-full')
|
||||
|
||||
cancel_requested = {'value': False}
|
||||
|
||||
def request_cancel():
|
||||
cancel_requested['value'] = True
|
||||
status_label.set_text('Cancelling...')
|
||||
|
||||
ui.button('Cancel', on_click=request_cancel).props('flat color=red').classes('w-full mt-4')
|
||||
|
||||
progress_dialog.open()
|
||||
|
||||
try:
|
||||
total = len(images_to_caption)
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
def get_prompt(cat):
|
||||
return SorterEngine.get_category_prompt(state.profile_name, cat)
|
||||
|
||||
for i, (img_path, category) in enumerate(images_to_caption):
|
||||
if cancel_requested['value']:
|
||||
break
|
||||
|
||||
progress_bar.set_value(i / total)
|
||||
progress_label.set_text(f'{i + 1} / {total}')
|
||||
status_label.set_text(f'Captioning {os.path.basename(img_path)}...')
|
||||
|
||||
prompt = get_prompt(category)
|
||||
caption, error = await run.io_bound(
|
||||
SorterEngine.caption_image_vllm,
|
||||
img_path, prompt, state.caption_settings
|
||||
)
|
||||
|
||||
if caption:
|
||||
SorterEngine.save_caption(img_path, caption, state.caption_settings.get('model_name', 'local-model'))
|
||||
state.caption_cache.add(img_path)
|
||||
success_count += 1
|
||||
else:
|
||||
error_caption = f"[ERROR] {error}"
|
||||
SorterEngine.save_caption(img_path, error_caption, state.caption_settings.get('model_name', 'local-model'))
|
||||
fail_count += 1
|
||||
|
||||
progress_bar.set_value(1)
|
||||
progress_label.set_text(f'{total} / {total}')
|
||||
|
||||
if cancel_requested['value']:
|
||||
status_label.set_text(f'Cancelled. {success_count} OK, {fail_count} failed')
|
||||
else:
|
||||
status_label.set_text(f'Done! {success_count} OK, {fail_count} failed')
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
progress_dialog.close()
|
||||
|
||||
finally:
|
||||
state.captioning_in_progress = False
|
||||
refresh_grid_only()
|
||||
|
||||
ui.notify(f'Captioned {success_count}/{total} images', type='positive' if fail_count == 0 else 'warning')
|
||||
|
||||
def render_sidebar():
|
||||
"""Render category management sidebar."""
|
||||
state.sidebar_container.clear()
|
||||
@@ -932,7 +1280,10 @@ def render_sidebar():
|
||||
.classes('w-full border border-gray-800')
|
||||
|
||||
# Category Manager (expanded)
|
||||
ui.label("📂 Categories").classes('text-sm font-bold text-gray-400 mt-2')
|
||||
with ui.row().classes('w-full justify-between items-center mt-2'):
|
||||
ui.label("📂 Categories").classes('text-sm font-bold text-gray-400')
|
||||
ui.button(icon='edit_note', on_click=open_prompt_editor_dialog) \
|
||||
.props('flat dense color=purple size=sm').tooltip('Edit Prompts')
|
||||
|
||||
categories = state.get_categories()
|
||||
|
||||
@@ -1055,6 +1406,7 @@ def render_image_card(img_path: str):
|
||||
"""Render individual image card.
|
||||
Uses functools.partial instead of lambdas for better memory efficiency."""
|
||||
is_staged = img_path in state.staged_data
|
||||
has_caption = img_path in state.caption_cache
|
||||
thumb_size = 800
|
||||
|
||||
card = ui.card().classes('p-2 bg-gray-900 border border-gray-700 no-shadow hover:border-green-500 transition-colors')
|
||||
@@ -1066,8 +1418,16 @@ def render_image_card(img_path: str):
|
||||
|
||||
# Header with filename and actions
|
||||
with ui.row().classes('w-full justify-between no-wrap mb-1'):
|
||||
ui.label(os.path.basename(img_path)[:15]).classes('text-xs text-gray-400 truncate')
|
||||
with ui.row().classes('items-center gap-1'):
|
||||
ui.label(os.path.basename(img_path)[:15]).classes('text-xs text-gray-400 truncate')
|
||||
# Caption indicator
|
||||
if has_caption:
|
||||
ui.icon('description', size='xs').classes('text-purple-400').tooltip('Has caption')
|
||||
with ui.row().classes('gap-0'):
|
||||
ui.button(
|
||||
icon='auto_awesome',
|
||||
on_click=partial(open_caption_dialog, img_path)
|
||||
).props('flat size=sm dense color=purple').tooltip('Caption')
|
||||
ui.button(
|
||||
icon='zoom_in',
|
||||
on_click=partial(open_zoom_dialog, img_path)
|
||||
@@ -1384,6 +1744,11 @@ def build_header():
|
||||
on_change=lambda e: (setattr(state, 'preview_quality', e.value), refresh_ui())
|
||||
).props('color=green label-always')
|
||||
|
||||
ui.separator().classes('my-2')
|
||||
ui.label('CAPTION SETTINGS').classes('text-xs font-bold mb-2 text-purple-400')
|
||||
ui.button('Configure API', icon='api', on_click=open_caption_settings_dialog) \
|
||||
.props('flat color=purple').classes('w-full')
|
||||
|
||||
ui.switch('Dark', value=True, on_change=lambda e: ui.dark_mode().set_value(e.value)) \
|
||||
.props('color=green')
|
||||
|
||||
@@ -1424,6 +1789,14 @@ def build_main_content():
|
||||
.bind_value(state, 'cleanup_mode') \
|
||||
.props('inline dark color=green')
|
||||
|
||||
# Caption options
|
||||
with ui.column():
|
||||
ui.label('CAPTIONING:').classes('text-gray-500 text-xs font-bold')
|
||||
ui.checkbox('Caption on Apply').bind_value(state, 'caption_on_apply') \
|
||||
.props('color=purple dark')
|
||||
ui.button('CAPTION CATEGORY', icon='auto_awesome', on_click=action_caption_category) \
|
||||
.props('outline color=purple')
|
||||
|
||||
# Action buttons
|
||||
with ui.row().classes('items-center gap-6'):
|
||||
ui.button('APPLY PAGE', on_click=action_apply_page) \
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
streamlit
|
||||
Pillow
|
||||
nicegui
|
||||
requests
|
||||
Reference in New Issue
Block a user