Git clone直後の場合
COCO APIのインストール
1 2 3 4 | $ git clone https://github.com/cocodataset/cocoapi.git
$ cd cocoapi/PythonAPI
$ make
$ cp -r pycocotools ../../models/research/
|
1 2 3 | $ cd models/reasearch $ protoc object_detection/protos/*.proto --python_out=. docker内で実行 |
1 | $ export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim |
モデル
1 2 3 4 5 6 7 8 9 10 | $ cd /tmp $ curl -O http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz $ tar xzf ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz $ gsutil cp /tmp/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03/model.ckpt.* gs://${YOUR_GCS_BUCKET}/data/``` # TPUで学習 ```shell $ vim object_detection/samples/configs/object_detection/samples/configs/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync.config $ gsutil cp object_detection/samples/configs/object_detection/samples/configs/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync.config gs://${YOUR_GCS_BUCKET}/data/pipeline.config |
1 | gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` --job-dir=gs://${YOUR_GCS_BUCKET}/train_ppn --packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz --module-name object_detection.model_tpu_main --runtime-version 1.8 --scale-tier BASIC_TPU --region us-central1 -- --model_dir=gs://${YOUR_GCS_BUCKET}/train_ppn --tpu_zone us-central1 --pipeline_config_path=gs://${YOUR_GCS_BUCKET}/data/pipeline.config |
取得
1 2 3 | $ export CONFIG_FILE=gs://${YOUR_GCS_BUCKET}/data/pipeline.config $ export CHECKPOINT_PATH=gs://${YOUR_GCS_BUCKET}/train_ppn/model.ckpt-2000 $ export OUTPUT_DIR=/tmp/tflite |
1 | $ python object_detection/export_tflite_ssd_graph.py --pipeline_config_path=$CONFIG_FILE --trained_checkpoint_prefix=$CHECKPOINT_PATH --output_directory=$OUTPUT_DIR --add_postprocessing_op=true |
toco
1 | bazel run -c opt tensorflow/contrib/lite/toco:toco -- --input_file=$OUTPUT_DIR/tflite_graph.pb --output_file=$OUTPUT_DIR/detect.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --inference_type=QUANTIZED_UINT8 --mean_values=128 --std_values=128 --change_concat_input_ranges=false --allow_custom_ops --default_ranges_min=0 --default_ranges_max=6 |
後処理
1 | $ git reset --hard |