asdfasdfsdf
This commit is contained in:
parent
4dc7385239
commit
2fcbd16bd9
@ -159,7 +159,7 @@ class OpusStreamAudioSource(AudioSource):
|
||||
except Exception as e:
|
||||
# Probably disconnected. We don't care. Just clean up.
|
||||
# FIXME: Limit exception to socket errors.
|
||||
pass
|
||||
print("INPUT EXCEPTION: ", e)
|
||||
|
||||
print("input thread done")
|
||||
|
||||
|
19
diffstuff.py
19
diffstuff.py
@ -6,23 +6,38 @@ def onestepchange(start, dest):
|
||||
ret = ""
|
||||
|
||||
for i, s in enumerate(difflib.ndiff(start, dest)):
|
||||
# print(i)
|
||||
# print(s)
|
||||
#print("i: ", i)
|
||||
#print("S: ", s)
|
||||
|
||||
# Remove a character from the start.
|
||||
if s[0] == '-':
|
||||
#print("ret1")
|
||||
return ret + start[i+1:]
|
||||
|
||||
# Add a character.
|
||||
if s[1] == '+':
|
||||
#print("ret2")
|
||||
return ret + s[-1] + start[i:]
|
||||
|
||||
# Keep moving through the stream.
|
||||
ret = ret + s[-1]
|
||||
|
||||
# If we're at the length of the starting string plus one, then we've
|
||||
# added our one character. Let's bounce.
|
||||
if len(ret) > len(start):
|
||||
#print("ret3")
|
||||
return ret
|
||||
|
||||
|
||||
if ret[i] != start[i]:
|
||||
#print("ret4")
|
||||
return ret + start[i:]
|
||||
|
||||
# Hack.
|
||||
if ret == "":
|
||||
return dest
|
||||
|
||||
#print("ret5 - ret")
|
||||
return ret
|
||||
|
||||
def countsteps(start, dest):
|
||||
|
@ -6,6 +6,18 @@ recent_phrase_count = 8
|
||||
# How real time the recording is in seconds.
|
||||
record_timeout = 2
|
||||
|
||||
# Delete Discord users after a minute of no-activity.
|
||||
discord_transcriber_timeout = 60 # 60
|
||||
|
||||
|
||||
import socket
|
||||
|
||||
# Create socket for listening for incoming Opus audio streams from the Discord
|
||||
# bot.
|
||||
opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
opus_server_socket.bind(("127.0.0.1", 9967))
|
||||
opus_server_socket.listen()
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
@ -25,7 +37,6 @@ import pygame
|
||||
import wave
|
||||
#from pyogg.opus import OpusEncoder
|
||||
|
||||
import socket
|
||||
import select
|
||||
import time
|
||||
import json
|
||||
@ -36,50 +47,129 @@ import diffstuff
|
||||
import audiosource
|
||||
from transcriber import Transcriber
|
||||
|
||||
pygame_font_height = 16
|
||||
|
||||
|
||||
|
||||
|
||||
pygame_font_height = 32
|
||||
pygame.init()
|
||||
pygame_display_surface = pygame.display.set_mode((1280, pygame_font_height * 2))
|
||||
pygame_display_surface = pygame.display.set_mode((960-75, pygame_font_height * 2 * 2.5))
|
||||
pygame.display.set_caption("Transcription")
|
||||
pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height)
|
||||
|
||||
|
||||
|
||||
|
||||
opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
print("binding")
|
||||
opus_server_socket.bind(("127.0.0.1", 9967))
|
||||
print("set non blocking")
|
||||
#opus_server_socket.setblocking(False)
|
||||
print("listening")
|
||||
opus_server_socket.listen()
|
||||
|
||||
|
||||
wave_out = wave.open("wave.wav", "wb")
|
||||
wave_out.setnchannels(1)
|
||||
wave_out.setframerate(16000)
|
||||
wave_out.setsampwidth(2)
|
||||
|
||||
transcriber = Transcriber()
|
||||
mic_source = audiosource.MicrophoneAudioSource()
|
||||
transcribers = []
|
||||
|
||||
transcriber1 = Transcriber()
|
||||
mic_source1 = audiosource.MicrophoneAudioSource()
|
||||
transcriber1.set_source(mic_source1)
|
||||
transcriber1.username = "Kiri"
|
||||
transcribers.append(transcriber1)
|
||||
|
||||
# transcriber2 = Transcriber()
|
||||
# mic_source2 = audiosource.MicrophoneAudioSource()
|
||||
# transcriber2.set_source(mic_source2)
|
||||
# transcriber2.username = "Kiri2"
|
||||
# transcribers.append(transcriber2)
|
||||
|
||||
|
||||
|
||||
discord_transcribers_per_user_id = {}
|
||||
|
||||
while True:
|
||||
print("looping")
|
||||
|
||||
# Wait for a connection.
|
||||
while True:
|
||||
s = select.select([opus_server_socket], [], [], 0)
|
||||
time.sleep(0.01)
|
||||
if len(s[0]):
|
||||
accepted_socket, addr = opus_server_socket.accept()
|
||||
print(accepted_socket)
|
||||
break
|
||||
# Check for new opus connections.
|
||||
#print("Checking for new connections...")
|
||||
s = select.select([opus_server_socket], [], [], 0)
|
||||
if len(s[0]):
|
||||
accepted_socket, addr = opus_server_socket.accept()
|
||||
#print("Accepted new Opus stream: ", accepted_socket)
|
||||
new_stream = audiosource.OpusStreamAudioSource(accepted_socket)
|
||||
if (new_stream._user_info["userId"] in discord_transcribers_per_user_id):
|
||||
discord_transcribers_per_user_id[new_stream._user_info["userId"]].set_source(new_stream)
|
||||
else:
|
||||
new_transcriber = Transcriber()
|
||||
new_transcriber.set_source(new_stream)
|
||||
discord_transcribers_per_user_id[new_stream._user_info["userId"]] = new_transcriber
|
||||
new_transcriber.username = new_stream._user_info["displayName"]
|
||||
|
||||
opusSource = audiosource.OpusStreamAudioSource(accepted_socket)
|
||||
if not new_transcriber in transcribers:
|
||||
transcribers.append(new_transcriber)
|
||||
|
||||
transcriber.set_source(opusSource)
|
||||
while not opusSource.is_done():
|
||||
time.sleep(0.1)
|
||||
removal_queue = []
|
||||
|
||||
# Run updates.
|
||||
print("Running updates...")
|
||||
for transcriber in transcribers:
|
||||
#print("Running updates for... ", transcriber.username)
|
||||
transcriber.update()
|
||||
#print("Done running updates for... ", transcriber.username)
|
||||
|
||||
if transcriber._phrase_time + timedelta(seconds=discord_transcriber_timeout) < datetime.utcnow():
|
||||
if transcriber._audio_source == None:
|
||||
removal_queue.append(transcriber)
|
||||
|
||||
#print("Running removals...")
|
||||
|
||||
# Note that this will not remove them from discord_transcribers_per_user_id.
|
||||
# It's probably fine, though.
|
||||
#print("Running removals...")
|
||||
for removal in removal_queue:
|
||||
#print("Removing inactive user: ", removal.username)
|
||||
transcribers.remove(removal)
|
||||
|
||||
# Sleep.
|
||||
print("Sleeping...")
|
||||
time.sleep(0.05)
|
||||
|
||||
#print("Rendering...")
|
||||
|
||||
# Do rendering.
|
||||
pygame_display_surface.fill((0, 0, 0))
|
||||
|
||||
# Render text.
|
||||
for transcriber in transcribers:
|
||||
pygame_text_surface = pygame_font.render(transcriber.scrolling_text, (0, 0, 0), (255, 255, 255))
|
||||
|
||||
pygame_text_rect = pygame_text_surface.get_rect()
|
||||
pygame_text_rect.center = (
|
||||
pygame_display_surface.get_width() / 2,
|
||||
pygame_font_height * (1 + transcribers.index(transcriber)))
|
||||
pygame_text_rect.right = pygame_display_surface.get_width()
|
||||
|
||||
pygame_display_surface.blit(pygame_text_surface, pygame_text_rect)
|
||||
|
||||
# Render a background for the names.
|
||||
fill_rect = pygame_display_surface.get_rect()
|
||||
fill_rect.width = 220
|
||||
fill_rect.left = 0
|
||||
pygame_display_surface.fill((0, 0, 0), fill_rect)
|
||||
|
||||
# Render names.
|
||||
for transcriber in transcribers:
|
||||
|
||||
username_for_display = ""
|
||||
for c in transcriber.username:
|
||||
if ord(c) <= 127:
|
||||
username_for_display += c
|
||||
|
||||
pygame_username_surface = pygame_font.render(username_for_display, (0, 0, 0), (255, 255, 255))
|
||||
|
||||
pygame_text_rect = pygame_username_surface.get_rect()
|
||||
pygame_text_rect.center = (
|
||||
pygame_display_surface.get_width() / 2,
|
||||
pygame_font_height * (1 + transcribers.index(transcriber)))
|
||||
pygame_text_rect.left = 16
|
||||
|
||||
pygame_display_surface.blit(pygame_username_surface, pygame_text_rect)
|
||||
|
||||
pygame.display.update()
|
||||
|
||||
|
||||
exit(0)
|
||||
|
233
transcriber.py
233
transcriber.py
@ -6,8 +6,11 @@ recent_phrase_count = 8
|
||||
# Seconds of silence before we start a new phrase.
|
||||
phrase_timeout = 3
|
||||
|
||||
# Higher is more restrictive on what it lets pass through.
|
||||
no_speech_prob_threshold = 0.2
|
||||
# Higher is less restrictive on what it lets pass through.
|
||||
no_speech_prob_threshold = 0.25 # 0.15
|
||||
|
||||
# Minimum number of seconds before we fire off the model again.
|
||||
min_time_between_updates = 2
|
||||
|
||||
import numpy as np
|
||||
import speech_recognition
|
||||
@ -16,14 +19,19 @@ import torch
|
||||
import wave
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import diffstuff
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
_audio_model = whisper.load_model("medium.en") # "large"
|
||||
_audio_model_mutex = threading.Lock()
|
||||
|
||||
# For debugging...
|
||||
# wave_out = wave.open("wave.wav", "wb")
|
||||
# wave_out.setnchannels(1)
|
||||
# wave_out.setframerate(16000)
|
||||
# wave_out.setsampwidth(2)
|
||||
#wave_out = wave.open("wave.wav", "wb")
|
||||
#wave_out.setnchannels(1)
|
||||
#wave_out.setframerate(16000)
|
||||
#wave_out.setsampwidth(2)
|
||||
|
||||
class Transcriber:
|
||||
|
||||
@ -35,11 +43,31 @@ class Transcriber:
|
||||
|
||||
self.phrases = [""]
|
||||
|
||||
self.scrolling_text = ""
|
||||
|
||||
# Time since the last data came in for the current phrase.
|
||||
self._phrase_time = datetime.utcnow()
|
||||
|
||||
# Last time that we ran the model.
|
||||
self._last_model_time = datetime.utcnow()
|
||||
|
||||
self._should_stop = False
|
||||
|
||||
self._phrases_list_mutex = threading.Lock()
|
||||
|
||||
self._update_thread = threading.Thread(
|
||||
target=self._update_thread_func, daemon=True)
|
||||
self._update_thread.start()
|
||||
|
||||
self.username = ""
|
||||
|
||||
self._audio_source_queue = Queue()
|
||||
|
||||
def stop(self):
|
||||
self._should_stop = True
|
||||
|
||||
def set_source(self, source):
|
||||
self._audio_source = source
|
||||
self._audio_source_queue.put(source)
|
||||
|
||||
def phrase_probably_silent(self):
|
||||
"""Whisper hallucinates a LOT on silence, so let's just ignore stuff
|
||||
@ -64,72 +92,161 @@ class Transcriber:
|
||||
|
||||
return False
|
||||
|
||||
def update(self):
|
||||
|
||||
def run_transcription_update(self):
|
||||
|
||||
#print("run_transcription_update Checking queue...")
|
||||
|
||||
# Switch to whatever the newest source is.
|
||||
if self._audio_source == None and not self._audio_source_queue.empty():
|
||||
self._audio_source = self._audio_source_queue.get()
|
||||
|
||||
if not self._audio_source:
|
||||
#print("run_transcription_update returning early...")
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
if not self._audio_source.data_queue.empty():
|
||||
|
||||
if self._audio_source:
|
||||
# We got some new data. Let's process it!
|
||||
|
||||
if not self._audio_source.data_queue.empty():
|
||||
# If enough time has passed between recordings, consider the
|
||||
# last phrase complete and start a new one. Clear the current
|
||||
# working audio buffer to start over with the new data.
|
||||
if self._phrase_time and now - self._phrase_time > timedelta(seconds=phrase_timeout):
|
||||
|
||||
# We got some new data. Let's process it!
|
||||
|
||||
# If enough time has passed between recordings, consider the
|
||||
# last phrase complete and start a new one. Clear the current
|
||||
# working audio buffer to start over with the new data.
|
||||
if self._phrase_time and now - self._phrase_time > timedelta(seconds=phrase_timeout):
|
||||
|
||||
# Only add a new phrase if we actually have data in the last
|
||||
# one.
|
||||
# Only add a new phrase if we actually have data in the last
|
||||
# one.
|
||||
with self._phrases_list_mutex:
|
||||
if self.phrases[-1] != "":
|
||||
self.phrases.append("")
|
||||
|
||||
self._current_data = b''
|
||||
self._current_data = b''
|
||||
|
||||
self._phrase_time = now
|
||||
# Get all the new data since last tick,
|
||||
new_data = []
|
||||
while not self._audio_source.data_queue.empty():
|
||||
new_packet = self._audio_source.data_queue.get()
|
||||
new_data.append(new_packet)
|
||||
new_data_joined = b''.join(new_data)
|
||||
|
||||
# For debugging...
|
||||
#wave_out.writeframes(new_data_joined)
|
||||
|
||||
# Append it to the current buffer.
|
||||
self._current_data = self._current_data + new_data_joined
|
||||
|
||||
# Get all the new data since last tick,
|
||||
new_data = []
|
||||
while not self._audio_source.data_queue.empty():
|
||||
new_packet = self._audio_source.data_queue.get()
|
||||
new_data.append(new_packet)
|
||||
new_data_joined = b''.join(new_data)
|
||||
|
||||
# For debugging...
|
||||
#wave_out.writeframes(new_data_joined)
|
||||
|
||||
# Append it to the current buffer.
|
||||
self._current_data = self._current_data + new_data_joined
|
||||
# if self.phrase_probably_silent():
|
||||
# with self._phrases_list_mutex:
|
||||
# self.phrases[-1] = ""
|
||||
# else:
|
||||
|
||||
if self.phrase_probably_silent():
|
||||
self.phrases[-1] = ""
|
||||
else:
|
||||
# Convert in-ram buffer to something the model can use
|
||||
# directly without needing a temp file. Convert data from 16
|
||||
# bit wide integers to floating point with a width of 32
|
||||
# bits. Clamp the audio stream frequency to a PCM wavelength
|
||||
# compatible default of 32768hz max.
|
||||
audio_np = np.frombuffer(
|
||||
self._current_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
# Convert in-ram buffer to something the model can use
|
||||
# directly without needing a temp file. Convert data from 16
|
||||
# bit wide integers to floating point with a width of 32
|
||||
# bits. Clamp the audio stream frequency to a PCM wavelength
|
||||
# compatible default of 32768hz max.
|
||||
audio_np = np.frombuffer(
|
||||
self._current_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
#print("run_transcription_update About to run transcription...")
|
||||
|
||||
# Run the transcription model, and extract the text.
|
||||
result = _audio_model.transcribe(
|
||||
audio_np, fp16=torch.cuda.is_available())
|
||||
# Run the transcription model, and extract the text.
|
||||
with _audio_model_mutex:
|
||||
#print("Transcribe start ", len(self._current_data))
|
||||
result = _audio_model.transcribe(
|
||||
audio_np, fp16=torch.cuda.is_available())
|
||||
#print("Transcribe end")
|
||||
self._last_model_time = now
|
||||
|
||||
# Filter out text segments with a high no_speech_prob.
|
||||
combined_text = ""
|
||||
for seg in result["segments"]:
|
||||
if seg["no_speech_prob"] <= no_speech_prob_threshold:
|
||||
combined_text += seg["text"]
|
||||
# Filter out text segments with a high no_speech_prob.
|
||||
combined_text = ""
|
||||
for seg in result["segments"]:
|
||||
if seg["no_speech_prob"] <= no_speech_prob_threshold:
|
||||
combined_text += seg["text"]
|
||||
|
||||
text = combined_text.strip()
|
||||
text = combined_text.strip()
|
||||
|
||||
self.phrases[-1] = text
|
||||
# FIXME:
|
||||
text = result["text"]
|
||||
|
||||
print("phrases: ", json.dumps(self.phrases, indent=4))
|
||||
with self._phrases_list_mutex:
|
||||
self.phrases[-1] = text
|
||||
#print("phrases: ", json.dumps(self.phrases, indent=4))
|
||||
|
||||
# Update phrase time at the end so waiting for the mutex doesn't
|
||||
# cause us to split phrases.
|
||||
self._phrase_time = now
|
||||
|
||||
# Automatically drop audio sources when we're finished with them.
|
||||
if self._audio_source.is_done():
|
||||
self._audio_source = None
|
||||
|
||||
|
||||
# Automatically drop audio sources when we're finished with them.
|
||||
if self._audio_source.is_done():
|
||||
self._audio_source = None
|
||||
def _update_thread_func(self):
|
||||
while not self._should_stop:
|
||||
time.sleep(0.1)
|
||||
now = datetime.utcnow()
|
||||
if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
||||
self.run_transcription_update()
|
||||
|
||||
|
||||
def update(self):
|
||||
#print("update updating scrolling text...")
|
||||
self.update_scrolling_text()
|
||||
#print("update running transcription update...")
|
||||
|
||||
now = datetime.utcnow()
|
||||
#if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
||||
# self.run_transcription_update()
|
||||
|
||||
def update_scrolling_text(self):
|
||||
|
||||
#print("update_scrolling_text 1")
|
||||
|
||||
# Combine all the known phrases.
|
||||
with self._phrases_list_mutex:
|
||||
rolling_text_target = " ".join(self.phrases).strip()[-160:]
|
||||
|
||||
#print("update_scrolling_text 2")
|
||||
|
||||
if rolling_text_target != self.scrolling_text:
|
||||
|
||||
#print("update_scrolling_text 3")
|
||||
|
||||
# Start the diff off.
|
||||
new_rolling_output_text = diffstuff.onestepchange(
|
||||
self.scrolling_text, rolling_text_target)
|
||||
|
||||
#print("update_scrolling_text 4")
|
||||
|
||||
# Chop off the start all at once. It's not needed for the animation
|
||||
# to look good.
|
||||
#print("update_scrolling_text - ", self.scrolling_text)
|
||||
#print("update_scrolling_text - ", new_rolling_output_text)
|
||||
while self.scrolling_text.endswith(new_rolling_output_text):
|
||||
|
||||
#print("update_scrolling_text - start - ", self.scrolling_text)
|
||||
#print("update_scrolling_text - end - ", new_rolling_output_text)
|
||||
|
||||
new_rolling_output_text = diffstuff.onestepchange(
|
||||
new_rolling_output_text, rolling_text_target)
|
||||
|
||||
#print("update_scrolling_text 5")
|
||||
|
||||
# Set the new text.
|
||||
self.scrolling_text = new_rolling_output_text
|
||||
|
||||
#print("update_scrolling_text 6")
|
||||
|
||||
# Just jump ahead if we're still too far behind.
|
||||
# FIXME: Hardcoded value.
|
||||
if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80:
|
||||
self.scrolling_text = rolling_text_target
|
||||
|
||||
#print("update_scrolling_text 7")
|
||||
|
||||
#print("%s: %s" % (self.username, self.scrolling_text))
|
||||
|
||||
#print("update_scrolling_text 8")
|
||||
|
||||
#print("update_scrolling_text done")
|
||||
|
Loading…
Reference in New Issue
Block a user