asdfasdfsdf

This commit is contained in:
Kiri 2025-08-22 06:46:03 -07:00
parent 4dc7385239
commit 2fcbd16bd9
4 changed files with 311 additions and 89 deletions

View File

@ -159,7 +159,7 @@ class OpusStreamAudioSource(AudioSource):
except Exception as e: except Exception as e:
# Probably disconnected. We don't care. Just clean up. # Probably disconnected. We don't care. Just clean up.
# FIXME: Limit exception to socket errors. # FIXME: Limit exception to socket errors.
pass print("INPUT EXCEPTION: ", e)
print("input thread done") print("input thread done")

View File

@ -6,23 +6,38 @@ def onestepchange(start, dest):
ret = "" ret = ""
for i, s in enumerate(difflib.ndiff(start, dest)): for i, s in enumerate(difflib.ndiff(start, dest)):
# print(i) #print("i: ", i)
# print(s) #print("S: ", s)
# Remove a character from the start.
if s[0] == '-': if s[0] == '-':
#print("ret1")
return ret + start[i+1:] return ret + start[i+1:]
# Add a character.
if s[1] == '+': if s[1] == '+':
#print("ret2")
return ret + s[-1] + start[i:] return ret + s[-1] + start[i:]
# Keep moving through the stream.
ret = ret + s[-1] ret = ret + s[-1]
# If we're at the length of the starting string plus one, then we've
# added our one character. Let's bounce.
if len(ret) > len(start): if len(ret) > len(start):
#print("ret3")
return ret return ret
if ret[i] != start[i]: if ret[i] != start[i]:
#print("ret4")
return ret + start[i:] return ret + start[i:]
# Hack.
if ret == "":
return dest
#print("ret5 - ret")
return ret return ret
def countsteps(start, dest): def countsteps(start, dest):

View File

@ -6,6 +6,18 @@ recent_phrase_count = 8
# How real time the recording is in seconds. # How real time the recording is in seconds.
record_timeout = 2 record_timeout = 2
# Delete Discord users after a minute of no-activity.
discord_transcriber_timeout = 60 # 60
import socket
# Create socket for listening for incoming Opus audio streams from the Discord
# bot.
opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
opus_server_socket.bind(("127.0.0.1", 9967))
opus_server_socket.listen()
import argparse import argparse
import os import os
@ -25,7 +37,6 @@ import pygame
import wave import wave
#from pyogg.opus import OpusEncoder #from pyogg.opus import OpusEncoder
import socket
import select import select
import time import time
import json import json
@ -36,50 +47,129 @@ import diffstuff
import audiosource import audiosource
from transcriber import Transcriber from transcriber import Transcriber
pygame_font_height = 16
pygame_font_height = 32
pygame.init() pygame.init()
pygame_display_surface = pygame.display.set_mode((1280, pygame_font_height * 2)) pygame_display_surface = pygame.display.set_mode((960-75, pygame_font_height * 2 * 2.5))
pygame.display.set_caption("Transcription") pygame.display.set_caption("Transcription")
pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height) pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height)
opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print("binding")
opus_server_socket.bind(("127.0.0.1", 9967))
print("set non blocking")
#opus_server_socket.setblocking(False)
print("listening")
opus_server_socket.listen()
wave_out = wave.open("wave.wav", "wb") wave_out = wave.open("wave.wav", "wb")
wave_out.setnchannels(1) wave_out.setnchannels(1)
wave_out.setframerate(16000) wave_out.setframerate(16000)
wave_out.setsampwidth(2) wave_out.setsampwidth(2)
transcriber = Transcriber() transcribers = []
mic_source = audiosource.MicrophoneAudioSource()
transcriber1 = Transcriber()
mic_source1 = audiosource.MicrophoneAudioSource()
transcriber1.set_source(mic_source1)
transcriber1.username = "Kiri"
transcribers.append(transcriber1)
# transcriber2 = Transcriber()
# mic_source2 = audiosource.MicrophoneAudioSource()
# transcriber2.set_source(mic_source2)
# transcriber2.username = "Kiri2"
# transcribers.append(transcriber2)
discord_transcribers_per_user_id = {}
while True: while True:
print("looping")
# Wait for a connection. # Check for new opus connections.
while True: #print("Checking for new connections...")
s = select.select([opus_server_socket], [], [], 0) s = select.select([opus_server_socket], [], [], 0)
time.sleep(0.01)
if len(s[0]): if len(s[0]):
accepted_socket, addr = opus_server_socket.accept() accepted_socket, addr = opus_server_socket.accept()
print(accepted_socket) #print("Accepted new Opus stream: ", accepted_socket)
break new_stream = audiosource.OpusStreamAudioSource(accepted_socket)
if (new_stream._user_info["userId"] in discord_transcribers_per_user_id):
discord_transcribers_per_user_id[new_stream._user_info["userId"]].set_source(new_stream)
else:
new_transcriber = Transcriber()
new_transcriber.set_source(new_stream)
discord_transcribers_per_user_id[new_stream._user_info["userId"]] = new_transcriber
new_transcriber.username = new_stream._user_info["displayName"]
opusSource = audiosource.OpusStreamAudioSource(accepted_socket) if not new_transcriber in transcribers:
transcribers.append(new_transcriber)
transcriber.set_source(opusSource) removal_queue = []
while not opusSource.is_done():
time.sleep(0.1) # Run updates.
print("Running updates...")
for transcriber in transcribers:
#print("Running updates for... ", transcriber.username)
transcriber.update() transcriber.update()
#print("Done running updates for... ", transcriber.username)
if transcriber._phrase_time + timedelta(seconds=discord_transcriber_timeout) < datetime.utcnow():
if transcriber._audio_source == None:
removal_queue.append(transcriber)
#print("Running removals...")
# Note that this will not remove them from discord_transcribers_per_user_id.
# It's probably fine, though.
#print("Running removals...")
for removal in removal_queue:
#print("Removing inactive user: ", removal.username)
transcribers.remove(removal)
# Sleep.
print("Sleeping...")
time.sleep(0.05)
#print("Rendering...")
# Do rendering.
pygame_display_surface.fill((0, 0, 0))
# Render text.
for transcriber in transcribers:
pygame_text_surface = pygame_font.render(transcriber.scrolling_text, (0, 0, 0), (255, 255, 255))
pygame_text_rect = pygame_text_surface.get_rect()
pygame_text_rect.center = (
pygame_display_surface.get_width() / 2,
pygame_font_height * (1 + transcribers.index(transcriber)))
pygame_text_rect.right = pygame_display_surface.get_width()
pygame_display_surface.blit(pygame_text_surface, pygame_text_rect)
# Render a background for the names.
fill_rect = pygame_display_surface.get_rect()
fill_rect.width = 220
fill_rect.left = 0
pygame_display_surface.fill((0, 0, 0), fill_rect)
# Render names.
for transcriber in transcribers:
username_for_display = ""
for c in transcriber.username:
if ord(c) <= 127:
username_for_display += c
pygame_username_surface = pygame_font.render(username_for_display, (0, 0, 0), (255, 255, 255))
pygame_text_rect = pygame_username_surface.get_rect()
pygame_text_rect.center = (
pygame_display_surface.get_width() / 2,
pygame_font_height * (1 + transcribers.index(transcriber)))
pygame_text_rect.left = 16
pygame_display_surface.blit(pygame_username_surface, pygame_text_rect)
pygame.display.update()
exit(0) exit(0)

