Add /api/transcribe endpoint with Whisper
This commit is contained in:
+55
-1
@@ -2,10 +2,15 @@
|
|||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||||
|
|
||||||
|
from app.services.transcriber import transcribe_bytes
|
||||||
|
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
|
# File size limit: 10 MB
|
||||||
|
MAX_UPLOAD_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
@@ -16,3 +21,52 @@ async def health():
|
|||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/transcribe")
|
||||||
|
async def transcribe_audio(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
model: str = "base",
|
||||||
|
):
|
||||||
|
"""Transcribe an audio file using Whisper.
|
||||||
|
|
||||||
|
Supported formats: wav, mp3, m4a, ogg, flac, webm.
|
||||||
|
Model options: tiny, base, small, medium (default: base).
|
||||||
|
"""
|
||||||
|
# Validate model
|
||||||
|
valid_models = {"tiny", "base", "small", "medium"}
|
||||||
|
if model not in valid_models:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid model '{model}'. Use: {', '.join(sorted(valid_models))}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate file type
|
||||||
|
allowed = {"audio/wav", "audio/mpeg", "audio/mp4", "audio/x-m4a",
|
||||||
|
"audio/ogg", "audio/flac", "audio/webm", "audio/x-wav"}
|
||||||
|
if file.content_type and file.content_type not in allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported format: {file.content_type}. Supported: wav, mp3, m4a, ogg, flac, webm",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file
|
||||||
|
contents = await file.read()
|
||||||
|
if len(contents) > MAX_UPLOAD_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"File too large. Max {MAX_UPLOAD_SIZE // (1024*1024)} MB.",
|
||||||
|
)
|
||||||
|
if len(contents) == 0:
|
||||||
|
raise HTTPException(400, detail="Empty file.")
|
||||||
|
|
||||||
|
# Transcribe
|
||||||
|
try:
|
||||||
|
result = transcribe_bytes(contents, model_name=model)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"filename": file.filename,
|
||||||
|
**result,
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""Whisper transcription service — CPU-only, async-ready."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_model = None
|
||||||
|
_model_name = None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model(name: str = "base"):
|
||||||
|
"""Lazy-load Whisper model (downloads on first use)."""
|
||||||
|
global _model, _model_name
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
if _model is None or _model_name != name:
|
||||||
|
_model = whisper.load_model(name)
|
||||||
|
_model_name = name
|
||||||
|
return _model
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_bytes(audio_bytes: bytes, model_name: str = "base") -> dict:
|
||||||
|
"""Transcribe audio from bytes. Returns {"text": "...", "segments": [...], "language": "..."}"""
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
model = _load_model(model_name)
|
||||||
|
|
||||||
|
# Write to temp file (whisper needs a file path or numpy array)
|
||||||
|
suffix = ".wav"
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||||
|
tmp.write(audio_bytes)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = model.transcribe(tmp_path, fp16=False) # fp16=False for CPU
|
||||||
|
finally:
|
||||||
|
Path(tmp_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
elapsed = round(time.time() - t0, 1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": result["text"].strip(),
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"start": round(seg["start"], 2),
|
||||||
|
"end": round(seg["end"], 2),
|
||||||
|
"text": seg["text"].strip(),
|
||||||
|
}
|
||||||
|
for seg in result.get("segments", [])
|
||||||
|
],
|
||||||
|
"language": result.get("language", "unknown"),
|
||||||
|
"duration_seconds": elapsed,
|
||||||
|
"model": model_name,
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
FROM python:3.11-slim
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ fastapi>=0.115.0
|
|||||||
uvicorn[standard]>=0.30.0
|
uvicorn[standard]>=0.30.0
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
httpx>=0.27.0
|
httpx>=0.27.0
|
||||||
|
python-multipart>=0.0.9
|
||||||
|
openai-whisper>=20240930
|
||||||
|
|||||||
+54
-6
@@ -1,4 +1,6 @@
|
|||||||
"""Tests for the FastAPI application."""
|
"""Tests for the FastAPI micro-api."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
@@ -26,30 +28,76 @@ class TestHealth:
|
|||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscribe:
|
||||||
|
def test_transcribe_no_file(self):
|
||||||
|
response = client.post("/api/transcribe")
|
||||||
|
assert response.status_code == 422 # FastAPI validation error
|
||||||
|
|
||||||
|
def test_transcribe_invalid_model(self):
|
||||||
|
# Create a tiny WAV file (44-byte header + silence)
|
||||||
|
wav_header = (
|
||||||
|
b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00"
|
||||||
|
b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00"
|
||||||
|
b"\x02\x00\x10\x00data\x00\x00\x00\x00"
|
||||||
|
)
|
||||||
|
response = client.post(
|
||||||
|
"/api/transcribe?model=invalid_model",
|
||||||
|
files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid model" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_transcribe_invalid_format(self):
|
||||||
|
response = client.post(
|
||||||
|
"/api/transcribe",
|
||||||
|
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Unsupported format" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_transcribe_empty_file(self):
|
||||||
|
response = client.post(
|
||||||
|
"/api/transcribe",
|
||||||
|
files={"file": ("empty.wav", io.BytesIO(b""), "audio/wav")},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Empty" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_transcribe_accepts_wav_with_tiny_model(self):
|
||||||
|
"""Test that the endpoint accepts a valid WAV file (tiny model won't load)."""
|
||||||
|
wav_header = (
|
||||||
|
b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00"
|
||||||
|
b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00"
|
||||||
|
b"\x02\x00\x10\x00data\x00\x00\x00\x00"
|
||||||
|
)
|
||||||
|
response = client.post(
|
||||||
|
"/api/transcribe?model=tiny",
|
||||||
|
files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")},
|
||||||
|
)
|
||||||
|
# 422 if file format is too broken for whisper, 200 if it just returns empty
|
||||||
|
# 500 if model not downloaded (expected in CI)
|
||||||
|
assert response.status_code in (200, 422, 500)
|
||||||
|
|
||||||
|
|
||||||
class TestDatabase:
|
class TestDatabase:
|
||||||
def test_init_db_creates_tables(self):
|
def test_init_db_creates_tables(self):
|
||||||
from app.database import init_db
|
from app.database import init_db
|
||||||
|
|
||||||
# Should not raise
|
|
||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
|
||||||
class TestModels:
|
class TestModels:
|
||||||
def test_item_create_validation(self):
|
def test_item_create_validation(self):
|
||||||
from app.models import ItemCreate
|
from app.models import ItemCreate
|
||||||
|
|
||||||
item = ItemCreate(name="test-item")
|
item = ItemCreate(name="test-item")
|
||||||
assert item.name == "test-item"
|
assert item.name == "test-item"
|
||||||
|
|
||||||
def test_item_create_requires_name(self):
|
def test_item_create_requires_name(self):
|
||||||
from app.models import ItemCreate
|
from app.models import ItemCreate
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
ItemCreate()
|
ItemCreate()
|
||||||
|
|
||||||
def test_item_response_serialization(self):
|
def test_item_response_serialization(self):
|
||||||
from app.models import ItemResponse
|
from app.models import ItemResponse
|
||||||
|
|
||||||
item = ItemResponse(id=1, name="test")
|
item = ItemResponse(id=1, name="test")
|
||||||
data = item.model_dump()
|
data = item.model_dump()
|
||||||
assert data["id"] == 1
|
assert data["id"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user