fix: add transcription model fallback chain
This commit is contained in:
+34
-21
@@ -36,17 +36,22 @@ def _audio_format(audio_path: str) -> str:
|
||||
return suffix or "wav"
|
||||
|
||||
|
||||
def _build_transcription_payload(audio_path: str) -> dict[str, Any]:
|
||||
TRANSCRIPTION_MODELS = [
|
||||
"openai/gpt-4o-mini-transcribe",
|
||||
"openai/whisper-large-v3",
|
||||
]
|
||||
|
||||
|
||||
def _build_transcription_payload(audio_path: str, model: str) -> dict[str, Any]:
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
encoded = base64.b64encode(audio_file.read()).decode("ascii")
|
||||
|
||||
return {
|
||||
"model": "openai/whisper-large-v3",
|
||||
"model": model,
|
||||
"input_audio": {
|
||||
"data": encoded,
|
||||
"format": _audio_format(audio_path),
|
||||
},
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
|
||||
@@ -84,31 +89,39 @@ async def _normalize_audio_for_transcription(audio_path: str) -> str:
|
||||
|
||||
|
||||
async def transcribe(audio_path: str) -> str:
|
||||
"""Send audio to OpenRouter's whisper model and return transcript text."""
|
||||
"""Send audio to OpenRouter STT models and return transcript text."""
|
||||
headers = _auth_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["X-OpenRouter-Title"] = "discord-meeting-bot"
|
||||
|
||||
normalized_path = await _normalize_audio_for_transcription(audio_path)
|
||||
failures: list[str] = []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
resp = await client.post(
|
||||
f"{OPENROUTER_BASE}/audio/transcriptions",
|
||||
headers=headers,
|
||||
content=json.dumps(_build_transcription_payload(normalized_path)),
|
||||
)
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = summarize_error(_safe_json(resp), fallback=resp.text)
|
||||
raise RuntimeError(
|
||||
f"OpenRouter transcription failed ({resp.status_code}): {detail}"
|
||||
) from exc
|
||||
for model in TRANSCRIPTION_MODELS:
|
||||
resp = await client.post(
|
||||
f"{OPENROUTER_BASE}/audio/transcriptions",
|
||||
headers=headers,
|
||||
json=_build_transcription_payload(normalized_path, model),
|
||||
)
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
detail = summarize_error(_safe_json(resp), fallback=resp.text)
|
||||
generation_id = resp.headers.get("x-generation-id")
|
||||
suffix = f"; generation_id={generation_id}" if generation_id else ""
|
||||
failures.append(f"{model}: HTTP {resp.status_code}: {detail}{suffix}")
|
||||
continue
|
||||
|
||||
data = resp.json()
|
||||
text = data.get("text", "")
|
||||
if not text.strip():
|
||||
raise RuntimeError("OpenRouter transcription returned empty text")
|
||||
return text.strip()
|
||||
data = resp.json()
|
||||
text = data.get("text", "")
|
||||
if text.strip():
|
||||
return text.strip()
|
||||
failures.append(f"{model}: returned empty text")
|
||||
|
||||
raise RuntimeError(
|
||||
"OpenRouter transcription failed across all models: " + " | ".join(failures)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.remove(normalized_path)
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from openrouter_client import _audio_format, _build_transcription_payload
|
||||
from openrouter_client import TRANSCRIPTION_MODELS, _audio_format, _build_transcription_payload
|
||||
|
||||
|
||||
def test_audio_format_defaults_to_wav_when_missing_suffix(tmp_path: Path):
|
||||
@@ -18,9 +18,8 @@ def test_build_transcription_payload_uses_base64_json_shape(tmp_path: Path):
|
||||
path = tmp_path / "meeting.wav"
|
||||
path.write_bytes(b"RIFFdemo")
|
||||
|
||||
payload = _build_transcription_payload(str(path))
|
||||
payload = _build_transcription_payload(str(path), TRANSCRIPTION_MODELS[0])
|
||||
|
||||
assert payload["model"] == "openai/whisper-large-v3"
|
||||
assert payload["language"] == "en"
|
||||
assert payload["model"] == "openai/gpt-4o-mini-transcribe"
|
||||
assert payload["input_audio"]["format"] == "wav"
|
||||
assert payload["input_audio"]["data"] == base64.b64encode(b"RIFFdemo").decode("ascii")
|
||||
|
||||
Reference in New Issue
Block a user