#!/usr/bin/python3 # Recent phrases to include in the text buffer before the current transcription. recent_phrase_count = 8 # Seconds of silence before we start a new phrase. phrase_timeout = 1.0 # Higher is less restrictive on what it lets pass through. no_speech_prob_threshold = 0.05 # 0.15 # Minimum number of seconds before we fire off the model again. min_time_between_updates = 2 import numpy as np import speech_recognition import whisper import torch import wave from datetime import datetime, timedelta import json import diffstuff import threading import time from queue import Queue _audio_model = whisper.load_model("medium.en") # "large" _audio_model_mutex = threading.Lock() # For debugging... #wave_out = wave.open("wave.wav", "wb") #wave_out.setnchannels(1) #wave_out.setframerate(16000) #wave_out.setsampwidth(2) class Transcriber: def __init__(self): self._audio_source = None # Audio data for the current phrase. self._current_data = b'' self.phrases = [""] self.scrolling_text = "" # Time since the last data came in for the current phrase. self._phrase_time = datetime.utcnow() # Last time that we ran the model. self._last_model_time = datetime.utcnow() self._should_stop = False self._phrases_list_mutex = threading.Lock() self._update_thread = threading.Thread( target=self._update_thread_func, daemon=True) self._update_thread.start() self.username = "" self._audio_source_queue = Queue() def stop(self): self._should_stop = True def set_source(self, source): self._audio_source_queue.put(source) def phrase_probably_silent(self): """Whisper hallucinates a LOT on silence, so let's just ignore stuff that's mostly silence. First line of defense here.""" threshold = 100 threshold_pass = 0 threshold_fail = 0 avg = 0 for k in self._current_data: avg += k if(abs(k)) > threshold: threshold_pass += 1 else: threshold_fail += 1 avg = avg / len(self._current_data) threshold_pct = threshold_pass / len(self._current_data) if threshold_pct < 0.1: return True return False def run_transcription_update(self): #print("run_transcription_update Checking queue...") # Switch to whatever the newest source is. if self._audio_source == None and not self._audio_source_queue.empty(): self._audio_source = self._audio_source_queue.get() if not self._audio_source: #print("run_transcription_update returning early...") return now = datetime.utcnow() if not self._audio_source.data_queue.empty(): # We got some new data. Let's process it! # Get all the new data since last tick. new_data = [] with self._audio_source._data_mutex: while not self._audio_source.data_queue.empty(): new_packet = self._audio_source.data_queue.get() new_data.append(new_packet) new_data_joined = b''.join(new_data) # For debugging... #wave_out.writeframes(new_data_joined) # Append it to the current buffer. self._current_data = self._current_data + new_data_joined # if self.phrase_probably_silent(): # with self._phrases_list_mutex: # self.phrases[-1] = "" # else: # Convert in-ram buffer to something the model can use # directly without needing a temp file. Convert data from 16 # bit wide integers to floating point with a width of 32 # bits. Clamp the audio stream frequency to a PCM wavelength # compatible default of 32768hz max. audio_np = np.frombuffer( self._current_data[-16000 * 10:], dtype=np.int16).astype(np.float32) / 32768.0 #print("run_transcription_update About to run transcription...") # Run the transcription model, and extract the text. with _audio_model_mutex: #print("Transcribe start ", len(self._current_data)) result = _audio_model.transcribe( audio_np, fp16=torch.cuda.is_available(), word_timestamps=True, hallucination_silence_threshold=2) #print("Transcribe end") self._last_model_time = now with self._phrases_list_mutex: wave_out = wave.open("tmp/wave%0.4d.wav" % len(self.phrases), "wb") wave_out.setnchannels(1) wave_out.setframerate(16000) wave_out.setsampwidth(2) wave_out.writeframes(self._current_data) # Filter out text segments with a high no_speech_prob. combined_text = "" for seg in result["segments"]: if seg["no_speech_prob"] <= no_speech_prob_threshold: combined_text += seg["text"] text = combined_text.strip() # # FIXME: text = result["text"] with self._phrases_list_mutex: self.phrases[-1] = text #print("phrases: ", json.dumps(self.phrases, indent=4)) # Update phrase time at the end so waiting for the mutex doesn't # cause us to split phrases. self._phrase_time = now # If enough time has passed between recordings, consider the # last phrase complete and start a new one. Clear the current # working audio buffer to start over with the new data. if now - self._audio_source.time_of_last_input > timedelta(seconds=phrase_timeout): # Only add a new phrase if we actually have data in the last # one. with self._phrases_list_mutex: if self.phrases[-1] != "": self.phrases.append("") self._current_data = b'' # Automatically drop audio sources when we're finished with them. if self._audio_source.is_done(): self._audio_source = None def _update_thread_func(self): while not self._should_stop: time.sleep(0.1) now = datetime.utcnow() if self._last_model_time + timedelta(seconds=min_time_between_updates) < now: self.run_transcription_update() def update(self): #print("update updating scrolling text...") self.update_scrolling_text() #print("update running transcription update...") now = datetime.utcnow() #if self._last_model_time + timedelta(seconds=min_time_between_updates) < now: # self.run_transcription_update() def update_scrolling_text(self): #print("update_scrolling_text 1") # Combine all the known phrases. with self._phrases_list_mutex: rolling_text_target = " ".join(self.phrases).strip()[-160:] #print("update_scrolling_text 2") if rolling_text_target != self.scrolling_text: #print("update_scrolling_text 3") # Start the diff off. new_rolling_output_text = diffstuff.onestepchange( self.scrolling_text, rolling_text_target) #print("update_scrolling_text 4") # Chop off the start all at once. It's not needed for the animation # to look good. #print("update_scrolling_text - ", self.scrolling_text) #print("update_scrolling_text - ", new_rolling_output_text) while self.scrolling_text.endswith(new_rolling_output_text): #print("update_scrolling_text - start - ", self.scrolling_text) #print("update_scrolling_text - end - ", new_rolling_output_text) new_rolling_output_text = diffstuff.onestepchange( new_rolling_output_text, rolling_text_target) #print("update_scrolling_text 5") # Set the new text. self.scrolling_text = new_rolling_output_text #print("update_scrolling_text 6") # Just jump ahead if we're still too far behind. # FIXME: Hardcoded value. if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80: self.scrolling_text = rolling_text_target #print("update_scrolling_text 7") #print("%s: %s" % (self.username, self.scrolling_text)) #print("update_scrolling_text 8") #print("update_scrolling_text done")