视频字幕提取(2)-基于opencv和paddleocr

视频内字幕提取

video-subtitles-ocr

视频字幕提取,基于 opencv 和 paddleocr

视频内字幕提取

这里是针对内封了硬字幕的视频,字幕已经成为了画面的一部分。

思路:简单用 opencv 提取视频内的所有帧,现在可以使用 paddleocr,更加方便一点对图片进行 ocr 识别。

0. 首先需要配置一下

0.1 安装 python 库

  • opencv-python
  • scikit-image
  • paddleocr

ref: https://paddlepaddle.github.io/PaddleOCR/latest/quick_start.html#1-paddlepaddle

1. 读取视频

使用 opencv 读取视频

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import cv2

video_path = 'd7.mp4'

v = cv2.VideoCapture(video_path)
num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
fps = v.get(cv2.CAP_PROP_FPS)
height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(v.get(cv2.CAP_PROP_FRAME_WIDTH))

print(f'video : {video_path}\n'
f'num_frames : {num_frames}\n'
f'fps : {fps}\n'
f'resolution : {width} x {height}')

2. 提取所有帧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import datetime

def get_frame_index(time_str: str, fps: float):
t = time_str.split(':')
t = list(map(float, t))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
td = datetime.timedelta(minutes=t[0], seconds=t[1])
else:
raise ValueError(
'Time data "{}" does not match format "%H:%M:%S"'.format(time_str))
index = int(td.total_seconds() * fps)
return index

# 起始时间、结束时间
time_start = '0:00'
time_end = '0:10'
ocr_start = get_frame_index(time_start, fps) if time_start else 0
ocr_end = get_frame_index(time_end, fps) if time_end else num_frames
num_ocr_frames = ocr_end - ocr_start
print(f'ocr_start : {ocr_start}\n'
f'ocr_end : {ocr_end}\n'
f'num_ocr_frames : {num_ocr_frames}')

3. 只保留画面中有字幕的区域

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# *** 调整字幕区域的高度,按比例 ***
h1, h2 = 0.86, 0.94
h1, h2 = int(height * h1), int(height * h2)

v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
frames = [v.read()[1] for _ in range(num_ocr_frames)]
z_frames = [frame[h1:h2, :] for frame in frames]

# 预览一下
title = 'preview'
cv2.startWindowThread()
cv2.namedWindow(title)
for idx, img in enumerate(z_frames):
tmp_img = img.copy()
cv2.putText(tmp_img, f'idx:{idx}', (5, 25),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
cv2.imshow(title, tmp_img)
cv2.imshow(title, img)
cv2.waitKey(50)
cv2.destroyWindow(title)
cv2.destroyAllWindows()

4. 去除相似度较高的帧,保留关键帧

为了减少识别量,先去除一部分相似度较高的图片。

  • 计算两个图片的均方差(即 MSE), 采用 skimage.metrics.mean_squared_error 函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

# 设置阈值
mse_threshold = 100

from skimage.metrics import mean_squared_error

k_frames = [{'start': 0,
'end': 0,
'frame': z_frames[0],
'text': ''}]

for idx in range(1, num_ocr_frames):
img1 = z_frames[idx - 1]
img2 = z_frames[idx]

mse = mean_squared_error(img1, img2)
# print(idx, mse)

if mse < mse_threshold:
k_frames[-1]['end'] = idx
else:
k_frames.append({'start': idx,
'end': idx,
'frame': z_frames[idx],
'text': ''})

for kf in k_frames:
print(f"{kf['start']} --> {kf['end']} : {kf['text']}")

5.1 识别字幕 pytesseract

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import pytesseract

config = f'--tessdata-dir "{tessdata_dir}" --psm 7'

for idx, kf in enumerate(k_frames):
# 识别为字符串
ocr_str = pytesseract.image_to_string(kf['frame'], lang=lang, config=config)
ocr_str = ocr_str.strip().replace(' ', '')

if ocr_str:
k_frames[idx]['text'] = ocr_str
print(f"{kf['start']} --> {kf['end']} : {kf['text']}")

print([k_frames.remove(kf) for kf in k_frames if not kf['text']])

5.2 识别字幕 paddleocr

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

from paddleocr import PaddleOCR

ocr = PaddleOCR(lang='ch')

for idx, kf in enumerate(k_frames):
# 识别字符串
result = ocr.ocr(kf['frame'])
print(result)
for line in result:
if line == None: break

words = ''
for rect, word in line:
words += word[0]
print(idx, words)

k_frames[idx]['text'] = words

print([k_frames.remove(kf) for kf in k_frames if not kf['text']])

6. 格式化字幕

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

for kf in k_frames:
print(f"{kf['start']} --> {kf['end']} : {kf['text']}")


def get_srt_timestamp(frame_index: int, fps: float):
td = datetime.timedelta(seconds=frame_index / fps)
ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)


for kf in k_frames:
time1 = get_srt_timestamp(kf['start'], fps)
time2 = get_srt_timestamp(kf['end'], fps)

print(f"{time1} --> {time2}\n{kf['text']}\n")