From 46265a6f89e714cd48b43ed052e8a12b2cd42463 Mon Sep 17 00:00:00 2001 From: founder Date: Fri, 5 Jun 2026 03:02:42 +0200 Subject: [PATCH] Add /api/transcribe endpoint with Whisper --- app/routes.py | 56 +++++++++++++++++++++++++++++++++- app/services/__init__.py | 0 app/services/transcriber.py | 55 ++++++++++++++++++++++++++++++++++ dockerfile | 2 ++ requirements.txt | 2 ++ tests/test_main.py | 60 +++++++++++++++++++++++++++++++++---- 6 files changed, 168 insertions(+), 7 deletions(-) create mode 100644 app/services/__init__.py create mode 100644 app/services/transcriber.py diff --git a/app/routes.py b/app/routes.py index a3b5caa..918f524 100644 --- a/app/routes.py +++ b/app/routes.py @@ -2,10 +2,15 @@ from datetime import datetime, timezone -from fastapi import APIRouter +from fastapi import APIRouter, File, HTTPException, UploadFile + +from app.services.transcriber import transcribe_bytes router = APIRouter(prefix="/api") +# File size limit: 10 MB +MAX_UPLOAD_SIZE = 10 * 1024 * 1024 + @router.get("/health") async def health(): @@ -16,3 +21,52 @@ async def health(): "version": "0.1.0", "timestamp": datetime.now(timezone.utc).isoformat(), } + + +@router.post("/transcribe") +async def transcribe_audio( + file: UploadFile = File(...), + model: str = "base", +): + """Transcribe an audio file using Whisper. + + Supported formats: wav, mp3, m4a, ogg, flac, webm. + Model options: tiny, base, small, medium (default: base). + """ + # Validate model + valid_models = {"tiny", "base", "small", "medium"} + if model not in valid_models: + raise HTTPException( + status_code=400, + detail=f"Invalid model '{model}'. Use: {', '.join(sorted(valid_models))}", + ) + + # Validate file type + allowed = {"audio/wav", "audio/mpeg", "audio/mp4", "audio/x-m4a", + "audio/ogg", "audio/flac", "audio/webm", "audio/x-wav"} + if file.content_type and file.content_type not in allowed: + raise HTTPException( + status_code=400, + detail=f"Unsupported format: {file.content_type}. Supported: wav, mp3, m4a, ogg, flac, webm", + ) + + # Read file + contents = await file.read() + if len(contents) > MAX_UPLOAD_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Max {MAX_UPLOAD_SIZE // (1024*1024)} MB.", + ) + if len(contents) == 0: + raise HTTPException(400, detail="Empty file.") + + # Transcribe + try: + result = transcribe_bytes(contents, model_name=model) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + + return { + "filename": file.filename, + **result, + } diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/transcriber.py b/app/services/transcriber.py new file mode 100644 index 0000000..094652e --- /dev/null +++ b/app/services/transcriber.py @@ -0,0 +1,55 @@ +"""Whisper transcription service — CPU-only, async-ready.""" + +import io +import tempfile +import time +from pathlib import Path + +_model = None +_model_name = None + + +def _load_model(name: str = "base"): + """Lazy-load Whisper model (downloads on first use).""" + global _model, _model_name + import whisper + + if _model is None or _model_name != name: + _model = whisper.load_model(name) + _model_name = name + return _model + + +def transcribe_bytes(audio_bytes: bytes, model_name: str = "base") -> dict: + """Transcribe audio from bytes. Returns {"text": "...", "segments": [...], "language": "..."}""" + t0 = time.time() + + model = _load_model(model_name) + + # Write to temp file (whisper needs a file path or numpy array) + suffix = ".wav" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(audio_bytes) + tmp_path = tmp.name + + try: + result = model.transcribe(tmp_path, fp16=False) # fp16=False for CPU + finally: + Path(tmp_path).unlink(missing_ok=True) + + elapsed = round(time.time() - t0, 1) + + return { + "text": result["text"].strip(), + "segments": [ + { + "start": round(seg["start"], 2), + "end": round(seg["end"], 2), + "text": seg["text"].strip(), + } + for seg in result.get("segments", []) + ], + "language": result.get("language", "unknown"), + "duration_seconds": elapsed, + "model": model_name, + } diff --git a/dockerfile b/dockerfile index 94825db..d5b86e7 100644 --- a/dockerfile +++ b/dockerfile @@ -1,5 +1,7 @@ FROM python:3.11-slim +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg && rm -rf /var/lib/apt/lists/* + WORKDIR /app COPY requirements.txt . diff --git a/requirements.txt b/requirements.txt index 3a4da97..716ddde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ fastapi>=0.115.0 uvicorn[standard]>=0.30.0 pytest>=8.0.0 httpx>=0.27.0 +python-multipart>=0.0.9 +openai-whisper>=20240930 diff --git a/tests/test_main.py b/tests/test_main.py index 9b4bfe0..17f38ac 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,6 @@ -"""Tests for the FastAPI application.""" +"""Tests for the FastAPI micro-api.""" + +import io import pytest from fastapi.testclient import TestClient @@ -26,30 +28,76 @@ class TestHealth: assert data["status"] == "running" +class TestTranscribe: + def test_transcribe_no_file(self): + response = client.post("/api/transcribe") + assert response.status_code == 422 # FastAPI validation error + + def test_transcribe_invalid_model(self): + # Create a tiny WAV file (44-byte header + silence) + wav_header = ( + b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00" + b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00" + b"\x02\x00\x10\x00data\x00\x00\x00\x00" + ) + response = client.post( + "/api/transcribe?model=invalid_model", + files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")}, + ) + assert response.status_code == 400 + assert "Invalid model" in response.json()["detail"] + + def test_transcribe_invalid_format(self): + response = client.post( + "/api/transcribe", + files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")}, + ) + assert response.status_code == 400 + assert "Unsupported format" in response.json()["detail"] + + def test_transcribe_empty_file(self): + response = client.post( + "/api/transcribe", + files={"file": ("empty.wav", io.BytesIO(b""), "audio/wav")}, + ) + assert response.status_code == 400 + assert "Empty" in response.json()["detail"] + + def test_transcribe_accepts_wav_with_tiny_model(self): + """Test that the endpoint accepts a valid WAV file (tiny model won't load).""" + wav_header = ( + b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00" + b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00" + b"\x02\x00\x10\x00data\x00\x00\x00\x00" + ) + response = client.post( + "/api/transcribe?model=tiny", + files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")}, + ) + # 422 if file format is too broken for whisper, 200 if it just returns empty + # 500 if model not downloaded (expected in CI) + assert response.status_code in (200, 422, 500) + + class TestDatabase: def test_init_db_creates_tables(self): from app.database import init_db - - # Should not raise init_db() class TestModels: def test_item_create_validation(self): from app.models import ItemCreate - item = ItemCreate(name="test-item") assert item.name == "test-item" def test_item_create_requires_name(self): from app.models import ItemCreate - with pytest.raises(Exception): ItemCreate() def test_item_response_serialization(self): from app.models import ItemResponse - item = ItemResponse(id=1, name="test") data = item.model_dump() assert data["id"] == 1