View File

@ -6,8 +6,11 @@ recent_phrase_count = 8
# Seconds of silence before we start a new phrase. # Seconds of silence before we start a new phrase.
phrase_timeout = 3 phrase_timeout = 3
# Higher is more restrictive on what it lets pass through. # Higher is less restrictive on what it lets pass through.
no_speech_prob_threshold = 0.2 no_speech_prob_threshold = 0.25 # 0.15
# Minimum number of seconds before we fire off the model again.
min_time_between_updates = 2
import numpy as np import numpy as np
import speech_recognition import speech_recognition
@ -16,8 +19,13 @@ import torch
import wave import wave
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import diffstuff
import threading
import time
from queue import Queue
_audio_model = whisper.load_model("medium.en") # "large" _audio_model = whisper.load_model("medium.en") # "large"
_audio_model_mutex = threading.Lock()
# For debugging... # For debugging...
#wave_out = wave.open("wave.wav", "wb") #wave_out = wave.open("wave.wav", "wb")
@ -35,11 +43,31 @@ class Transcriber:
self.phrases = [""] self.phrases = [""]
self.scrolling_text = ""
# Time since the last data came in for the current phrase. # Time since the last data came in for the current phrase.
self._phrase_time = datetime.utcnow() self._phrase_time = datetime.utcnow()
# Last time that we ran the model.
self._last_model_time = datetime.utcnow()
self._should_stop = False
self._phrases_list_mutex = threading.Lock()
self._update_thread = threading.Thread(
target=self._update_thread_func, daemon=True)
self._update_thread.start()
self.username = ""
self._audio_source_queue = Queue()
def stop(self):
self._should_stop = True
def set_source(self, source): def set_source(self, source):
self._audio_source = source self._audio_source_queue.put(source)
def phrase_probably_silent(self): def phrase_probably_silent(self):
"""Whisper hallucinates a LOT on silence, so let's just ignore stuff """Whisper hallucinates a LOT on silence, so let's just ignore stuff
@ -64,12 +92,19 @@ class Transcriber:
return False return False
def update(self): def run_transcription_update(self):
#print("run_transcription_update Checking queue...")
# Switch to whatever the newest source is.
if self._audio_source == None and not self._audio_source_queue.empty():
self._audio_source = self._audio_source_queue.get()
if not self._audio_source:
#print("run_transcription_update returning early...")
return
now = datetime.utcnow() now = datetime.utcnow()
if self._audio_source:
if not self._audio_source.data_queue.empty(): if not self._audio_source.data_queue.empty():
# We got some new data. Let's process it! # We got some new data. Let's process it!
@ -81,13 +116,12 @@ class Transcriber:
# Only add a new phrase if we actually have data in the last # Only add a new phrase if we actually have data in the last
# one. # one.
with self._phrases_list_mutex:
if self.phrases[-1] != "": if self.phrases[-1] != "":
self.phrases.append("") self.phrases.append("")
self._current_data = b'' self._current_data = b''
self._phrase_time = now
# Get all the new data since last tick, # Get all the new data since last tick,
new_data = [] new_data = []
while not self._audio_source.data_queue.empty(): while not self._audio_source.data_queue.empty():
@ -101,9 +135,10 @@ class Transcriber:
# Append it to the current buffer. # Append it to the current buffer.
self._current_data = self._current_data + new_data_joined self._current_data = self._current_data + new_data_joined
if self.phrase_probably_silent(): # if self.phrase_probably_silent():
self.phrases[-1] = "" # with self._phrases_list_mutex:
else: # self.phrases[-1] = ""
# else:
# Convert in-ram buffer to something the model can use # Convert in-ram buffer to something the model can use
# directly without needing a temp file. Convert data from 16 # directly without needing a temp file. Convert data from 16
@ -113,9 +148,15 @@ class Transcriber:
audio_np = np.frombuffer( audio_np = np.frombuffer(
self._current_data, dtype=np.int16).astype(np.float32) / 32768.0 self._current_data, dtype=np.int16).astype(np.float32) / 32768.0
#print("run_transcription_update About to run transcription...")
# Run the transcription model, and extract the text. # Run the transcription model, and extract the text.
with _audio_model_mutex:
#print("Transcribe start ", len(self._current_data))
result = _audio_model.transcribe( result = _audio_model.transcribe(
audio_np, fp16=torch.cuda.is_available()) audio_np, fp16=torch.cuda.is_available())
#print("Transcribe end")
self._last_model_time = now
# Filter out text segments with a high no_speech_prob. # Filter out text segments with a high no_speech_prob.
combined_text = "" combined_text = ""
@ -125,11 +166,87 @@ class Transcriber:
text = combined_text.strip() text = combined_text.strip()
# FIXME:
text = result["text"]
with self._phrases_list_mutex:
self.phrases[-1] = text self.phrases[-1] = text
#print("phrases: ", json.dumps(self.phrases, indent=4))
print("phrases: ", json.dumps(self.phrases, indent=4)) # Update phrase time at the end so waiting for the mutex doesn't
# cause us to split phrases.
self._phrase_time = now
# Automatically drop audio sources when we're finished with them. # Automatically drop audio sources when we're finished with them.
if self._audio_source.is_done(): if self._audio_source.is_done():
self._audio_source = None self._audio_source = None
def _update_thread_func(self):
while not self._should_stop:
time.sleep(0.1)
now = datetime.utcnow()
if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
self.run_transcription_update()
def update(self):
#print("update updating scrolling text...")
self.update_scrolling_text()
#print("update running transcription update...")
now = datetime.utcnow()
#if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
# self.run_transcription_update()
def update_scrolling_text(self):
#print("update_scrolling_text 1")
# Combine all the known phrases.
with self._phrases_list_mutex:
rolling_text_target = " ".join(self.phrases).strip()[-160:]
#print("update_scrolling_text 2")
if rolling_text_target != self.scrolling_text:
#print("update_scrolling_text 3")
# Start the diff off.
new_rolling_output_text = diffstuff.onestepchange(
self.scrolling_text, rolling_text_target)
#print("update_scrolling_text 4")
# Chop off the start all at once. It's not needed for the animation
# to look good.
#print("update_scrolling_text - ", self.scrolling_text)
#print("update_scrolling_text - ", new_rolling_output_text)
while self.scrolling_text.endswith(new_rolling_output_text):
#print("update_scrolling_text - start - ", self.scrolling_text)
#print("update_scrolling_text - end - ", new_rolling_output_text)
new_rolling_output_text = diffstuff.onestepchange(
new_rolling_output_text, rolling_text_target)
#print("update_scrolling_text 5")
# Set the new text.
self.scrolling_text = new_rolling_output_text
#print("update_scrolling_text 6")
# Just jump ahead if we're still too far behind.
# FIXME: Hardcoded value.
if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80:
self.scrolling_text = rolling_text_target
#print("update_scrolling_text 7")
#print("%s: %s" % (self.username, self.scrolling_text))
#print("update_scrolling_text 8")
#print("update_scrolling_text done")