-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathmain.py
212 lines (171 loc) · 8.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import torch
import hydra
import cv2
import time
from pipelines.pipeline import InferencePipeline
import numpy as np
from datetime import datetime
from ollama import chat
from pydantic import BaseModel
import keyboard
from concurrent.futures import ThreadPoolExecutor
import os
# pydantic model for the chat output
class ChaplinOutput(BaseModel):
list_of_changes: str
corrected_text: str
class Chaplin:
def __init__(self):
self.vsr_model = None
# flag to toggle recording
self.recording = False
# thread stuff
self.executor = ThreadPoolExecutor(max_workers=1)
# video params
self.output_prefix = "webcam"
self.res_factor = 3
self.fps = 16
self.frame_interval = 1 / self.fps
self.frame_compression = 25
def perform_inference(self, video_path):
# perform inference on the video with the vsr model
output = self.vsr_model(video_path)
# write the raw output
keyboard.write(output)
# shift left to select the entire output
cmd = ""
for i in range(len(output)):
cmd += 'shift+left, '
cmd = cmd[:-2]
keyboard.press_and_release(cmd)
# perform inference on the raw output to get back a "correct" version
response = chat(
model='llama3.2',
messages=[
{
'role': 'system',
'content': f"You are an assistant that helps make corrections to the output of a lipreading model. The text you will receive was transcribed using a video-to-text system that attempts to lipread the subject speaking in the video, so the text will likely be imperfect.\n\nIf something seems unusual, assume it was mistranscribed. Do your best to infer the words actually spoken, and make changes to the mistranscriptions in your response. Do not add more words or content, just change the ones that seem to be out of place (and, therefore, mistranscribed). Do not change even the wording of sentences, just individual words that look nonsensical in the context of all of the other words in the sentence.\n\nAlso, add correct punctuation to the entire text. ALWAYS end each sentence with the appropriate sentence ending: '.', '?', or '!'. The input text in all-caps, although your respose should be capitalized correctly and should NOT be in all-caps.\n\nReturn the corrected text in the format of 'list_of_changes' and 'corrected_text'."
},
{
'role': 'user',
'content': f"Transcription:\n\n{output}"
}
],
format=ChaplinOutput.model_json_schema()
)
# get only the corrected text
chat_output = ChaplinOutput.model_validate_json(
response.message.content)
# if last character isn't a sentence ending (happens sometimes), add a period
if chat_output.corrected_text[-1] not in ['.', '?', '!']:
chat_output.corrected_text += '.'
# write the corrected text
keyboard.write(chat_output.corrected_text + " ")
# return the corrected text and the video path
return {
"output": chat_output.corrected_text,
"video_path": video_path
}
def start_webcam(self):
# init webcam
cap = cv2.VideoCapture(0)
# set webcam resolution, and get frame dimensions
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640 // self.res_factor)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480 // self.res_factor)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
last_frame_time = time.time()
futures = []
output_path = ""
out = None
frame_count = 0
while True:
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
# remove any remaining videos that were saved to disk
for file in os.listdir():
if file.startswith(self.output_prefix) and file.endswith('.mp4'):
os.remove(file)
break
current_time = time.time()
# conditional ensures that the video is recorded at the correct frame rate
if current_time - last_frame_time >= self.frame_interval:
ret, frame = cap.read()
if ret:
# frame compression
encode_param = [
int(cv2.IMWRITE_JPEG_QUALITY), self.frame_compression]
_, buffer = cv2.imencode('.jpg', frame, encode_param)
compressed_frame = cv2.imdecode(
buffer, cv2.IMREAD_GRAYSCALE)
if self.recording:
if out is None:
output_path = self.output_prefix + \
str(time.time_ns() // 1_000_000) + '.mp4'
out = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
self.fps,
(frame_width, frame_height),
False # isColor
)
out.write(compressed_frame)
last_frame_time = current_time
# circle to indicate recording, only appears in the window and is not present in video saved to disk
cv2.circle(compressed_frame, (frame_width -
20, 20), 10, (0, 0, 0), -1)
frame_count += 1
# check if not recording AND video is at least 2 seconds long
elif not self.recording and frame_count > 0:
if out is not None:
out.release()
# only run inference if the video is at least 2 seconds long
if frame_count >= self.fps * 2:
futures.append(self.executor.submit(
self.perform_inference, output_path))
else:
os.remove(output_path)
output_path = self.output_prefix + \
str(time.time_ns() // 1_000_000) + '.mp4'
out = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
self.fps,
(frame_width, frame_height),
False # isColor
)
frame_count = 0
# display the frame in the window
cv2.imshow('Chaplin', cv2.flip(compressed_frame, 1))
# ensures that videos are handled in the order they were recorded
for fut in futures:
if fut.done():
result = fut.result()
# once done processing, delete the video with the video path
os.remove(result["video_path"])
futures.remove(fut)
else:
break
# release everything
cap.release()
if out:
out.release()
cv2.destroyAllWindows()
def on_action(self, event):
# toggles recording when alt key is pressed
if event.event_type == keyboard.KEY_DOWN and event.name == 'alt':
self.recording = not self.recording
@hydra.main(version_base=None, config_path="hydra_configs", config_name="default")
def main(cfg):
chaplin = Chaplin()
# hook to toggle recording
keyboard.hook(lambda e: chaplin.on_action(e))
# load the model
chaplin.vsr_model = InferencePipeline(
cfg.config_filename, device=torch.device(f"cuda:{cfg.gpu_idx}" if torch.cuda.is_available(
) and cfg.gpu_idx >= 0 else "cpu"), detector=cfg.detector, face_track=True)
print("Model loaded successfully!")
# start the webcam video capture
chaplin.start_webcam()
if __name__ == '__main__':
main()