해결된 질문
작성
·
131
0
강사님 ssd_mobile_net 코드를 따라서 작성하던 중 아래와 같은 오류가 나와서 강사님의 코드를 복사 붙여 넣기 해도 같은 오류가 나와서 어떻게 해야할지 모르겠습니다.
def get_tensor_detected_image(sess, img_array, use_copied_array):
rows = img_array.shape[0]
cols = img_array.shape[1]
if use_copied_array:
draw_img = img_array.copy()
else:
draw_img = img_array
inp = cv2.resize(img_array, (300,300))
inp = inp[:,:,[2,1,0]]
start = time.time()
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0':inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
green_color = (0,255,0)
red_color = (0,0,255)
num_detections = int(out[0][0])
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score > 0.3:
left = bbox[1] * cols
top = bbox[0] * rows
right = bbox[3] * cols
bottom = bbox[2] * rows
cv2.rectangle(draw_img, (int(left), int(top)), (int(right), int(bottom)), green_color, thickness = 2)
caption = "{}:{:.4f}".format(labels_to_names[classId], score)
cv2.putText(draw_img, caption, (int(left), int(top - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.4, red_color, 1)
print('Detection 수행시간:', round(time.time() - start,3),"초")
return draw_img
import numpy as np
import tensorflow as tf
import cv2
import time
import matplotlib.pyplot as plt
%matplotlib inline
video_input_path = '../../data/video/Night_Day_Chase.mp4'
video_output_path = '../../data/output/Night_Day_Chase_tensor_ssd_mobile_01.mp4'
cap = cv2.VideoCapture(video_input_path)
codec = cv2.VideoWriter_fourcc(*'XVID')
vid_size = (round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
vid_fps = cap.get(cv2.CAP_PROP_FPS)
vid_writer = cv2.VideoWriter(video_output_path, codec, vid_fps, vid_size)
frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print('총 Frame의 갯수:', frame_cnt, 'FPS:',vid_fps)
with tf.gfile.FastGFile('/home/bgw2001/DLCV/Detection/ssd/pretrained/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def, name = '')
index = 0
while True:
hasFrame, img_frame = cap.read()
if not hasFrame:
print('더 이상 처리할 frame이 없습니다.')
break
draw_img_frame = get_tensor_detected_image(sess=sess, img_array=img_frame, use_copied_array=False)
vid_writer.write(draw_img_frame)
vid_writer.release()
cap.release()
--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1333 try: -> 1334 return fn(*args) 1335 except errors.OpError as e: ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 1316 # Ensure any changes to the graph are reflected in the runtime. -> 1317 self._extend_graph() 1318 return self._call_tf_sessionrun( ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _extend_graph(self) 1351 with self._graph._session_run_lock(): # pylint: disable=protected-access -> 1352 tf_session.ExtendSession(self._session) 1353 InvalidArgumentError: Input 1 of node Preprocessor/map/while/Merge_2_1 was passed int32 from Preprocessor/map/while/NextIteration_2:0 incompatible with expected float. During handling of the above exception, another exception occurred: InvalidArgumentError Traceback (most recent call last) <ipython-input-48-071e2e179d8f> in <module> 12 print('더 이상 처리할 frame이 없습니다.') 13 break ---> 14 draw_img_frame = get_tensor_detected_image(sess=sess, img_array=img_frame, use_copied_array=False) 15 vid_writer.write(draw_img_frame) 16 vid_writer.release() <ipython-input-46-52db48ba4bca> in get_tensor_detected_image(sess, img_array, use_copied_array) 17 sess.graph.get_tensor_by_name('detection_boxes:0'), 18 sess.graph.get_tensor_by_name('detection_classes:0')], ---> 19 feed_dict={'image_tensor:0':inp.reshape(1, inp.shape[0], inp.shape[1], 3)}) 20 green_color = (0,255,0) 21 red_color = (0,0,255) ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 927 try: 928 result = self._run(None, fetches, feed_dict, options_ptr, --> 929 run_metadata_ptr) 930 if run_metadata: 931 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1150 if final_fetches or final_targets or (handle and feed_dict_tensor): 1151 results = self._do_run(handle, final_targets, final_fetches, -> 1152 feed_dict_tensor, options, run_metadata) 1153 else: 1154 results = [] ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1326 if handle is None: 1327 return self._do_call(_run_fn, feeds, fetches, targets, options, -> 1328 run_metadata) 1329 else: 1330 return self._do_call(_prun_fn, handle, feeds, fetches) ~/anaconda3/envs/tf113/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1346 pass 1347 message = error_interpolation.interpolate(message, self._graph) -> 1348 raise type(e)(node_def, op, message) 1349 1350 def _extend_graph(self): InvalidArgumentError: Input 1 of node Preprocessor/map/while/Merge_2_1 was passed int32 from Preprocessor/map/while/NextIteration_2:0 incompatible with expected float.