树莓派 tensorflow-lite 目标检测

0. 安装 tflite-runtime

ref: https://tensorflow.google.cn/lite/guide/python

1
pip3 install https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp37-cp37m-linux_armv7l.whl

1. tensorflow官方示例

tensorflow 提供了一个示例, 基于picamera的。

ref: https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 1. Clone
git clone https://github.com/tensorflow/examples --depth 1

# 2. 进入文件夹
cd examples/lite/examples/object_detection/raspberry_pi

# 文件夹里总共5个文件
# README.md #
# annotation.py # 用于绘制方框、标签
# detect_picamera.py # 主程序
# download.sh # 下载 python 依赖包、已训练的模型
# requirements.txt #


# 3. 下载已训练好的模型
bash download.sh /tmp
# - 下载 python 依赖包: numpy picamera Pillow
# - 下载 coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip, 里面有两个文件:detect.tflite, labelmap.txt,这个label文件有乱码
# - 下载正确的label文件: https://dl.google.com/coral/canned_models/coco_labels.txt

# 4. 运行程序
python3 detect_picamera.py --model /tmp/detect.tflite --labels /tmp/coco_labels.txt

2. 使用 opencv 调用 usb camera

我这里没有 picamera,只有一个老的 usb 接口的摄像头。 但 picamera 的 API 不支持 USB 摄像头。

下面改一下代码 使用 opencv 来调用 usb camera.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Example using TF Lite to detect objects with the Raspberry USB camera.

Hardware:
- Pi 3b+
- usb camera

Software
- python 3.7.3
- tflite runtime 2.1
- opencv

Dataset
- coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip
"""

import re
import time

import numpy as np
import cv2

from tflite_runtime.interpreter import Interpreter

args_camera_width = 640
args_camera_height = 480
args_model = 'detect.tflite'
args_labels = 'coco_labels.txt'
args_threshold = 0.4


def load_labels(path):
"""Loads the labels file. Supports files with or without index numbers."""
with open(path, 'r', encoding='utf-8') as f:
lines = f.readlines()
labels = {}
for row_number, content in enumerate(lines):
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
labels[int(pair[0])] = pair[1].strip()
else:
labels[row_number] = pair[0].strip()

return labels


def detect_objects(interpreter, image, threshold):
# 识别:张量填充,运行推理
interpreter.set_tensor(input_details[0]['index'], input_image)
interpreter.invoke()

# 结果输出
boxes = interpreter.get_tensor(output_details[0]['index'])
classes = interpreter.get_tensor(output_details[1]['index'])
scores = interpreter.get_tensor(output_details[2]['index'])
boxes = np.squeeze(boxes)
classes = np.squeeze(classes).astype(np.int32)
scores = np.squeeze(scores)
# print('boxes:', boxes)
# print('classes:', classes)
# print('scores:', classes)

# 设置识别阈值,剔除不好的结果
results = []
for i, score in enumerate(scores):
if score >= threshold:
result = {
'box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results


def annotate_objects(image, results):
for rst in results:
ymin, xmin, ymax, xmax = rst['box']
class_id = rst['class_id']
name = labels_dict[class_id]
score = rst['score']

xmin = int(xmin * args_camera_width)
xmax = int(xmax * args_camera_width)
ymin = int(ymin * args_camera_height)
ymax = int(ymax * args_camera_height)
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0))

txt = f'{name} {score:.2%}'
cv2.putText(image, txt, (xmin, ymin), 0, 1, (255, 255, 255), 2)


# 1. 读取 labels
labels_dict = load_labels(args_labels)
print('labels_dict: \n ', labels_dict)

# 2. 加载模型文件
interpreter = Interpreter(args_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# print('input_details:\n ', input_details)
# print('output_details:\n ', output_details)

# 3. 打开摄像头
camera = cv2.VideoCapture(0)
camera.set(3, args_camera_height)
camera.set(4, args_camera_width)

frame_rate_calc = 1.0
freq = cv2.getTickFrequency()

# 4. 目标识别
while (True):
# 4.1 计算FPS, 开始计时
t1 = cv2.getTickCount()

# 4.2 从摄像头读取图片, 缩放为 300x300
ret, frame = camera.read()
input_image = cv2.resize(frame, (300, 300))
input_image = np.expand_dims(input_image, axis=0)
input_image = np.uint8(np.float32(input_image))

# 4.3 识别:张量填充,运行推理
results = detect_objects(interpreter, input_image, args_threshold)

print(f'--- {time.strftime("%Y-%m-%d %H:%M:%S")} ---')
for rst in results:
box = rst['box']
class_id = rst['class_id']
name = labels_dict[class_id]
score = rst['score']
print(f'* {name} : {score:.2%} @ {box}')

# 4.4 将识别结果绘制在原图上
annotate_objects(frame, results)

# 4.5 将 FPS 绘制在原图上
txt = f'FPS: {frame_rate_calc:.2f}'
cv2.putText(frame, txt, (20, 30), 0, 1, (0, 255, 255), 2)

# 4.6 显示图片
cv2.imshow('Object detect', frame)

# 4.7 更新计算 FPS
t2 = cv2.getTickCount()
frame_rate_calc = freq / (t2 - t1)

cv2.waitKey(1)

camera.release()
cv2.destroyAllWindows()