要按类别计算TensorFlow目标检测数量,可以使用以下解决方法:
import tensorflow as tf
from object_detection.utils import label_map_util
from collections import defaultdict
model_path = 'path_to_your_model_directory/frozen_inference_graph.pb'
label_map_path = 'path_to_your_label_map_file/label_map.pbtxt'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v2.io.gfile.GFile(model_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.compat.v1.Session() as sess:
# 获取输入和输出张量
image_tensor = graph.get_tensor_by_name('image_tensor:0')
detection_boxes = graph.get_tensor_by_name('detection_boxes:0')
detection_scores = graph.get_tensor_by_name('detection_scores:0')
detection_classes = graph.get_tensor_by_name('detection_classes:0')
num_detections = graph.get_tensor_by_name('num_detections:0')
# 运行推理
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: np.expand_dims(image, axis=0)})
# 移除维度为1的尺寸
boxes = np.squeeze(boxes)
scores = np.squeeze(scores)
classes = np.squeeze(classes).astype(np.int32)
return boxes, scores, classes, num
image_path = 'path_to_your_image/image.jpg'
image = Image.open(image_path)
image_np = np.array(image)
boxes, scores, classes, num = run_inference_for_single_image(image_np, detection_graph)
count_dict = defaultdict(int)
for i in range(num):
if scores[i] > 0.5: # 可以调整置信度的阈值
class_name = category_index[classes[i]]['name']
count_dict[class_name] += 1
for class_name, count in count_dict.items():
print(f'{class_name}: {count}')
这样就可以按类别计算TensorFlow目标检测数量并打印出结果。请确保替换代码中的路径为相应的模型、标签映射文件和图像的路径。
上一篇:按类别计算SQL百分比