fix: address review bugs in server implementation
- Fix keyframe 6-tuple → 4-tuple mismatch crashing ExportRunner - Fix ws.broadcast() using wrong event loop from background threads - Fix export counter hardcoded to 1, now auto-increments - Add path traversal protection to file/stream/delete endpoints - Use proper HTTP error codes (was returning 200 for errors) - Add thread safety to WebSocket connection list - Record exports to DB so markers appear - Move WS endpoint to /ws/export (was /api/ws/export) - Prune dead threads from cache job tracker Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+7
-1
@@ -1,8 +1,9 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, WebSocket
|
||||||
|
|
||||||
from core.db import ProcessedDB
|
from core.db import ProcessedDB
|
||||||
from .config import DB_PATH
|
from .config import DB_PATH
|
||||||
from .routes import files, stream, markers, export, hidden
|
from .routes import files, stream, markers, export, hidden
|
||||||
|
from . import ws
|
||||||
|
|
||||||
app = FastAPI(title="8-cut Server")
|
app = FastAPI(title="8-cut Server")
|
||||||
|
|
||||||
@@ -13,3 +14,8 @@ app.include_router(stream.router, prefix="/api")
|
|||||||
app.include_router(markers.router, prefix="/api")
|
app.include_router(markers.router, prefix="/api")
|
||||||
app.include_router(export.router, prefix="/api")
|
app.include_router(export.router, prefix="/api")
|
||||||
app.include_router(hidden.router, prefix="/api")
|
app.include_router(hidden.router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/export")
|
||||||
|
async def export_ws(websocket: WebSocket):
|
||||||
|
await ws.connect(websocket)
|
||||||
|
|||||||
@@ -130,6 +130,13 @@ def _audio_extract_worker(source_path: str) -> None:
|
|||||||
os.unlink(tmp)
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_dead_jobs() -> None:
|
||||||
|
"""Remove finished threads from _active_jobs. Must be called under _jobs_lock."""
|
||||||
|
dead = [k for k, t in _active_jobs.items() if not t.is_alive()]
|
||||||
|
for k in dead:
|
||||||
|
del _active_jobs[k]
|
||||||
|
|
||||||
|
|
||||||
def ensure_transcode(source_path: str, quality: str) -> CacheStatus:
|
def ensure_transcode(source_path: str, quality: str) -> CacheStatus:
|
||||||
"""Start transcode if not cached. Returns current status."""
|
"""Start transcode if not cached. Returns current status."""
|
||||||
status = get_status(source_path, quality)
|
status = get_status(source_path, quality)
|
||||||
@@ -138,6 +145,7 @@ def ensure_transcode(source_path: str, quality: str) -> CacheStatus:
|
|||||||
|
|
||||||
job_key = f"{source_path}:{quality}"
|
job_key = f"{source_path}:{quality}"
|
||||||
with _jobs_lock:
|
with _jobs_lock:
|
||||||
|
_prune_dead_jobs()
|
||||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||||
return CacheStatus.TRANSCODING
|
return CacheStatus.TRANSCODING
|
||||||
t = threading.Thread(target=_transcode_worker, args=(source_path, quality), daemon=True)
|
t = threading.Thread(target=_transcode_worker, args=(source_path, quality), daemon=True)
|
||||||
@@ -154,6 +162,7 @@ def ensure_audio(source_path: str) -> CacheStatus:
|
|||||||
|
|
||||||
job_key = f"{source_path}:audio"
|
job_key = f"{source_path}:audio"
|
||||||
with _jobs_lock:
|
with _jobs_lock:
|
||||||
|
_prune_dead_jobs()
|
||||||
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
if job_key in _active_jobs and _active_jobs[job_key].is_alive():
|
||||||
return CacheStatus.TRANSCODING
|
return CacheStatus.TRANSCODING
|
||||||
t = threading.Thread(target=_audio_extract_worker, args=(source_path,), daemon=True)
|
t = threading.Thread(target=_audio_extract_worker, args=(source_path,), daemon=True)
|
||||||
|
|||||||
+40
-10
@@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.export import ExportRunner
|
from core.export import ExportRunner
|
||||||
@@ -36,20 +37,34 @@ class ExportRequest(BaseModel):
|
|||||||
encoder: str = "libx264"
|
encoder: str = "libx264"
|
||||||
|
|
||||||
|
|
||||||
|
def _next_counter(folder: str, basename: str) -> int:
|
||||||
|
"""Scan folder for existing {basename}_NNN dirs and return max + 1."""
|
||||||
|
pattern = re.compile(rf'^{re.escape(basename)}_(\d{{3}})$')
|
||||||
|
highest = 0
|
||||||
|
if os.path.isdir(folder):
|
||||||
|
for entry in os.listdir(folder):
|
||||||
|
m = pattern.match(entry)
|
||||||
|
if m:
|
||||||
|
highest = max(highest, int(m.group(1)))
|
||||||
|
return highest + 1
|
||||||
|
|
||||||
|
|
||||||
@router.post("/export")
|
@router.post("/export")
|
||||||
def start_export(req: ExportRequest):
|
def start_export(req: ExportRequest):
|
||||||
|
from ..app import db
|
||||||
|
|
||||||
job_id = str(uuid.uuid4())[:8]
|
job_id = str(uuid.uuid4())[:8]
|
||||||
folder = EXPORT_DIR
|
folder = EXPORT_DIR
|
||||||
if req.folder_suffix:
|
if req.folder_suffix:
|
||||||
folder = folder + req.folder_suffix
|
folder = folder + req.folder_suffix
|
||||||
|
|
||||||
image_sequence = req.format == "WebP"
|
image_sequence = req.format == "WebP"
|
||||||
|
counter = _next_counter(folder, req.name)
|
||||||
|
|
||||||
# Build job list: (start, output_path, portrait_ratio, crop_center)
|
# Build job list: (start, output_path, portrait_ratio, crop_center)
|
||||||
jobs = []
|
jobs = []
|
||||||
for i in range(req.clips):
|
for i in range(req.clips):
|
||||||
start = req.cursor + i * req.spread
|
start = req.cursor + i * req.spread
|
||||||
counter = 1 # server uses simple incrementing
|
|
||||||
if image_sequence:
|
if image_sequence:
|
||||||
out = build_sequence_dir(folder, req.name, counter, sub=i if req.clips > 1 else None)
|
out = build_sequence_dir(folder, req.name, counter, sub=i if req.clips > 1 else None)
|
||||||
else:
|
else:
|
||||||
@@ -57,18 +72,34 @@ def start_export(req: ExportRequest):
|
|||||||
os.makedirs(os.path.dirname(out), exist_ok=True)
|
os.makedirs(os.path.dirname(out), exist_ok=True)
|
||||||
jobs.append((start, out, req.portrait_ratio, req.crop_center))
|
jobs.append((start, out, req.portrait_ratio, req.crop_center))
|
||||||
|
|
||||||
# Apply keyframes if provided
|
# Apply keyframes if provided — returns 6-tuples, strip back to 4
|
||||||
if req.crop_keyframes:
|
if req.crop_keyframes:
|
||||||
jobs = apply_keyframes_to_jobs(
|
widened = apply_keyframes_to_jobs(
|
||||||
jobs, req.crop_keyframes,
|
jobs, req.crop_keyframes,
|
||||||
req.crop_center, req.portrait_ratio,
|
req.crop_center, req.portrait_ratio,
|
||||||
req.rand_portrait, req.rand_square,
|
req.rand_portrait, req.rand_square,
|
||||||
)
|
)
|
||||||
|
jobs = [(s, o, r, c) for s, o, r, c, _rp, _rs in widened]
|
||||||
|
|
||||||
completed = []
|
completed = []
|
||||||
|
|
||||||
def on_clip_done(path: str):
|
def on_clip_done(path: str):
|
||||||
completed.append(path)
|
completed.append(path)
|
||||||
|
# Record in DB so markers show up
|
||||||
|
db.add(
|
||||||
|
filename=os.path.basename(req.input_path),
|
||||||
|
start_time=req.cursor,
|
||||||
|
output_path=path,
|
||||||
|
label=req.label,
|
||||||
|
category=req.category,
|
||||||
|
short_side=req.short_side,
|
||||||
|
portrait_ratio=req.portrait_ratio or "",
|
||||||
|
crop_center=req.crop_center,
|
||||||
|
fmt=req.format,
|
||||||
|
clip_count=req.clips,
|
||||||
|
spread=req.spread,
|
||||||
|
profile=req.profile,
|
||||||
|
)
|
||||||
ws_module.broadcast({"type": "clip_done", "job_id": job_id, "path": path})
|
ws_module.broadcast({"type": "clip_done", "job_id": job_id, "path": path})
|
||||||
|
|
||||||
def on_all_done():
|
def on_all_done():
|
||||||
@@ -106,7 +137,7 @@ def start_export(req: ExportRequest):
|
|||||||
def get_export_status(job_id: str):
|
def get_export_status(job_id: str):
|
||||||
job = _jobs.get(job_id)
|
job = _jobs.get(job_id)
|
||||||
if job is None:
|
if job is None:
|
||||||
return {"error": "job not found"}
|
raise HTTPException(status_code=404, detail="job not found")
|
||||||
return {
|
return {
|
||||||
"status": job["status"],
|
"status": job["status"],
|
||||||
"total": job["total"],
|
"total": job["total"],
|
||||||
@@ -119,14 +150,13 @@ def get_export_status(job_id: str):
|
|||||||
@router.delete("/export/{output_path:path}")
|
@router.delete("/export/{output_path:path}")
|
||||||
def delete_export(output_path: str):
|
def delete_export(output_path: str):
|
||||||
from ..app import db
|
from ..app import db
|
||||||
|
# Validate path is under EXPORT_DIR
|
||||||
|
real = os.path.realpath(output_path)
|
||||||
|
if not real.startswith(os.path.realpath(EXPORT_DIR) + os.sep):
|
||||||
|
raise HTTPException(status_code=403, detail="path outside export directory")
|
||||||
db.delete_by_output_path(output_path)
|
db.delete_by_output_path(output_path)
|
||||||
if os.path.isfile(output_path):
|
if os.path.isfile(output_path):
|
||||||
os.unlink(output_path)
|
os.unlink(output_path)
|
||||||
elif os.path.isdir(output_path):
|
elif os.path.isdir(output_path):
|
||||||
shutil.rmtree(output_path)
|
shutil.rmtree(output_path)
|
||||||
return {"deleted": output_path}
|
return {"deleted": output_path}
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/export")
|
|
||||||
async def export_ws(websocket: WebSocket):
|
|
||||||
await ws_module.connect(websocket)
|
|
||||||
|
|||||||
+13
-5
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
from ..config import MEDIA_DIRS, VIDEO_EXTENSIONS
|
||||||
@@ -38,11 +38,19 @@ def list_roots():
|
|||||||
return MEDIA_DIRS
|
return MEDIA_DIRS
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_resolve(path: str, root: str) -> str:
|
||||||
|
"""Join path to root and verify it stays within the root directory."""
|
||||||
|
if root not in MEDIA_DIRS:
|
||||||
|
raise HTTPException(status_code=400, detail="invalid root")
|
||||||
|
full = os.path.realpath(os.path.join(root, path))
|
||||||
|
if not full.startswith(os.path.realpath(root) + os.sep):
|
||||||
|
raise HTTPException(status_code=403, detail="path outside media root")
|
||||||
|
return full
|
||||||
|
|
||||||
|
|
||||||
@router.get("/video/{path:path}")
|
@router.get("/video/{path:path}")
|
||||||
def serve_video(path: str, root: str = Query(...)):
|
def serve_video(path: str, root: str = Query(...)):
|
||||||
if root not in MEDIA_DIRS:
|
full = _safe_resolve(path, root)
|
||||||
return {"error": "invalid root"}
|
|
||||||
full = os.path.join(root, path)
|
|
||||||
if not os.path.isfile(full):
|
if not os.path.isfile(full):
|
||||||
return {"error": "not found"}
|
raise HTTPException(status_code=404, detail="not found")
|
||||||
return FileResponse(full, media_type="video/mp4")
|
return FileResponse(full, media_type="video/mp4")
|
||||||
|
|||||||
+11
-12
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
|
|
||||||
from ..config import MEDIA_DIRS, QUALITY_PRESETS
|
from ..config import MEDIA_DIRS, QUALITY_PRESETS
|
||||||
@@ -9,20 +9,23 @@ from .. import cache
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _resolve_source(path: str, root: str) -> str | None:
|
def _resolve_source(path: str, root: str) -> str:
|
||||||
|
"""Join path to root, verify it stays within root, and exists."""
|
||||||
if root not in MEDIA_DIRS:
|
if root not in MEDIA_DIRS:
|
||||||
return None
|
raise HTTPException(status_code=400, detail="invalid root")
|
||||||
full = os.path.join(root, path)
|
full = os.path.realpath(os.path.join(root, path))
|
||||||
return full if os.path.isfile(full) else None
|
if not full.startswith(os.path.realpath(root) + os.sep):
|
||||||
|
raise HTTPException(status_code=403, detail="path outside media root")
|
||||||
|
if not os.path.isfile(full):
|
||||||
|
raise HTTPException(status_code=404, detail="not found")
|
||||||
|
return full
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stream/{path:path}")
|
@router.get("/stream/{path:path}")
|
||||||
def stream_video(path: str, root: str = Query(...), quality: str = Query("low")):
|
def stream_video(path: str, root: str = Query(...), quality: str = Query("low")):
|
||||||
if quality not in QUALITY_PRESETS:
|
if quality not in QUALITY_PRESETS:
|
||||||
return JSONResponse({"error": f"invalid quality: {quality}"}, status_code=400)
|
raise HTTPException(status_code=400, detail=f"invalid quality: {quality}")
|
||||||
source = _resolve_source(path, root)
|
source = _resolve_source(path, root)
|
||||||
if source is None:
|
|
||||||
return JSONResponse({"error": "not found"}, status_code=404)
|
|
||||||
|
|
||||||
status = cache.ensure_transcode(source, quality)
|
status = cache.ensure_transcode(source, quality)
|
||||||
if status == cache.CacheStatus.READY:
|
if status == cache.CacheStatus.READY:
|
||||||
@@ -33,8 +36,6 @@ def stream_video(path: str, root: str = Query(...), quality: str = Query("low"))
|
|||||||
@router.get("/audio/{path:path}")
|
@router.get("/audio/{path:path}")
|
||||||
def stream_audio(path: str, root: str = Query(...)):
|
def stream_audio(path: str, root: str = Query(...)):
|
||||||
source = _resolve_source(path, root)
|
source = _resolve_source(path, root)
|
||||||
if source is None:
|
|
||||||
return JSONResponse({"error": "not found"}, status_code=404)
|
|
||||||
|
|
||||||
status = cache.ensure_audio(source)
|
status = cache.ensure_audio(source)
|
||||||
if status == cache.CacheStatus.READY:
|
if status == cache.CacheStatus.READY:
|
||||||
@@ -45,6 +46,4 @@ def stream_audio(path: str, root: str = Query(...)):
|
|||||||
@router.get("/cache/status/{path:path}")
|
@router.get("/cache/status/{path:path}")
|
||||||
def cache_status(path: str, root: str = Query(...)):
|
def cache_status(path: str, root: str = Query(...)):
|
||||||
source = _resolve_source(path, root)
|
source = _resolve_source(path, root)
|
||||||
if source is None:
|
|
||||||
return JSONResponse({"error": "not found"}, status_code=404)
|
|
||||||
return cache.get_all_statuses(source)
|
return cache.get_all_statuses(source)
|
||||||
|
|||||||
+20
-16
@@ -1,37 +1,41 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
from fastapi import WebSocket, WebSocketDisconnect
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
_connections: list[WebSocket] = []
|
_connections: list[WebSocket] = []
|
||||||
|
_loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
|
||||||
async def connect(ws: WebSocket):
|
async def connect(ws: WebSocket):
|
||||||
|
global _loop
|
||||||
|
_loop = asyncio.get_running_loop()
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
_connections.append(ws)
|
with _lock:
|
||||||
|
_connections.append(ws)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
await ws.receive_text() # keep alive
|
await ws.receive_text() # keep alive
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
_connections.remove(ws)
|
with _lock:
|
||||||
|
if ws in _connections:
|
||||||
|
_connections.remove(ws)
|
||||||
|
|
||||||
|
|
||||||
def broadcast(msg: dict):
|
def broadcast(msg: dict):
|
||||||
"""Send a message to all connected WebSocket clients.
|
"""Send a message to all connected WebSocket clients.
|
||||||
|
|
||||||
Called from sync code (export callbacks), so we schedule the coroutine
|
Called from sync code (export callbacks running in background threads),
|
||||||
on each connection's event loop.
|
so we schedule sends on uvicorn's event loop.
|
||||||
"""
|
"""
|
||||||
|
if _loop is None:
|
||||||
|
return
|
||||||
data = json.dumps(msg)
|
data = json.dumps(msg)
|
||||||
stale = []
|
with _lock:
|
||||||
for ws in _connections:
|
for ws in list(_connections):
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
asyncio.run_coroutine_threadsafe(ws.send_text(data), _loop)
|
||||||
if loop.is_running():
|
except Exception:
|
||||||
asyncio.run_coroutine_threadsafe(ws.send_text(data), loop)
|
pass
|
||||||
else:
|
|
||||||
loop.run_until_complete(ws.send_text(data))
|
|
||||||
except Exception:
|
|
||||||
stale.append(ws)
|
|
||||||
for ws in stale:
|
|
||||||
_connections.remove(ws)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user