fix: add transcription model fallback chain
This commit is contained in:
+24
-11
@@ -36,17 +36,22 @@ def _audio_format(audio_path: str) -> str:
|
|||||||
return suffix or "wav"
|
return suffix or "wav"
|
||||||
|
|
||||||
|
|
||||||
def _build_transcription_payload(audio_path: str) -> dict[str, Any]:
|
TRANSCRIPTION_MODELS = [
|
||||||
|
"openai/gpt-4o-mini-transcribe",
|
||||||
|
"openai/whisper-large-v3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_transcription_payload(audio_path: str, model: str) -> dict[str, Any]:
|
||||||
with open(audio_path, "rb") as audio_file:
|
with open(audio_path, "rb") as audio_file:
|
||||||
encoded = base64.b64encode(audio_file.read()).decode("ascii")
|
encoded = base64.b64encode(audio_file.read()).decode("ascii")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": "openai/whisper-large-v3",
|
"model": model,
|
||||||
"input_audio": {
|
"input_audio": {
|
||||||
"data": encoded,
|
"data": encoded,
|
||||||
"format": _audio_format(audio_path),
|
"format": _audio_format(audio_path),
|
||||||
},
|
},
|
||||||
"language": "en",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -84,31 +89,39 @@ async def _normalize_audio_for_transcription(audio_path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def transcribe(audio_path: str) -> str:
|
async def transcribe(audio_path: str) -> str:
|
||||||
"""Send audio to OpenRouter's whisper model and return transcript text."""
|
"""Send audio to OpenRouter STT models and return transcript text."""
|
||||||
headers = _auth_headers()
|
headers = _auth_headers()
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
|
headers["X-OpenRouter-Title"] = "discord-meeting-bot"
|
||||||
|
|
||||||
normalized_path = await _normalize_audio_for_transcription(audio_path)
|
normalized_path = await _normalize_audio_for_transcription(audio_path)
|
||||||
|
failures: list[str] = []
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
|
for model in TRANSCRIPTION_MODELS:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{OPENROUTER_BASE}/audio/transcriptions",
|
f"{OPENROUTER_BASE}/audio/transcriptions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
content=json.dumps(_build_transcription_payload(normalized_path)),
|
json=_build_transcription_payload(normalized_path, model),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError:
|
||||||
detail = summarize_error(_safe_json(resp), fallback=resp.text)
|
detail = summarize_error(_safe_json(resp), fallback=resp.text)
|
||||||
raise RuntimeError(
|
generation_id = resp.headers.get("x-generation-id")
|
||||||
f"OpenRouter transcription failed ({resp.status_code}): {detail}"
|
suffix = f"; generation_id={generation_id}" if generation_id else ""
|
||||||
) from exc
|
failures.append(f"{model}: HTTP {resp.status_code}: {detail}{suffix}")
|
||||||
|
continue
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
text = data.get("text", "")
|
text = data.get("text", "")
|
||||||
if not text.strip():
|
if text.strip():
|
||||||
raise RuntimeError("OpenRouter transcription returned empty text")
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
failures.append(f"{model}: returned empty text")
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"OpenRouter transcription failed across all models: " + " | ".join(failures)
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
os.remove(normalized_path)
|
os.remove(normalized_path)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from openrouter_client import _audio_format, _build_transcription_payload
|
from openrouter_client import TRANSCRIPTION_MODELS, _audio_format, _build_transcription_payload
|
||||||
|
|
||||||
|
|
||||||
def test_audio_format_defaults_to_wav_when_missing_suffix(tmp_path: Path):
|
def test_audio_format_defaults_to_wav_when_missing_suffix(tmp_path: Path):
|
||||||
@@ -18,9 +18,8 @@ def test_build_transcription_payload_uses_base64_json_shape(tmp_path: Path):
|
|||||||
path = tmp_path / "meeting.wav"
|
path = tmp_path / "meeting.wav"
|
||||||
path.write_bytes(b"RIFFdemo")
|
path.write_bytes(b"RIFFdemo")
|
||||||
|
|
||||||
payload = _build_transcription_payload(str(path))
|
payload = _build_transcription_payload(str(path), TRANSCRIPTION_MODELS[0])
|
||||||
|
|
||||||
assert payload["model"] == "openai/whisper-large-v3"
|
assert payload["model"] == "openai/gpt-4o-mini-transcribe"
|
||||||
assert payload["language"] == "en"
|
|
||||||
assert payload["input_audio"]["format"] == "wav"
|
assert payload["input_audio"]["format"] == "wav"
|
||||||
assert payload["input_audio"]["data"] == base64.b64encode(b"RIFFdemo").decode("ascii")
|
assert payload["input_audio"]["data"] == base64.b64encode(b"RIFFdemo").decode("ascii")
|
||||||
|
|||||||
Reference in New Issue
Block a user