Add /api/transcribe endpoint with Whisper
This commit is contained in:
+54
-6
@@ -1,4 +1,6 @@
|
||||
"""Tests for the FastAPI application."""
|
||||
"""Tests for the FastAPI micro-api."""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -26,30 +28,76 @@ class TestHealth:
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
class TestTranscribe:
|
||||
def test_transcribe_no_file(self):
|
||||
response = client.post("/api/transcribe")
|
||||
assert response.status_code == 422 # FastAPI validation error
|
||||
|
||||
def test_transcribe_invalid_model(self):
|
||||
# Create a tiny WAV file (44-byte header + silence)
|
||||
wav_header = (
|
||||
b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00"
|
||||
b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00"
|
||||
b"\x02\x00\x10\x00data\x00\x00\x00\x00"
|
||||
)
|
||||
response = client.post(
|
||||
"/api/transcribe?model=invalid_model",
|
||||
files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Invalid model" in response.json()["detail"]
|
||||
|
||||
def test_transcribe_invalid_format(self):
|
||||
response = client.post(
|
||||
"/api/transcribe",
|
||||
files={"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported format" in response.json()["detail"]
|
||||
|
||||
def test_transcribe_empty_file(self):
|
||||
response = client.post(
|
||||
"/api/transcribe",
|
||||
files={"file": ("empty.wav", io.BytesIO(b""), "audio/wav")},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Empty" in response.json()["detail"]
|
||||
|
||||
def test_transcribe_accepts_wav_with_tiny_model(self):
|
||||
"""Test that the endpoint accepts a valid WAV file (tiny model won't load)."""
|
||||
wav_header = (
|
||||
b"RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00"
|
||||
b"\x01\x00\x01\x00\x44\xac\x00\x00\x88\x58\x01\x00"
|
||||
b"\x02\x00\x10\x00data\x00\x00\x00\x00"
|
||||
)
|
||||
response = client.post(
|
||||
"/api/transcribe?model=tiny",
|
||||
files={"file": ("test.wav", io.BytesIO(wav_header), "audio/wav")},
|
||||
)
|
||||
# 422 if file format is too broken for whisper, 200 if it just returns empty
|
||||
# 500 if model not downloaded (expected in CI)
|
||||
assert response.status_code in (200, 422, 500)
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
def test_init_db_creates_tables(self):
|
||||
from app.database import init_db
|
||||
|
||||
# Should not raise
|
||||
init_db()
|
||||
|
||||
|
||||
class TestModels:
|
||||
def test_item_create_validation(self):
|
||||
from app.models import ItemCreate
|
||||
|
||||
item = ItemCreate(name="test-item")
|
||||
assert item.name == "test-item"
|
||||
|
||||
def test_item_create_requires_name(self):
|
||||
from app.models import ItemCreate
|
||||
|
||||
with pytest.raises(Exception):
|
||||
ItemCreate()
|
||||
|
||||
def test_item_response_serialization(self):
|
||||
from app.models import ItemResponse
|
||||
|
||||
item = ItemResponse(id=1, name="test")
|
||||
data = item.model_dump()
|
||||
assert data["id"] == 1
|
||||
|
||||
Reference in New Issue
Block a user