diff --git a/kiri_reqs.txt b/kiri_reqs.txt new file mode 100644 index 0000000..cb9519c --- /dev/null +++ b/kiri_reqs.txt @@ -0,0 +1 @@ +whisper-live tokenizers==0.20.3 diff --git a/requirements.txt b/requirements.txt index 565934d..18c62c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ SpeechRecognition torch numpy git+https://github.com/openai/whisper.git + +pygame diff --git a/requirements2.txt b/requirements2.txt new file mode 100644 index 0000000..bf909eb --- /dev/null +++ b/requirements2.txt @@ -0,0 +1,7 @@ +setuptools +pyaudio +SpeechRecognition +--extra-index-url https://download.pytorch.org/whl/rocm6.2.4 +torch +numpy +git+https://github.com/openai/whisper.git diff --git a/transcribe_demo.py b/transcribe_demo.py index 8f90fb7..edc824b 100644 --- a/transcribe_demo.py +++ b/transcribe_demo.py @@ -1,9 +1,16 @@ #! python3.7 +# Recent phrases to include in the text buffer before the current transcription. +recent_phrase_count = 8 + +# How real time the recording is in seconds. +record_timeout = 2 + + import argparse import os import numpy as np -import speech_recognition as sr +import speech_recognition import whisper import torch @@ -13,6 +20,110 @@ from time import sleep from sys import platform import textwrap +import difflib + +import pygame + + +pygame_font_height = 16 +pygame.init() +pygame_display_surface = pygame.display.set_mode((1280, pygame_font_height * 2)) +pygame.display.set_caption("Transcription") +pygame_font = pygame.font.Font("/home/kiri/.fonts/Sigmar-Regular.ttf", pygame_font_height) + + + + +class AudioSource: + def __init__(self): + # Thread safe Queue for passing data from the threaded recording callback. + self.data_queue = Queue() + +class MicrophoneAudioSource(AudioSource): + def __init__(self): + super().__init__() + + self.recorder = speech_recognition.Recognizer() + self.recorder.energy_threshold = 1000 + + # Definitely do this, dynamic energy compensation lowers the energy + # threshold dramatically to a point where the SpeechRecognizer + # never stops recording. + self.recorder.dynamic_energy_threshold = False + + self.source = speech_recognition.Microphone(sample_rate=16000) + + + with self.source: + self.recorder.adjust_for_ambient_noise(self.source) + + def record_callback(_, audio:speech_recognition.AudioData) -> None: + """ + Threaded callback function to receive audio data when recordings finish. + audio: An AudioData containing the recorded bytes. + """ + # Grab the raw bytes and push it into the thread safe queue. + + print("GOT SOME DATA!!!") + + data = audio.get_raw_data() + self.data_queue.put(data) + + # Create a background thread that will pass us raw audio bytes. + # We could do this manually but SpeechRecognizer provides a nice helper. + self.recorder.listen_in_background(self.source, record_callback, phrase_time_limit=record_timeout) + + print("--------------------------------------------------------------") + print("Done setting up mic!") + print("--------------------------------------------------------------") + + + +# while True: +# pygame_text_surface = pygame_font.render("Test test test", (0, 0, 0), (255, 255, 255)) +# pygame_text_rect = pygame_text_surface.get_rect() +# pygame_text_rect.center = (640, 32) +# pygame_display_surface.fill((0, 0, 0)) +# pygame_display_surface.blit(pygame_text_surface, pygame_text_rect) + +# for event in pygame.event.get(): +# if event.type == pygame.QUIT: +# pygame.quit() + +# pygame.display.update() + +# exit(0) + +def onestepchange(start, dest): + + ret = "" + + for i, s in enumerate(difflib.ndiff(start, dest)): + # print(i) + # print(s) + + if s[0] == '-': + return ret + start[i+1:] + + if s[1] == '+': + return ret + s[-1] + start[i:] + + ret = ret + s[-1] + + if len(ret) > len(start): + return ret + + if ret[i] != start[i]: + return ret + start[i:] + + return ret + +def countsteps(start, dest): + step_count = 0 + while start != dest: + start = onestepchange(start, dest) + step_count += 1 + return step_count def main(): parser = argparse.ArgumentParser() @@ -22,8 +133,6 @@ def main(): help="Don't use the english model.") parser.add_argument("--energy_threshold", default=1000, help="Energy level for mic to detect.", type=int) - parser.add_argument("--record_timeout", default=2, - help="How real time the recording is in seconds.", type=float) parser.add_argument("--phrase_timeout", default=3, help="How much empty space between recordings before we " "consider it a new line in the transcription.", type=float) @@ -35,33 +144,8 @@ def main(): # The last time a recording was retrieved from the queue. phrase_time = None - # Thread safe Queue for passing data from the threaded recording callback. - data_queue = Queue() + #data_queue = Queue() # We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends. - recorder = sr.Recognizer() - recorder.energy_threshold = args.energy_threshold - # Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording. - recorder.dynamic_energy_threshold = False - - # Important for linux users. - # Prevents permanent application hang and crash by using the wrong Microphone - source = None - # if 'linux' in platform: - # mic_name = args.default_microphone - # if not mic_name or mic_name == 'list': - # print("Available microphone devices are: ") - # for index, name in enumerate(sr.Microphone.list_microphone_names()): - # print(f"Microphone with name \"{name}\" found") - # return - # else: - # for index, name in enumerate(sr.Microphone.list_microphone_names()): - # if mic_name in name: - # source = sr.Microphone(sample_rate=16000, device_index=index) - # break - # else: - # source = sr.Microphone(sample_rate=16000) - - source = sr.Microphone(sample_rate=16000) # Load / Download model model = args.model @@ -69,107 +153,122 @@ def main(): model = model + ".en" audio_model = whisper.load_model(model) - record_timeout = args.record_timeout phrase_timeout = args.phrase_timeout transcription = [''] - with source: - recorder.adjust_for_ambient_noise(source) - - def record_callback(_, audio:sr.AudioData) -> None: - """ - Threaded callback function to receive audio data when recordings finish. - audio: An AudioData containing the recorded bytes. - """ - # Grab the raw bytes and push it into the thread safe queue. - data = audio.get_raw_data() - data_queue.put(data) - - # Create a background thread that will pass us raw audio bytes. - # We could do this manually but SpeechRecognizer provides a nice helper. - recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) - # Cue the user that we're ready to go. print("Model loaded.\n") + # Rolling output text buffer. + + # This is the one that animates. Stored as a single string. + rolling_output_text = "" + # This is the one that updates in big chunks at lower frequency. + # Stored as an array of phrases. + output_text = [""] + + mic_audio_source = MicrophoneAudioSource() + data_queue = mic_audio_source.data_queue + + # Rolling audio input buffer. audio_data = b'' + diffsize = 0 + while True: try: - now = datetime.utcnow() - # Pull raw recorded audio from the queue. - if not data_queue.empty(): - phrase_complete = False - # If enough time has passed between recordings, consider the phrase complete. - # Clear the current working audio buffer to start over with the new data. - if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): - phrase_complete = True - # This is the last time we received new audio data from the queue. - phrase_time = now - # for d in data_queue: - # if d > 0.5: - # print("Got something: ", d) + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + exit(0) - # Combine audio data from queue - audio_data += b''.join(data_queue.queue) - data_queue.queue.clear() + rolling_text_target = " ".join(output_text)[-160:] + if rolling_text_target != rolling_output_text: - # Convert in-ram buffer to something the model can use directly without needing a temp file. - # Convert data from 16 bit wide integers to floating point with a width of 32 bits. - # Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max. - audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 + # Chop off the start all at once. It's not needed for the animation to look good. + new_rolling_output_text = onestepchange(rolling_output_text, rolling_text_target) + while rolling_output_text.endswith(new_rolling_output_text): + new_rolling_output_text = onestepchange(new_rolling_output_text, rolling_text_target) + rolling_output_text = new_rolling_output_text - # Read the transcription. - result = audio_model.transcribe(audio_np, fp16=torch.cuda.is_available()) - text = result['text'].strip() + if countsteps(rolling_output_text, rolling_text_target) > 80: + rolling_output_text = rolling_text_target - # # If we detected a pause between recordings, add a new item to our transcription. - # # Otherwise edit the existing one. - # if phrase_complete: - # transcription.append(text) - # else: - # transcription[-1] += text - print(text) + print(rolling_output_text) - # Update rolling transcription file. - f = open("transcription.txt", "w+") - output_text = transcription[-4:] - output_text.append(text) - f.write(" ".join(output_text)) - f.close() + pygame_text_surface = pygame_font.render(rolling_output_text, (0, 0, 0), (255, 255, 255)) + pygame_text_rect = pygame_text_surface.get_rect() + pygame_text_rect.center = (640, pygame_font_height) + pygame_text_rect.right = 1280 + pygame_display_surface.fill((0, 0, 0)) + pygame_display_surface.blit(pygame_text_surface, pygame_text_rect) - if phrase_complete: + pygame.display.update() - # Append to full transcription. - transcription.append(text) + diffsize = abs(len(rolling_output_text) - len(rolling_text_target)) - # text += "\n" - # f = open("transcription.txt", "w+") - # f.write("\n".join(textwrap.wrap(text))) - # f.close() - - print("* Phrase complete.") - audio_data = b'' - - - # Clear the console to reprint the updated transcription. - # os.system('cls' if os.name=='nt' else 'clear') - for line in transcription: - print(line) - # Flush stdout. - print('', end='', flush=True) else: - # Infinite loops are bad for processors, must sleep. + + now = datetime.utcnow() + # Pull raw recorded audio from the queue. + if not data_queue.empty(): + + phrase_complete = False + # If enough time has passed between recordings, consider the phrase complete. + # Clear the current working audio buffer to start over with the new data. + # + # FIXME: Shouldn't we cut off the phrase here instead of + # waiting for later? + if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): + phrase_complete = True + + # This is the last time we received new audio data from the queue. + phrase_time = now + + # Combine audio data from queue + audio_data += b''.join(data_queue.queue) + data_queue.queue.clear() + + # Convert in-ram buffer to something the model can use directly without needing a temp file. + # Convert data from 16 bit wide integers to floating point with a width of 32 bits. + # Clamp the audio stream frequency to a PCM wavelength compatible default of 32768hz max. + audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 + + # Run the transcription model, and extract the text. + result = audio_model.transcribe(audio_np, fp16=torch.cuda.is_available()) + text = result['text'].strip() + + # Update rolling transcription file. + + # Start with all our recent-but-complete phrases. + output_text = transcription[-recent_phrase_count:] + + # Append the phrase-in-progress. (TODO: Can we make this a + # different color or something?) + output_text.append(text) + + # If we're done with the phrase, we can go ahead and stuff + # it into the list and clear out the current audio data + # buffer. + if phrase_complete: + + # Append to full transcription. + if text != "": + transcription.append(text) + + # Clear audio buffer. + audio_data = b'' + + # Infinite loops are bad for processors, must sleep. Also, limit the anim speed. + if diffsize > 30: sleep(0.01) + else: + sleep(0.05) + except KeyboardInterrupt: break - print("\n\nTranscription:") - for line in transcription: - print(line) - - if __name__ == "__main__": main()