Compare commits
No commits in common. "multi_user" and "master" have entirely different histories.
multi_user
...
master
34
README.md
Normal file
34
README.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Real Time Whisper Transcription
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
This is a demo of real time speech to text with OpenAI's Whisper model. It works by constantly recording audio in a thread and concatenating the raw bytes over multiple recordings.
|
||||||
|
|
||||||
|
To install dependencies simply run
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
in an environment of your choosing.
|
||||||
|
|
||||||
|
Whisper also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
|
||||||
|
|
||||||
|
```
|
||||||
|
# on Ubuntu or Debian
|
||||||
|
sudo apt update && sudo apt install ffmpeg
|
||||||
|
|
||||||
|
# on Arch Linux
|
||||||
|
sudo pacman -S ffmpeg
|
||||||
|
|
||||||
|
# on MacOS using Homebrew (https://brew.sh/)
|
||||||
|
brew install ffmpeg
|
||||||
|
|
||||||
|
# on Windows using Chocolatey (https://chocolatey.org/)
|
||||||
|
choco install ffmpeg
|
||||||
|
|
||||||
|
# on Windows using Scoop (https://scoop.sh/)
|
||||||
|
scoop install ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on Whisper please see https://github.com/openai/whisper
|
||||||
|
|
||||||
|
The code in this repository is public domain.
|
196
audiosource.py
196
audiosource.py
@ -1,196 +0,0 @@
|
|||||||
#!/usr/bin/python3
|
|
||||||
|
|
||||||
import socket
|
|
||||||
import select
|
|
||||||
import time
|
|
||||||
from queue import Queue
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import speech_recognition
|
|
||||||
import wave
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from pyogg.opus_decoder import OpusDecoder
|
|
||||||
|
|
||||||
#wave_out = wave.open("tmp/mic.wav", "wb")
|
|
||||||
#wave_out.setnchannels(1)
|
|
||||||
#wave_out.setframerate(16000)
|
|
||||||
#wave_out.setsampwidth(2)
|
|
||||||
|
|
||||||
|
|
||||||
class AudioSource:
|
|
||||||
def __init__(self):
|
|
||||||
# Thread safe Queue for passing data from the threaded recording
|
|
||||||
# callback.
|
|
||||||
self.data_queue = Queue()
|
|
||||||
|
|
||||||
self.time_of_last_input = datetime.utcnow()
|
|
||||||
|
|
||||||
self._data_mutex = threading.Lock()
|
|
||||||
|
|
||||||
def add_data(self, data):
|
|
||||||
with self._data_mutex:
|
|
||||||
self.time_of_last_input = datetime.utcnow()
|
|
||||||
self.data_queue.put(bytearray(data))
|
|
||||||
#wave_out.writeframes(data)
|
|
||||||
|
|
||||||
|
|
||||||
def is_done(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Microphone
|
|
||||||
|
|
||||||
# How real time the recording is in seconds.
|
|
||||||
record_timeout = 2
|
|
||||||
|
|
||||||
class MicrophoneAudioSource(AudioSource):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._recorder = speech_recognition.Recognizer()
|
|
||||||
self._recorder.energy_threshold = 200
|
|
||||||
|
|
||||||
# Definitely do this, dynamic energy compensation lowers the energy
|
|
||||||
# threshold dramatically to a point where the SpeechRecognizer
|
|
||||||
# never stops recording.
|
|
||||||
self._recorder.dynamic_energy_threshold = False
|
|
||||||
|
|
||||||
self._source = speech_recognition.Microphone(sample_rate=16000)
|
|
||||||
|
|
||||||
|
|
||||||
with self._source:
|
|
||||||
self._recorder.adjust_for_ambient_noise(self._source)
|
|
||||||
|
|
||||||
def record_callback(_, audio:speech_recognition.AudioData) -> None:
|
|
||||||
"""
|
|
||||||
Threaded callback function to receive audio data when recordings finish.
|
|
||||||
audio: An AudioData containing the recorded bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Grab the raw bytes and push it into the thread safe queue.
|
|
||||||
data = audio.get_raw_data()
|
|
||||||
self.time_of_last_input = datetime.utcnow()
|
|
||||||
#self.data_queue.put(bytearray(data))
|
|
||||||
self.add_data(data)
|
|
||||||
|
|
||||||
# Create a background thread that will pass us raw audio bytes.
|
|
||||||
# We could do this manually but SpeechRecognizer provides a nice helper.
|
|
||||||
self._stopper = self._recorder.listen_in_background(
|
|
||||||
self._source, record_callback,
|
|
||||||
phrase_time_limit=record_timeout)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
assert(self._stopper)
|
|
||||||
self._stopper()
|
|
||||||
|
|
||||||
self._recorder = None
|
|
||||||
self._stopper = None
|
|
||||||
self._source = None
|
|
||||||
|
|
||||||
def is_done(self):
|
|
||||||
return self._recorder == None
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Opus stream
|
|
||||||
|
|
||||||
# For debugging
|
|
||||||
# wave_out = wave.open("wave2.wav", "wb")
|
|
||||||
# wave_out.setnchannels(1)
|
|
||||||
# wave_out.setframerate(16000)
|
|
||||||
# wave_out.setsampwidth(2)
|
|
||||||
|
|
||||||
class OpusStreamAudioSource(AudioSource):
|
|
||||||
|
|
||||||
def __init__(self, sock):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._socket = sock
|
|
||||||
|
|
||||||
self._opus_decoder = OpusDecoder()
|
|
||||||
self._opus_decoder.set_channels(1)
|
|
||||||
self._opus_decoder.set_sampling_frequency(16000)
|
|
||||||
|
|
||||||
# Fetch user info.
|
|
||||||
user_info_tmp = self._read_packet(self._socket)
|
|
||||||
self._user_info = json.loads(user_info_tmp.decode("utf-8"))
|
|
||||||
print("User connection...")
|
|
||||||
print(json.dumps(self._user_info, indent=4))
|
|
||||||
|
|
||||||
self._is_done = False
|
|
||||||
|
|
||||||
# Start input thread.
|
|
||||||
self._input_thread = threading.Thread(
|
|
||||||
target=self._input_thread_func, daemon=True)
|
|
||||||
self._input_thread.start()
|
|
||||||
|
|
||||||
def _read_packet(self, sock):
|
|
||||||
try:
|
|
||||||
input_buffer = b''
|
|
||||||
#print("Reading packet size...")
|
|
||||||
while len(input_buffer) < 4:
|
|
||||||
input_buffer = input_buffer + sock.recv(1)
|
|
||||||
if not input_buffer:
|
|
||||||
raise Exception("Failed to read size of packet.")
|
|
||||||
|
|
||||||
packet_size = int.from_bytes(input_buffer, "little")
|
|
||||||
#print("Packet size: ", packet_size)
|
|
||||||
|
|
||||||
input_buffer = b''
|
|
||||||
while len(input_buffer) < packet_size:
|
|
||||||
input_buffer = input_buffer + sock.recv(1)
|
|
||||||
if not input_buffer:
|
|
||||||
raise Exception("Failed to read packet.")
|
|
||||||
|
|
||||||
return input_buffer
|
|
||||||
|
|
||||||
except Exception as e: # FIXME: Use socket-specific exception type.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _input_thread_func(self):
|
|
||||||
|
|
||||||
print("input thread start")
|
|
||||||
try:
|
|
||||||
|
|
||||||
while not self._is_done:
|
|
||||||
|
|
||||||
next_packet = self._read_packet(self._socket)
|
|
||||||
|
|
||||||
if next_packet:
|
|
||||||
|
|
||||||
# If we don't use bytearray here to copy, we run into a weird
|
|
||||||
# exception about the memory not being writeable.
|
|
||||||
decoded_data = self._opus_decoder.decode(bytearray(next_packet))
|
|
||||||
|
|
||||||
# For debugging.
|
|
||||||
#wave_out.writeframes(decoded_data)
|
|
||||||
|
|
||||||
# We need to copy decoded_data here or we end up with
|
|
||||||
# recycled buffers in our queue, which leads to broken
|
|
||||||
# audio.
|
|
||||||
#self.data_queue.put(bytearray(decoded_data))
|
|
||||||
self.add_data(decoded_data)
|
|
||||||
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Probably disconnected. We don't care. Just clean up.
|
|
||||||
# FIXME: Limit exception to socket errors.
|
|
||||||
print("INPUT EXCEPTION: ", e)
|
|
||||||
|
|
||||||
print("input thread done")
|
|
||||||
|
|
||||||
self._is_done = True
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._is_done = True
|
|
||||||
|
|
||||||
# We won't join() the input thread because we don't want to sit around
|
|
||||||
# and wait for a packet. It'll die on its own, so whatever.
|
|
||||||
|
|
||||||
def is_done(self):
|
|
||||||
return self._is_done
|
|
||||||
|
|
48
diffstuff.py
48
diffstuff.py
@ -1,48 +0,0 @@
|
|||||||
import textwrap
|
|
||||||
import difflib
|
|
||||||
|
|
||||||
def onestepchange(start, dest):
|
|
||||||
|
|
||||||
ret = ""
|
|
||||||
|
|
||||||
for i, s in enumerate(difflib.ndiff(start, dest)):
|
|
||||||
#print("i: ", i)
|
|
||||||
#print("S: ", s)
|
|
||||||
|
|
||||||
# Remove a character from the start.
|
|
||||||
if s[0] == '-':
|
|
||||||
#print("ret1")
|
|
||||||
return ret + start[i+1:]
|
|
||||||
|
|
||||||
# Add a character.
|
|
||||||
if s[1] == '+':
|
|
||||||
#print("ret2")
|
|
||||||
return ret + s[-1] + start[i:]
|
|
||||||
|
|
||||||
# Keep moving through the stream.
|
|
||||||
ret = ret + s[-1]
|
|
||||||
|
|
||||||
# If we're at the length of the starting string plus one, then we've
|
|
||||||
# added our one character. Let's bounce.
|
|
||||||
if len(ret) > len(start):
|
|
||||||
#print("ret3")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
if ret[i] != start[i]:
|
|
||||||
#print("ret4")
|
|
||||||
return ret + start[i:]
|
|
||||||
|
|
||||||
# Hack.
|
|
||||||
if ret == "":
|
|
||||||
return dest
|
|
||||||
|
|
||||||
#print("ret5 - ret")
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def countsteps(start, dest):
|
|
||||||
step_count = 0
|
|
||||||
while start != dest:
|
|
||||||
start = onestepchange(start, dest)
|
|
||||||
step_count += 1
|
|
||||||
return step_count
|
|
@ -1,8 +1,7 @@
|
|||||||
setuptools
|
setuptools
|
||||||
pyaudio
|
pyaudio
|
||||||
SpeechRecognition
|
SpeechRecognition
|
||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.2.4
|
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
git+https://github.com/openai/whisper.git
|
git+https://github.com/openai/whisper.git
|
||||||
git+https://github.com/TeamPyOgg/PyOgg.git@4118fc40067eb475468726c6bccf1242abfc24fc
|
|
@ -1,28 +1,9 @@
|
|||||||
#! python3.7
|
#! python3.7
|
||||||
|
|
||||||
# Recent phrases to include in the text buffer before the current transcription.
|
|
||||||
recent_phrase_count = 8
|
|
||||||
|
|
||||||
# How real time the recording is in seconds.
|
|
||||||
record_timeout = 2
|
|
||||||
|
|
||||||
# Delete Discord users after a minute of no-activity.
|
|
||||||
discord_transcriber_timeout = 60 # 60
|
|
||||||
|
|
||||||
|
|
||||||
import socket
|
|
||||||
|
|
||||||
# Create socket for listening for incoming Opus audio streams from the Discord
|
|
||||||
# bot.
|
|
||||||
opus_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
opus_server_socket.bind(("127.0.0.1", 9967))
|
|
||||||
opus_server_socket.listen()
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import speech_recognition
|
import speech_recognition as sr
|
||||||
import whisper
|
import whisper
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -31,153 +12,7 @@ from queue import Queue
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
from sys import platform
|
from sys import platform
|
||||||
|
|
||||||
|
import textwrap
|
||||||
import pygame
|
|
||||||
|
|
||||||
import wave
|
|
||||||
#from pyogg.opus import OpusEncoder
|
|
||||||
|
|
||||||
import select
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
import diffstuff
|
|
||||||
import audiosource
|
|
||||||
from transcriber import Transcriber
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pygame_font_height = 32
|
|
||||||
pygame.init()
|
|
||||||
pygame_display_surface = pygame.display.set_mode((960-75, pygame_font_height * 2 * 2.5))
|
|
||||||
pygame.display.set_caption("Transcription")
|
|
||||||
pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
wave_out = wave.open("wave.wav", "wb")
|
|
||||||
wave_out.setnchannels(1)
|
|
||||||
wave_out.setframerate(16000)
|
|
||||||
wave_out.setsampwidth(2)
|
|
||||||
|
|
||||||
transcribers = []
|
|
||||||
|
|
||||||
transcriber1 = Transcriber()
|
|
||||||
mic_source1 = audiosource.MicrophoneAudioSource()
|
|
||||||
transcriber1.set_source(mic_source1)
|
|
||||||
transcriber1.username = "Kiri"
|
|
||||||
transcribers.append(transcriber1)
|
|
||||||
|
|
||||||
# transcriber2 = Transcriber()
|
|
||||||
# mic_source2 = audiosource.MicrophoneAudioSource()
|
|
||||||
# transcriber2.set_source(mic_source2)
|
|
||||||
# transcriber2.username = "Kiri2"
|
|
||||||
# transcribers.append(transcriber2)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
discord_transcribers_per_user_id = {}
|
|
||||||
|
|
||||||
while True:
|
|
||||||
|
|
||||||
# Check for new opus connections.
|
|
||||||
#print("Checking for new connections...")
|
|
||||||
s = select.select([opus_server_socket], [], [], 0)
|
|
||||||
if len(s[0]):
|
|
||||||
accepted_socket, addr = opus_server_socket.accept()
|
|
||||||
#print("Accepted new Opus stream: ", accepted_socket)
|
|
||||||
new_stream = audiosource.OpusStreamAudioSource(accepted_socket)
|
|
||||||
if (new_stream._user_info["userId"] in discord_transcribers_per_user_id):
|
|
||||||
discord_transcribers_per_user_id[new_stream._user_info["userId"]].set_source(new_stream)
|
|
||||||
else:
|
|
||||||
new_transcriber = Transcriber()
|
|
||||||
new_transcriber.set_source(new_stream)
|
|
||||||
discord_transcribers_per_user_id[new_stream._user_info["userId"]] = new_transcriber
|
|
||||||
new_transcriber.username = new_stream._user_info["displayName"]
|
|
||||||
|
|
||||||
if not new_transcriber in transcribers:
|
|
||||||
transcribers.append(new_transcriber)
|
|
||||||
|
|
||||||
removal_queue = []
|
|
||||||
|
|
||||||
# Run updates.
|
|
||||||
print("Running updates...")
|
|
||||||
for transcriber in transcribers:
|
|
||||||
#print("Running updates for... ", transcriber.username)
|
|
||||||
transcriber.update()
|
|
||||||
#print("Done running updates for... ", transcriber.username)
|
|
||||||
|
|
||||||
if transcriber._phrase_time + timedelta(seconds=discord_transcriber_timeout) < datetime.utcnow():
|
|
||||||
if transcriber._audio_source == None:
|
|
||||||
removal_queue.append(transcriber)
|
|
||||||
|
|
||||||
#print("Running removals...")
|
|
||||||
|
|
||||||
# Note that this will not remove them from discord_transcribers_per_user_id.
|
|
||||||
# It's probably fine, though.
|
|
||||||
#print("Running removals...")
|
|
||||||
for removal in removal_queue:
|
|
||||||
#print("Removing inactive user: ", removal.username)
|
|
||||||
transcribers.remove(removal)
|
|
||||||
|
|
||||||
# Sleep.
|
|
||||||
print("Sleeping...")
|
|
||||||
time.sleep(0.05)
|
|
||||||
|
|
||||||
#print("Rendering...")
|
|
||||||
|
|
||||||
# Do rendering.
|
|
||||||
pygame_display_surface.fill((0, 0, 0))
|
|
||||||
|
|
||||||
# Render text.
|
|
||||||
for transcriber in transcribers:
|
|
||||||
pygame_text_surface = pygame_font.render(transcriber.scrolling_text, (0, 0, 0), (255, 255, 255))
|
|
||||||
|
|
||||||
pygame_text_rect = pygame_text_surface.get_rect()
|
|
||||||
pygame_text_rect.center = (
|
|
||||||
pygame_display_surface.get_width() / 2,
|
|
||||||
pygame_font_height * (1 + transcribers.index(transcriber)))
|
|
||||||
pygame_text_rect.right = pygame_display_surface.get_width()
|
|
||||||
|
|
||||||
pygame_display_surface.blit(pygame_text_surface, pygame_text_rect)
|
|
||||||
|
|
||||||
# Render a background for the names.
|
|
||||||
fill_rect = pygame_display_surface.get_rect()
|
|
||||||
fill_rect.width = 220
|
|
||||||
fill_rect.left = 0
|
|
||||||
pygame_display_surface.fill((0, 0, 0), fill_rect)
|
|
||||||
|
|
||||||
# Render names.
|
|
||||||
for transcriber in transcribers:
|
|
||||||
|
|
||||||
username_for_display = ""
|
|
||||||
for c in transcriber.username:
|
|
||||||
if ord(c) <= 127:
|
|
||||||
username_for_display += c
|
|
||||||
|
|
||||||
pygame_username_surface = pygame_font.render(username_for_display, (0, 0, 0), (255, 255, 255))
|
|
||||||
|
|
||||||
pygame_text_rect = pygame_username_surface.get_rect()
|
|
||||||
pygame_text_rect.center = (
|
|
||||||
pygame_display_surface.get_width() / 2,
|
|
||||||
pygame_font_height * (1 + transcribers.index(transcriber)))
|
|
||||||
pygame_text_rect.left = 16
|
|
||||||
|
|
||||||
pygame_display_surface.blit(pygame_username_surface, pygame_text_rect)
|
|
||||||
|
|
||||||
pygame.display.update()
|
|
||||||
|
|
||||||
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -187,6 +22,8 @@ def main():
|
|||||||
help="Don't use the english model.")
|
help="Don't use the english model.")
|
||||||
parser.add_argument("--energy_threshold", default=1000,
|
parser.add_argument("--energy_threshold", default=1000,
|
||||||
help="Energy level for mic to detect.", type=int)
|
help="Energy level for mic to detect.", type=int)
|
||||||
|
parser.add_argument("--record_timeout", default=2,
|
||||||
|
help="How real time the recording is in seconds.", type=float)
|
||||||
parser.add_argument("--phrase_timeout", default=3,
|
parser.add_argument("--phrase_timeout", default=3,
|
||||||
help="How much empty space between recordings before we "
|
help="How much empty space between recordings before we "
|
||||||
"consider it a new line in the transcription.", type=float)
|
"consider it a new line in the transcription.", type=float)
|
||||||
@ -198,8 +35,30 @@ def main():
|
|||||||
|
|
||||||
# The last time a recording was retrieved from the queue.
|
# The last time a recording was retrieved from the queue.
|
||||||
phrase_time = None
|
phrase_time = None
|
||||||
#data_queue = Queue()
|
# Thread safe Queue for passing data from the threaded recording callback.
|
||||||
|
data_queue = Queue()
|
||||||
# We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
|
# We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
|
||||||
|
recorder = sr.Recognizer()
|
||||||
|
recorder.energy_threshold = args.energy_threshold
|
||||||
|
# Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording.
|
||||||
|
recorder.dynamic_energy_threshold = False
|
||||||
|
|
||||||
|
# Important for linux users.
|
||||||
|
# Prevents permanent application hang and crash by using the wrong Microphone
|
||||||
|
if 'linux' in platform:
|
||||||
|
mic_name = args.default_microphone
|
||||||
|
if not mic_name or mic_name == 'list':
|
||||||
|
print("Available microphone devices are: ")
|
||||||
|
for index, name in enumerate(sr.Microphone.list_microphone_names()):
|
||||||
|
print(f"Microphone with name \"{name}\" found")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
for index, name in enumerate(sr.Microphone.list_microphone_names()):
|
||||||
|
if mic_name in name:
|
||||||
|
source = sr.Microphone(sample_rate=16000, device_index=index)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
source = sr.Microphone(sample_rate=16000)
|
||||||
|
|
||||||
# Load / Download model
|
# Load / Download model
|
||||||
model = args.model
|
model = args.model
|
||||||
@ -207,81 +66,49 @@ def main():
|
|||||||
model = model + ".en"
|
model = model + ".en"
|
||||||
audio_model = whisper.load_model(model)
|
audio_model = whisper.load_model(model)
|
||||||
|
|
||||||
|
record_timeout = args.record_timeout
|
||||||
phrase_timeout = args.phrase_timeout
|
phrase_timeout = args.phrase_timeout
|
||||||
|
|
||||||
transcription = ['']
|
transcription = ['']
|
||||||
|
|
||||||
|
with source:
|
||||||
|
recorder.adjust_for_ambient_noise(source)
|
||||||
|
|
||||||
|
def record_callback(_, audio:sr.AudioData) -> None:
|
||||||
|
"""
|
||||||
|
Threaded callback function to receive audio data when recordings finish.
|
||||||
|
audio: An AudioData containing the recorded bytes.
|
||||||
|
"""
|
||||||
|
# Grab the raw bytes and push it into the thread safe queue.
|
||||||
|
data = audio.get_raw_data()
|
||||||
|
data_queue.put(data)
|
||||||
|
|
||||||
|
# Create a background thread that will pass us raw audio bytes.
|
||||||
|
# We could do this manually but SpeechRecognizer provides a nice helper.
|
||||||
|
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
|
||||||
|
|
||||||
# Cue the user that we're ready to go.
|
# Cue the user that we're ready to go.
|
||||||
print("Model loaded.\n")
|
print("Model loaded.\n")
|
||||||
|
|
||||||
# Rolling output text buffer.
|
|
||||||
|
|
||||||
# This is the one that animates. Stored as a single string.
|
|
||||||
rolling_output_text = ""
|
|
||||||
# This is the one that updates in big chunks at lower frequency.
|
|
||||||
# Stored as an array of phrases.
|
|
||||||
output_text = [""]
|
|
||||||
|
|
||||||
mic_audio_source = MicrophoneAudioSource()
|
|
||||||
mic_audio_source.start()
|
|
||||||
data_queue = mic_audio_source.data_queue
|
|
||||||
|
|
||||||
# Rolling audio input buffer.
|
|
||||||
audio_data = b''
|
audio_data = b''
|
||||||
|
|
||||||
diffsize = 0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
for event in pygame.event.get():
|
|
||||||
if event.type == pygame.QUIT:
|
|
||||||
pygame.quit()
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
rolling_text_target = " ".join(output_text)[-160:]
|
|
||||||
if rolling_text_target != rolling_output_text:
|
|
||||||
|
|
||||||
# Chop off the start all at once. It's not needed for the animation to look good.
|
|
||||||
new_rolling_output_text = onestepchange(rolling_output_text, rolling_text_target)
|
|
||||||
while rolling_output_text.endswith(new_rolling_output_text):
|
|
||||||
new_rolling_output_text = onestepchange(new_rolling_output_text, rolling_text_target)
|
|
||||||
rolling_output_text = new_rolling_output_text
|
|
||||||
|
|
||||||
if countsteps(rolling_output_text, rolling_text_target) > 80:
|
|
||||||
rolling_output_text = rolling_text_target
|
|
||||||
|
|
||||||
print(rolling_output_text)
|
|
||||||
|
|
||||||
pygame_text_surface = pygame_font.render(rolling_output_text, (0, 0, 0), (255, 255, 255))
|
|
||||||
pygame_text_rect = pygame_text_surface.get_rect()
|
|
||||||
pygame_text_rect.center = (640, pygame_font_height)
|
|
||||||
pygame_text_rect.right = 1280
|
|
||||||
pygame_display_surface.fill((0, 0, 0))
|
|
||||||
pygame_display_surface.blit(pygame_text_surface, pygame_text_rect)
|
|
||||||
|
|
||||||
pygame.display.update()
|
|
||||||
|
|
||||||
diffsize = abs(len(rolling_output_text) - len(rolling_text_target))
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
# Pull raw recorded audio from the queue.
|
# Pull raw recorded audio from the queue.
|
||||||
if not data_queue.empty():
|
if not data_queue.empty():
|
||||||
|
|
||||||
phrase_complete = False
|
phrase_complete = False
|
||||||
# If enough time has passed between recordings, consider the phrase complete.
|
# If enough time has passed between recordings, consider the phrase complete.
|
||||||
# Clear the current working audio buffer to start over with the new data.
|
# Clear the current working audio buffer to start over with the new data.
|
||||||
#
|
|
||||||
# FIXME: Shouldn't we cut off the phrase here instead of
|
|
||||||
# waiting for later?
|
|
||||||
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
|
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
|
||||||
phrase_complete = True
|
phrase_complete = True
|
||||||
|
|
||||||
# This is the last time we received new audio data from the queue.
|
# This is the last time we received new audio data from the queue.
|
||||||
phrase_time = now
|
phrase_time = now
|
||||||
|
|
||||||
|
# for d in data_queue:
|
||||||
|
# if d > 0.5:
|
||||||
|
# print("Got something: ", d)
|
||||||
|
|
||||||
# Combine audio data from queue
|
# Combine audio data from queue
|
||||||
audio_data += b''.join(data_queue.queue)
|
audio_data += b''.join(data_queue.queue)
|
||||||
data_queue.queue.clear()
|
data_queue.queue.clear()
|
||||||
@ -291,39 +118,55 @@ def main():
|
|||||||
# Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max.
|
# Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max.
|
||||||
audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
# Run the transcription model, and extract the text.
|
# Read the transcription.
|
||||||
result = audio_model.transcribe(audio_np, fp16=torch.cuda.is_available())
|
result = audio_model.transcribe(audio_np, fp16=torch.cuda.is_available())
|
||||||
text = result['text'].strip()
|
text = result['text'].strip()
|
||||||
|
|
||||||
|
# # If we detected a pause between recordings, add a new item to our transcription.
|
||||||
|
# # Otherwise edit the existing one.
|
||||||
|
# if phrase_complete:
|
||||||
|
# transcription.append(text)
|
||||||
|
# else:
|
||||||
|
# transcription[-1] += text
|
||||||
|
print(text)
|
||||||
|
|
||||||
# Update rolling transcription file.
|
# Update rolling transcription file.
|
||||||
|
f = open("transcription.txt", "w+")
|
||||||
# Start with all our recent-but-complete phrases.
|
output_text = transcription[-4:]
|
||||||
output_text = transcription[-recent_phrase_count:]
|
|
||||||
|
|
||||||
# Append the phrase-in-progress. (TODO: Can we make this a
|
|
||||||
# different color or something?)
|
|
||||||
output_text.append(text)
|
output_text.append(text)
|
||||||
|
f.write(" ".join(output_text))
|
||||||
|
f.close()
|
||||||
|
|
||||||
# If we're done with the phrase, we can go ahead and stuff
|
|
||||||
# it into the list and clear out the current audio data
|
|
||||||
# buffer.
|
|
||||||
if phrase_complete:
|
if phrase_complete:
|
||||||
|
|
||||||
# Append to full transcription.
|
# Append to full transcription.
|
||||||
if text != "":
|
|
||||||
transcription.append(text)
|
transcription.append(text)
|
||||||
|
|
||||||
# Clear audio buffer.
|
# text += "\n"
|
||||||
|
# f = open("transcription.txt", "w+")
|
||||||
|
# f.write("\n".join(textwrap.wrap(text)))
|
||||||
|
# f.close()
|
||||||
|
|
||||||
|
print("* Phrase complete.")
|
||||||
audio_data = b''
|
audio_data = b''
|
||||||
|
|
||||||
# Infinite loops are bad for processors, must sleep. Also, limit the anim speed.
|
|
||||||
if diffsize > 30:
|
|
||||||
sleep(0.0025)
|
|
||||||
else:
|
|
||||||
sleep(0.0125)
|
|
||||||
|
|
||||||
|
# Clear the console to reprint the updated transcription.
|
||||||
|
# os.system('cls' if os.name=='nt' else 'clear')
|
||||||
|
for line in transcription:
|
||||||
|
print(line)
|
||||||
|
# Flush stdout.
|
||||||
|
print('', end='', flush=True)
|
||||||
|
else:
|
||||||
|
# Infinite loops are bad for processors, must sleep.
|
||||||
|
sleep(0.01)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
print("\n\nTranscription:")
|
||||||
|
for line in transcription:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
264
transcriber.py
264
transcriber.py
@ -1,264 +0,0 @@
|
|||||||
#!/usr/bin/python3
|
|
||||||
|
|
||||||
# Recent phrases to include in the text buffer before the current transcription.
|
|
||||||
recent_phrase_count = 8
|
|
||||||
|
|
||||||
# Seconds of silence before we start a new phrase.
|
|
||||||
phrase_timeout = 1.0
|
|
||||||
|
|
||||||
# Higher is less restrictive on what it lets pass through.
|
|
||||||
no_speech_prob_threshold = 0.05 # 0.15
|
|
||||||
|
|
||||||
# Minimum number of seconds before we fire off the model again.
|
|
||||||
min_time_between_updates = 2
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import speech_recognition
|
|
||||||
import whisper
|
|
||||||
import torch
|
|
||||||
import wave
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import diffstuff
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
_audio_model = whisper.load_model("medium.en") # "large"
|
|
||||||
_audio_model_mutex = threading.Lock()
|
|
||||||
|
|
||||||
# For debugging...
|
|
||||||
#wave_out = wave.open("wave.wav", "wb")
|
|
||||||
#wave_out.setnchannels(1)
|
|
||||||
#wave_out.setframerate(16000)
|
|
||||||
#wave_out.setsampwidth(2)
|
|
||||||
|
|
||||||
class Transcriber:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._audio_source = None
|
|
||||||
|
|
||||||
# Audio data for the current phrase.
|
|
||||||
self._current_data = b''
|
|
||||||
|
|
||||||
self.phrases = [""]
|
|
||||||
|
|
||||||
self.scrolling_text = ""
|
|
||||||
|
|
||||||
# Time since the last data came in for the current phrase.
|
|
||||||
self._phrase_time = datetime.utcnow()
|
|
||||||
|
|
||||||
# Last time that we ran the model.
|
|
||||||
self._last_model_time = datetime.utcnow()
|
|
||||||
|
|
||||||
self._should_stop = False
|
|
||||||
|
|
||||||
self._phrases_list_mutex = threading.Lock()
|
|
||||||
|
|
||||||
self._update_thread = threading.Thread(
|
|
||||||
target=self._update_thread_func, daemon=True)
|
|
||||||
self._update_thread.start()
|
|
||||||
|
|
||||||
self.username = ""
|
|
||||||
|
|
||||||
self._audio_source_queue = Queue()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._should_stop = True
|
|
||||||
|
|
||||||
def set_source(self, source):
|
|
||||||
self._audio_source_queue.put(source)
|
|
||||||
|
|
||||||
def phrase_probably_silent(self):
|
|
||||||
"""Whisper hallucinates a LOT on silence, so let's just ignore stuff
|
|
||||||
that's mostly silence. First line of defense here."""
|
|
||||||
|
|
||||||
threshold = 100
|
|
||||||
threshold_pass = 0
|
|
||||||
threshold_fail = 0
|
|
||||||
avg = 0
|
|
||||||
for k in self._current_data:
|
|
||||||
avg += k
|
|
||||||
if(abs(k)) > threshold:
|
|
||||||
threshold_pass += 1
|
|
||||||
else:
|
|
||||||
threshold_fail += 1
|
|
||||||
|
|
||||||
avg = avg / len(self._current_data)
|
|
||||||
threshold_pct = threshold_pass / len(self._current_data)
|
|
||||||
|
|
||||||
if threshold_pct < 0.1:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def run_transcription_update(self):
|
|
||||||
|
|
||||||
#print("run_transcription_update Checking queue...")
|
|
||||||
|
|
||||||
# Switch to whatever the newest source is.
|
|
||||||
if self._audio_source == None and not self._audio_source_queue.empty():
|
|
||||||
self._audio_source = self._audio_source_queue.get()
|
|
||||||
|
|
||||||
if not self._audio_source:
|
|
||||||
#print("run_transcription_update returning early...")
|
|
||||||
return
|
|
||||||
|
|
||||||
now = datetime.utcnow()
|
|
||||||
if not self._audio_source.data_queue.empty():
|
|
||||||
|
|
||||||
# We got some new data. Let's process it!
|
|
||||||
|
|
||||||
# Get all the new data since last tick.
|
|
||||||
new_data = []
|
|
||||||
with self._audio_source._data_mutex:
|
|
||||||
while not self._audio_source.data_queue.empty():
|
|
||||||
new_packet = self._audio_source.data_queue.get()
|
|
||||||
new_data.append(new_packet)
|
|
||||||
new_data_joined = b''.join(new_data)
|
|
||||||
|
|
||||||
# For debugging...
|
|
||||||
#wave_out.writeframes(new_data_joined)
|
|
||||||
|
|
||||||
# Append it to the current buffer.
|
|
||||||
self._current_data = self._current_data + new_data_joined
|
|
||||||
|
|
||||||
# if self.phrase_probably_silent():
|
|
||||||
# with self._phrases_list_mutex:
|
|
||||||
# self.phrases[-1] = ""
|
|
||||||
# else:
|
|
||||||
|
|
||||||
# Convert in-ram buffer to something the model can use
|
|
||||||
# directly without needing a temp file. Convert data from 16
|
|
||||||
# bit wide integers to floating point with a width of 32
|
|
||||||
# bits. Clamp the audio stream frequency to a PCM wavelength
|
|
||||||
# compatible default of 32768hz max.
|
|
||||||
audio_np = np.frombuffer(
|
|
||||||
self._current_data[-16000 * 10:],
|
|
||||||
dtype=np.int16).astype(np.float32) / 32768.0
|
|
||||||
|
|
||||||
#print("run_transcription_update About to run transcription...")
|
|
||||||
|
|
||||||
# Run the transcription model, and extract the text.
|
|
||||||
with _audio_model_mutex:
|
|
||||||
#print("Transcribe start ", len(self._current_data))
|
|
||||||
result = _audio_model.transcribe(
|
|
||||||
audio_np, fp16=torch.cuda.is_available(),
|
|
||||||
word_timestamps=True,
|
|
||||||
hallucination_silence_threshold=2)
|
|
||||||
#print("Transcribe end")
|
|
||||||
self._last_model_time = now
|
|
||||||
|
|
||||||
with self._phrases_list_mutex:
|
|
||||||
wave_out = wave.open("tmp/wave%0.4d.wav" % len(self.phrases), "wb")
|
|
||||||
wave_out.setnchannels(1)
|
|
||||||
wave_out.setframerate(16000)
|
|
||||||
wave_out.setsampwidth(2)
|
|
||||||
wave_out.writeframes(self._current_data)
|
|
||||||
|
|
||||||
|
|
||||||
# Filter out text segments with a high no_speech_prob.
|
|
||||||
combined_text = ""
|
|
||||||
for seg in result["segments"]:
|
|
||||||
if seg["no_speech_prob"] <= no_speech_prob_threshold:
|
|
||||||
combined_text += seg["text"]
|
|
||||||
|
|
||||||
text = combined_text.strip()
|
|
||||||
|
|
||||||
# # FIXME:
|
|
||||||
text = result["text"]
|
|
||||||
|
|
||||||
with self._phrases_list_mutex:
|
|
||||||
self.phrases[-1] = text
|
|
||||||
#print("phrases: ", json.dumps(self.phrases, indent=4))
|
|
||||||
|
|
||||||
# Update phrase time at the end so waiting for the mutex doesn't
|
|
||||||
# cause us to split phrases.
|
|
||||||
self._phrase_time = now
|
|
||||||
|
|
||||||
# If enough time has passed between recordings, consider the
|
|
||||||
# last phrase complete and start a new one. Clear the current
|
|
||||||
# working audio buffer to start over with the new data.
|
|
||||||
if now - self._audio_source.time_of_last_input > timedelta(seconds=phrase_timeout):
|
|
||||||
|
|
||||||
# Only add a new phrase if we actually have data in the last
|
|
||||||
# one.
|
|
||||||
with self._phrases_list_mutex:
|
|
||||||
if self.phrases[-1] != "":
|
|
||||||
self.phrases.append("")
|
|
||||||
|
|
||||||
self._current_data = b''
|
|
||||||
|
|
||||||
# Automatically drop audio sources when we're finished with them.
|
|
||||||
if self._audio_source.is_done():
|
|
||||||
self._audio_source = None
|
|
||||||
|
|
||||||
|
|
||||||
def _update_thread_func(self):
|
|
||||||
while not self._should_stop:
|
|
||||||
time.sleep(0.1)
|
|
||||||
now = datetime.utcnow()
|
|
||||||
if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
|
||||||
self.run_transcription_update()
|
|
||||||
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
#print("update updating scrolling text...")
|
|
||||||
self.update_scrolling_text()
|
|
||||||
#print("update running transcription update...")
|
|
||||||
|
|
||||||
now = datetime.utcnow()
|
|
||||||
#if self._last_model_time + timedelta(seconds=min_time_between_updates) < now:
|
|
||||||
# self.run_transcription_update()
|
|
||||||
|
|
||||||
def update_scrolling_text(self):
|
|
||||||
|
|
||||||
#print("update_scrolling_text 1")
|
|
||||||
|
|
||||||
# Combine all the known phrases.
|
|
||||||
with self._phrases_list_mutex:
|
|
||||||
rolling_text_target = " ".join(self.phrases).strip()[-160:]
|
|
||||||
|
|
||||||
#print("update_scrolling_text 2")
|
|
||||||
|
|
||||||
if rolling_text_target != self.scrolling_text:
|
|
||||||
|
|
||||||
#print("update_scrolling_text 3")
|
|
||||||
|
|
||||||
# Start the diff off.
|
|
||||||
new_rolling_output_text = diffstuff.onestepchange(
|
|
||||||
self.scrolling_text, rolling_text_target)
|
|
||||||
|
|
||||||
#print("update_scrolling_text 4")
|
|
||||||
|
|
||||||
# Chop off the start all at once. It's not needed for the animation
|
|
||||||
# to look good.
|
|
||||||
#print("update_scrolling_text - ", self.scrolling_text)
|
|
||||||
#print("update_scrolling_text - ", new_rolling_output_text)
|
|
||||||
while self.scrolling_text.endswith(new_rolling_output_text):
|
|
||||||
|
|
||||||
#print("update_scrolling_text - start - ", self.scrolling_text)
|
|
||||||
#print("update_scrolling_text - end - ", new_rolling_output_text)
|
|
||||||
|
|
||||||
new_rolling_output_text = diffstuff.onestepchange(
|
|
||||||
new_rolling_output_text, rolling_text_target)
|
|
||||||
|
|
||||||
#print("update_scrolling_text 5")
|
|
||||||
|
|
||||||
# Set the new text.
|
|
||||||
self.scrolling_text = new_rolling_output_text
|
|
||||||
|
|
||||||
#print("update_scrolling_text 6")
|
|
||||||
|
|
||||||
# Just jump ahead if we're still too far behind.
|
|
||||||
# FIXME: Hardcoded value.
|
|
||||||
if diffstuff.countsteps(self.scrolling_text, rolling_text_target) > 80:
|
|
||||||
self.scrolling_text = rolling_text_target
|
|
||||||
|
|
||||||
#print("update_scrolling_text 7")
|
|
||||||
|
|
||||||
#print("%s: %s" % (self.username, self.scrolling_text))
|
|
||||||
|
|
||||||
#print("update_scrolling_text 8")
|
|
||||||
|
|
||||||
#print("update_scrolling_text done")
|
|
Loading…
Reference in New Issue
Block a user