From 2fcbd16bd9b410cefa7ac4e4deb5d3d93884696b Mon Sep 17 00:00:00 2001 From: Kiri Date: Fri, 22 Aug 2025 06:46:03 -0700 Subject: [PATCH] asdfasdfsdf --- audiosource.py | 2 +- diffstuff.py | 19 +++- transcribe_demo.py | 146 ++++++++++++++++++++++------ transcriber.py | 233 ++++++++++++++++++++++++++++++++++----------- 4 files changed, 311 insertions(+), 89 deletions(-) diff --git a/audiosource.py b/audiosource.py index 05fe474..73e6855 100644 --- a/audiosource.py +++ b/audiosource.py @@ -159,7 +159,7 @@ class OpusStreamAudioSource(AudioSource): except Exception as e: # Probably disconnected. We don't care. Just clean up. # FIXME: Limit exception to socket errors. - pass + print("INPUT EXCEPTION: ", e) print("input thread done") diff --git a/diffstuff.py b/diffstuff.py index 2173c92..96bc2d2 100644 --- a/diffstuff.py +++ b/diffstuff.py @@ -6,23 +6,38 @@ def onestepchange(start, dest): ret = "" for i, s in enumerate(difflib.ndiff(start, dest)): - # print(i) - # print(s) + #print("i: ", i) + #print("S: ", s) + # Remove a character from the start. if s[0] == '-': + #print("ret1") return ret + start[i+1:] + # Add a character. if s[1] == '+': + #print("ret2") return ret + s[-1] + start[i:] + # Keep moving through the stream. ret = ret + s[-1] + # If we're at the length of the starting string plus one, then we've + # added our one character. Let's bounce. if len(ret) > len(start): + #print("ret3") return ret + if ret[i] != start[i]: + #print("ret4") return ret + start[i:] + + # Hack. + if ret == "": + return dest + #print("ret5 - ret") return ret def countsteps(start, dest): diff --git a/transcribe_demo.py b/transcribe_demo.py index 5c96cf6..0fb55dc 100644 --- a/transcribe_demo.py +++ b/transcribe_demo.py @@ -6,6 +6,18 @@ recent_phrase_count = 8 # How real time the recording is in seconds. record_timeout = 2 +# Delete Discord users after a minute of no-activity. +discord_transcriber_timeout = 60 # 60 + + +import socket + +# Create socket for listening for incoming Opus audio streams from the Discord +# bot. +opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +opus_server_socket.bind(("127.0.0.1", 9967)) +opus_server_socket.listen() + import argparse import os @@ -25,7 +37,6 @@ import pygame import wave #from pyogg.opus import OpusEncoder -import socket import select import time import json @@ -36,50 +47,129 @@ import diffstuff import audiosource from transcriber import Transcriber -pygame_font_height = 16 + + + + +pygame_font_height = 32 pygame.init() -pygame_display_surface = pygame.display.set_mode((1280, pygame_font_height * 2)) +pygame_display_surface = pygame.display.set_mode((960-75, pygame_font_height * 2 * 2.5)) pygame.display.set_caption("Transcription") pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height) - -opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -print("binding") -opus_server_socket.bind(("127.0.0.1", 9967)) -print("set non blocking") -#opus_server_socket.setblocking(False) -print("listening") -opus_server_socket.listen() - - wave_out = wave.open("wave.wav", "wb") wave_out.setnchannels(1) wave_out.setframerate(16000) wave_out.setsampwidth(2) -transcriber = Transcriber() -mic_source = audiosource.MicrophoneAudioSource() +transcribers = [] + +transcriber1 = Transcriber() +mic_source1 = audiosource.MicrophoneAudioSource() +transcriber1.set_source(mic_source1) +transcriber1.username = "Kiri" +transcribers.append(transcriber1) + +# transcriber2 = Transcriber() +# mic_source2 = audiosource.MicrophoneAudioSource() +# transcriber2.set_source(mic_source2) +# transcriber2.username = "Kiri2" +# transcribers.append(transcriber2) + + + +discord_transcribers_per_user_id = {} while True: - print("looping") - # Wait for a connection. - while True: - s = select.select([opus_server_socket], [], [], 0) - time.sleep(0.01) - if len(s[0]): - accepted_socket, addr = opus_server_socket.accept() - print(accepted_socket) - break + # Check for new opus connections. + #print("Checking for new connections...") + s = select.select([opus_server_socket], [], [], 0) + if len(s[0]): + accepted_socket, addr = opus_server_socket.accept() + #print("Accepted new Opus stream: ", accepted_socket) + new_stream = audiosource.OpusStreamAudioSource(accepted_socket) + if (new_stream._user_info["userId"] in discord_transcribers_per_user_id): + discord_transcribers_per_user_id[new_stream._user_info["userId"]].set_source(new_stream) + else: + new_transcriber = Transcriber() + new_transcriber.set_source(new_stream) + discord_transcribers_per_user_id[new_stream._user_info["userId"]] = new_transcriber + new_transcriber.username = new_stream._user_info["displayName"] - opusSource = audiosource.OpusStreamAudioSource(accepted_socket) + if not new_transcriber in transcribers: + transcribers.append(new_transcriber) - transcriber.set_source(opusSource) - while not opusSource.is_done(): - time.sleep(0.1) + removal_queue = [] + + # Run updates. + print("Running updates...") + for transcriber in transcribers: + #print("Running updates for... ", transcriber.username) transcriber.update() + #print("Done running updates for... ", transcriber.username) + + if transcriber._phrase_time + timedelta(seconds=discord_transcriber_timeout) < datetime.utcnow(): + if transcriber._audio_source == None: + removal_queue.append(transcriber) + + #print("Running removals...") + + # Note that this will not remove them from discord_transcribers_per_user_id. + # It's probably fine, though. + #print("Running removals...") + for removal in removal_queue: + #print("Removing inactive user: ", removal.username) + transcribers.remove(removal) + + # Sleep. + print("Sleeping...") + time.sleep(0.05) + + #print("Rendering...") + + # Do rendering. + pygame_display_surface.fill((0, 0, 0)) + + # Render text. + for transcriber in transcribers: + pygame_text_surface = pygame_font.render(transcriber.scrolling_text, (0, 0, 0), (255, 255, 255)) + + pygame_text_rect = pygame_text_surface.get_rect() + pygame_text_rect.center = ( + pygame_display_surface.get_width() / 2, + pygame_font_height * (1 + transcribers.index(transcriber))) + pygame_text_rect.right = pygame_display_surface.get_width() + + pygame_display_surface.blit(pygame_text_surface, pygame_text_rect) + + # Render a background for the names. + fill_rect = pygame_display_surface.get_rect() + fill_rect.width = 220 + fill_rect.left = 0 + pygame_display_surface.fill((0, 0, 0), fill_rect) + + # Render names. + for transcriber in transcribers: + + username_for_display = "" + for c in transcriber.username: + if ord(c) <= 127: + username_for_display += c + + pygame_username_surface = pygame_font.render(username_for_display, (0, 0, 0), (255, 255, 255)) + + pygame_text_rect = pygame_username_surface.get_rect() + pygame_text_rect.center = ( + pygame_display_surface.get_width() / 2, + pygame_font_height * (1 + transcribers.index(transcriber))) + pygame_text_rect.left = 16 + + pygame_display_surface.blit(pygame_username_surface, pygame_text_rect) + + pygame.display.update() exit(0) diff --git a/transcriber.py b/transcriber.py index 63fa42c..116924f 100644 --- a/transcriber.py +++ b/transcriber.py @@ -6,8 +6,11 @@ recent_phrase_count = 8 # Seconds of silence before we start a new phrase. phrase_timeout = 3 -# Higher is more restrictive on what it lets pass through. -no_speech_prob_threshold = 0.2 +# Higher is less restrictive on what it lets pass through. +no_speech_prob_threshold = 0.25 # 0.15 + +# Minimum number of seconds before we fire off the model again. +min_time_between_updates = 2 import numpy as np import speech_recognition @@ -16,14 +19,19 @@ import torch import wave from datetime import datetime, timedelta import json +import diffstuff +import threading +import time +from queue import Queue _audio_model = whisper.load_model("medium.en") # "large" +_audio_model_mutex = threading.Lock() # For debugging... -# wave_out = wave.open("wave.wav", "wb") -# wave_out.setnchannels(1) -# wave_out.setframerate(16000) -# wave_out.setsampwidth(2) +#wave_out = wave.open("wave.wav", "wb") +#wave_out.setnchannels(1) +#wave_out.setframerate(16000) +#wave_out.setsampwidth(2) class Transcriber: @@ -35,11 +43,31 @@ class Transcriber: self.phrases = [""] + self.scrolling_text = "" + # Time since the last data came in for the current phrase. self._phrase_time = datetime.utcnow() + # Last time that we ran the model. + self._last_model_time = datetime.utcnow() + + self._should_stop = False + + self._phrases_list_mutex = threading.Lock() + + self._update_thread = threading.Thread( + target=self._update_thread_func, daemon=True) + self._update_thread.start() + + self.username = "" + + self._audio_source_queue = Queue() + + def stop(self): + self._should_stop = True + def set_source(self, source): - self._audio_source = source + self._audio_source_queue.put(source) def phrase_probably_silent(self): """Whisper hallucinates a LOT on silence, so let's just ignore stuff @@ -64,72 +92,161 @@ class Transcriber: return False - def update(self): - + def run_transcription_update(self): + + #print("run_transcription_update Checking queue...") + + # Switch to whatever the newest source is. + if self._audio_source == None and not self._audio_source_queue.empty(): + self._audio_source = self._audio_source_queue.get() + + if not self._audio_source: + #print("run_transcription_update returning early...") + return + now = datetime.utcnow() + if not self._audio_source.data_queue.empty(): - if self._audio_source: + # We got some new data. Let's process it! - if not self._audio_source.data_queue.empty(): + # If enough time has passed between recordings, consider the + # last phrase complete and start a new one. Clear the current + # working audio buffer to start over with the new data. + if self._phrase_time and now - self._phrase_time > timedelta(seconds=phrase_timeout): - # We got some new data. Let's process it! - - # If enough time has passed between recordings, consider the - # last phrase complete and start a new one. Clear the current - # working audio buffer to start over with the new data. - if self._phrase_time and now - self._phrase_time > timedelta(seconds=phrase_timeout): - - # Only add a new phrase if we actually have data in the last - # one. + # Only add a new phrase if we actually have data in the last + # one. + with self._phrases_list_mutex: if self.phrases[-1] != "": self.phrases.append("") - self._current_data = b'' + self._current_data = b'' - self._phrase_time = now + # Get all the new data since last tick, + new_data = [] + while not self._audio_source.data_queue.empty(): + new_packet = self._audio_source.data_queue.get() + new_data.append(new_packet) + new_data_joined = b''.join(new_data) + + # For debugging... + #wave_out.writeframes(new_data_joined) + + # Append it to the current buffer. + self._current_data = self._current_data + new_data_joined - # Get all the new data since last tick, - new_data = [] - while not self._audio_source.data_queue.empty(): - new_packet = self._audio_source.data_queue.get() - new_data.append(new_packet) - new_data_joined = b''.join(new_data) - - # For debugging... - #wave_out.writeframes(new_data_joined) - - # Append it to the current buffer. - self._current_data = self._current_data + new_data_joined + # if self.phrase_probably_silent(): + # with self._phrases_list_mutex: + # self.phrases[-1] = "" + # else: - if self.phrase_probably_silent(): - self.phrases[-1] = "" - else: + # Convert in-ram buffer to something the model can use + # directly without needing a temp file. Convert data from 16 + # bit wide integers to floating point with a width of 32 + # bits. Clamp the audio stream frequency to a PCM wavelength + # compatible default of 32768hz max. + audio_np = np.frombuffer( + self._current_data, dtype=np.int16).astype(np.float32) / 32768.0 - # Convert in-ram buffer to something the model can use - # directly without needing a temp file. Convert data from 16 - # bit wide integers to floating point with a width of 32 - # bits. Clamp the audio stream frequency to a PCM wavelength - # compatible default of 32768hz max. - audio_np = np.frombuffer( - self._current_data, dtype=np.int16).astype(np.float32) / 32768.0 + #print("run_transcription_update About to run transcription...") - # Run the transcription model, and extract the text. - result = _audio_model.transcribe( - audio_np, fp16=torch.cuda.is_available()) + # Run the transcription model, and extract the text. + with _audio_model_mutex: + #print("Transcribe start ", len(self._current_data)) + result = _audio_model.transcribe( + audio_np, fp16=torch.cuda.is_available()) + #print("Transcribe end") + self._last_model_time = now - # Filter out text segments with a high no_speech_prob. - combined_text = "" - for seg in result["segments"]: - if seg["no_speech_prob"] <= no_speech_prob_threshold: - combined_text += seg["text"] + # Filter out text segments with a high no_speech_prob. + combined_text = "" + for seg in result["segments"]: + if seg["no_speech_prob"] <= no_speech_prob_threshold: + combined_text += seg["text"] - text = combined_text.strip() + text = combined_text.strip() - self.phrases[-1] = text + # FIXME: + text = result["text"] - print("phrases: ", json.dumps(self.phrases, indent=4)) + with self._phrases_list_mutex: + self.phrases[-1] = text + #print("phrases: ", json.dumps(self.phrases, indent=4)) + + # Update phrase time at the end so waiting for the mutex doesn't + # cause us to split phrases. + self._phrase_time = now + + # Automatically drop audio sources when we're finished with them. + if self._audio_source.is_done(): + self._audio_source = None - # Automatically drop audio sources when we're finished with them. - if self._audio_source.is_done(): - self._audio_source = None \ No newline at end of file + def _update_thread_func(self): + while not self._should_stop: + time.sleep(0.1) + now = datetime.utcnow() + if self._last_model_time + timedelta(seconds=min_time_between_updates) < now: + self.run_transcription_update() + + + def update(self): + #print("update updating scrolling text...") + self.update_scrolling_text() + #print("update running transcription update...") + + now = datetime.utcnow() + #if self._last_model_time + timedelta(seconds=min_time_between_updates) < now: + # self.run_transcription_update() + + def update_scrolling_text(self): + + #print("update_scrolling_text 1") + + # Combine all the known phrases. + with self._phrases_list_mutex: + rolling_text_target = " ".join(self.phrases).strip()[-160:] + + #print("update_scrolling_text 2") + + if rolling_text_target != self.scrolling_text: + + #print("update_scrolling_text 3") + + # Start the diff off. + new_rolling_output_text = diffstuff.onestepchange( + self.scrolling_text, rolling_text_target) + + #print("update_scrolling_text 4") + + # Chop off the start all at once. It's not needed for the animation + # to look good. + #print("update_scrolling_text - ", self.scrolling_text) + #print("update_scrolling_text - ", new_rolling_output_text) + while self.scrolling_text.endswith(new_rolling_output_text): + + #print("update_scrolling_text - start - ", self.scrolling_text) + #print("update_scrolling_text - end - ", new_rolling_output_text) + + new_rolling_output_text = diffstuff.onestepchange( + new_rolling_output_text, rolling_text_target) + + #print("update_scrolling_text 5") + + # Set the new text. + self.scrolling_text = new_rolling_output_text + + #print("update_scrolling_text 6") + + # Just jump ahead if we're still too far behind. + # FIXME: Hardcoded value. + if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80: + self.scrolling_text = rolling_text_target + + #print("update_scrolling_text 7") + + #print("%s: %s" % (self.username, self.scrolling_text)) + + #print("update_scrolling_text 8") + + #print("update_scrolling_text done")