265 lines
8.6 KiB
Python
265 lines
8.6 KiB
Python
#!/usr/bin/python3
|
|
|
|
# Recent phrases to include in the text buffer before the current transcription.
|
|
recent_phrase_count = 8
|
|
|
|
# Seconds of silence before we start a new phrase.
|
|
phrase_timeout = 1.0
|
|
|
|
# Higher is less restrictive on what it lets pass through.
|
|
no_speech_prob_threshold = 0.05 # 0.15
|
|
|
|
# Minimum number of seconds before we fire off the model again.
|
|
min_time_between_updates = 2
|
|
|
|
import numpy as np
|
|
import speech_recognition
|
|
import whisper
|
|
import torch
|
|
import wave
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
import diffstuff
|
|
import threading
|
|
import time
|
|
from queue import Queue
|
|
|
|
_audio_model = whisper.load_model("medium.en") # "large"
|
|
_audio_model_mutex = threading.Lock()
|
|
|
|
# For debugging...
|
|
#wave_out = wave.open("wave.wav", "wb")
|
|
#wave_out.setnchannels(1)
|
|
#wave_out.setframerate(16000)
|
|
#wave_out.setsampwidth(2)
|
|
|
|
class Transcriber:
|
|
|
|
def __init__(self):
|
|
self._audio_source = None
|
|
|
|
# Audio data for the current phrase.
|
|
self._current_data = b''
|
|
|
|
self.phrases = [""]
|
|
|
|
self.scrolling_text = ""
|
|
|
|
# Time since the last data came in for the current phrase.
|
|
self._phrase_time = datetime.utcnow()
|
|
|
|
# Last time that we ran the model.
|
|
self._last_model_time = datetime.utcnow()
|
|
|
|
self._should_stop = False
|
|
|
|
self._phrases_list_mutex = threading.Lock()
|
|
|
|
self._update_thread = threading.Thread(
|
|
target=self._update_thread_func, daemon=True)
|
|
self._update_thread.start()
|
|
|
|
self.username = ""
|
|
|
|
self._audio_source_queue = Queue()
|
|
|
|
def stop(self):
|
|
self._should_stop = True
|
|
|
|
def set_source(self, source):
|
|
self._audio_source_queue.put(source)
|
|
|
|
def phrase_probably_silent(self):
|
|
"""Whisper hallucinates a LOT on silence, so let's just ignore stuff
|
|
that's mostly silence. First line of defense here."""
|
|
|
|
threshold = 100
|
|
threshold_pass = 0
|
|
threshold_fail = 0
|
|
avg = 0
|
|
for k in self._current_data:
|
|
avg += k
|
|
if(abs(k)) > threshold:
|
|
threshold_pass += 1
|
|
else:
|
|
threshold_fail += 1
|
|
|
|
avg = avg / len(self._current_data)
|
|
threshold_pct = threshold_pass / len(self._current_data)
|
|
|
|
if threshold_pct < 0.1:
|
|
return True
|
|
|
|
return False
|
|
|
|
def run_transcription_update(self):
|
|
|
|
#print("run_transcription_update Checking queue...")
|
|
|
|
# Switch to whatever the newest source is.
|
|
if self._audio_source == None and not self._audio_source_queue.empty():
|
|
self._audio_source = self._audio_source_queue.get()
|
|
|
|
if not self._audio_source:
|
|
#print("run_transcription_update returning early...")
|
|
return
|
|
|
|
now = datetime.utcnow()
|
|
if not self._audio_source.data_queue.empty():
|
|
|
|
# We got some new data. Let's process it!
|
|
|
|
# Get all the new data since last tick.
|
|
new_data = []
|
|
with self._audio_source._data_mutex:
|
|
while not self._audio_source.data_queue.empty():
|
|
new_packet = self._audio_source.data_queue.get()
|
|
new_data.append(new_packet)
|
|
new_data_joined = b''.join(new_data)
|
|
|
|
# For debugging...
|
|
#wave_out.writeframes(new_data_joined)
|
|
|
|
# Append it to the current buffer.
|
|
self._current_data = self._current_data + new_data_joined
|
|
|
|
# if self.phrase_probably_silent():
|
|
# with self._phrases_list_mutex:
|
|
# self.phrases[-1] = ""
|
|
# else:
|
|
|
|
# Convert in-ram buffer to something the model can use
|
|
# directly without needing a temp file. Convert data from 16
|
|
# bit wide integers to floating point with a width of 32
|
|
# bits. Clamp the audio stream frequency to a PCM wavelength
|
|
# compatible default of 32768hz max.
|
|
audio_np = np.frombuffer(
|
|
self._current_data[-16000 * 10:],
|
|
dtype=np.int16).astype(np.float32) / 32768.0
|
|
|
|
#print("run_transcription_update About to run transcription...")
|
|
|
|
# Run the transcription model, and extract the text.
|
|
with _audio_model_mutex:
|
|
#print("Transcribe start ", len(self._current_data))
|
|
result = _audio_model.transcribe(
|
|
audio_np, fp16=torch.cuda.is_available(),
|
|
word_timestamps=True,
|
|
hallucination_silence_threshold=2)
|
|
#print("Transcribe end")
|
|
self._last_model_time = now
|
|
|
|
with self._phrases_list_mutex:
|
|
wave_out = wave.open("tmp/wave%0.4d.wav" % len(self.phrases), "wb")
|
|
wave_out.setnchannels(1)
|
|
wave_out.setframerate(16000)
|
|
wave_out.setsampwidth(2)
|
|
wave_out.writeframes(self._current_data)
|
|
|
|
|
|
# Filter out text segments with a high no_speech_prob.
|
|
combined_text = ""
|
|
for seg in result["segments"]:
|
|
if seg["no_speech_prob"] <= no_speech_prob_threshold:
|
|
combined_text += seg["text"]
|
|
|
|
text = combined_text.strip()
|
|
|
|
# # FIXME:
|
|
text = result["text"]
|
|
|
|
with self._phrases_list_mutex:
|
|
self.phrases[-1] = text
|
|
#print("phrases: ", json.dumps(self.phrases, indent=4))
|
|
|
|
# Update phrase time at the end so waiting for the mutex doesn't
|
|
# cause us to split phrases.
|
|
self._phrase_time = now
|
|
|
|
# If enough time has passed between recordings, consider the
|
|
# last phrase complete and start a new one. Clear the current
|
|
# working audio buffer to start over with the new data.
|
|
if now - self._audio_source.time_of_last_input > timedelta(seconds=phrase_timeout):
|
|
|
|
# Only add a new phrase if we actually have data in the last
|
|
# one.
|
|
with self._phrases_list_mutex:
|
|
if self.phrases[-1] != "":
|
|
self.phrases.append("")
|
|
|
|
self._current_data = b''
|
|
|
|
# Automatically drop audio sources when we're finished with them.
|
|
if self._audio_source.is_done():
|
|
self._audio_source = None
|
|
|
|
|
|
def _update_thread_func(self):
|
|
while not self._should_stop:
|
|
time.sleep(0.1)
|
|
now = datetime.utcnow()
|
|
if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
|
self.run_transcription_update()
|
|
|
|
|
|
def update(self):
|
|
#print("update updating scrolling text...")
|
|
self.update_scrolling_text()
|
|
#print("update running transcription update...")
|
|
|
|
now = datetime.utcnow()
|
|
#if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
|
# self.run_transcription_update()
|
|
|
|
def update_scrolling_text(self):
|
|
|
|
#print("update_scrolling_text 1")
|
|
|
|
# Combine all the known phrases.
|
|
with self._phrases_list_mutex:
|
|
rolling_text_target = " ".join(self.phrases).strip()[-160:]
|
|
|
|
#print("update_scrolling_text 2")
|
|
|
|
if rolling_text_target != self.scrolling_text:
|
|
|
|
#print("update_scrolling_text 3")
|
|
|
|
# Start the diff off.
|
|
new_rolling_output_text = diffstuff.onestepchange(
|
|
self.scrolling_text, rolling_text_target)
|
|
|
|
#print("update_scrolling_text 4")
|
|
|
|
# Chop off the start all at once. It's not needed for the animation
|
|
# to look good.
|
|
#print("update_scrolling_text - ", self.scrolling_text)
|
|
#print("update_scrolling_text - ", new_rolling_output_text)
|
|
while self.scrolling_text.endswith(new_rolling_output_text):
|
|
|
|
#print("update_scrolling_text - start - ", self.scrolling_text)
|
|
#print("update_scrolling_text - end - ", new_rolling_output_text)
|
|
|
|
new_rolling_output_text = diffstuff.onestepchange(
|
|
new_rolling_output_text, rolling_text_target)
|
|
|
|
#print("update_scrolling_text 5")
|
|
|
|
# Set the new text.
|
|
self.scrolling_text = new_rolling_output_text
|
|
|
|
#print("update_scrolling_text 6")
|
|
|
|
# Just jump ahead if we're still too far behind.
|
|
# FIXME: Hardcoded value.
|
|
if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80:
|
|
self.scrolling_text = rolling_text_target
|
|
|
|
#print("update_scrolling_text 7")
|
|
|
|
#print("%s: %s" % (self.username, self.scrolling_text))
|
|
|
|
#print("update_scrolling_text 8")
|
|
|
|
#print("update_scrolling_text done")
|