d/ht: introduce ht_model

This commit is contained in:
Simon Zeni 2021-11-09 13:46:55 -05:00
parent e84d948fb4
commit 0e53b3b87f
4 changed files with 647 additions and 0 deletions

View file

@ -218,6 +218,7 @@ if(XRT_BUILD_DRIVER_HANDTRACKING)
ht/ht_driver.hpp
ht/ht_interface.h
ht/ht_models.cpp
ht/ht_model.cpp
ht/ht_hand_math.cpp
ht/ht_image_math.cpp
ht/ht_nms.cpp

View file

@ -0,0 +1,594 @@
// Copyright 2021, Collabora, Ltd.
// SPDX-License-Identifier: BSL-1.0
/*!
* @file
* @brief Code to run machine learning models for camera-based hand tracker.
* @author Moses Turner <moses@collabora.com>
* @author Marcus Edel <marcus.edel@collabora.com>
* @author Simon Zeni <simon@bl4ckb0ne.ca>
* @ingroup drv_ht
*/
// Many C api things were stolen from here (MIT license):
// https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c
#include "ht_driver.hpp"
#include "ht_image_math.hpp"
#include "ht_model.hpp"
#include "ht_nms.hpp"
#include <array>
#undef HEAVY_SCRIBBLE
/*
* Anchors data taken from mediapipe's palm detection, used for single-shot detector model.
*
* See:
* https://google.github.io/mediapipe/solutions/hands.html#palm-detection-model
* https://github.com/google/mediapipe/blob/v0.8.8/mediapipe/calculators/tflite/ssd_anchors_calculator.cc#L101
* https://github.com/google/mediapipe/blob/v0.8.8/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt#L60
*/
struct anchor
{
float x, y;
};
static const struct anchor anchors[896]{
{0.031250, 0.031250}, {0.031250, 0.031250}, {0.093750, 0.031250}, {0.093750, 0.031250}, //
{0.156250, 0.031250}, {0.156250, 0.031250}, {0.218750, 0.031250}, {0.218750, 0.031250}, //
{0.281250, 0.031250}, {0.281250, 0.031250}, {0.343750, 0.031250}, {0.343750, 0.031250}, //
{0.406250, 0.031250}, {0.406250, 0.031250}, {0.468750, 0.031250}, {0.468750, 0.031250}, //
{0.531250, 0.031250}, {0.531250, 0.031250}, {0.593750, 0.031250}, {0.593750, 0.031250}, //
{0.656250, 0.031250}, {0.656250, 0.031250}, {0.718750, 0.031250}, {0.718750, 0.031250}, //
{0.781250, 0.031250}, {0.781250, 0.031250}, {0.843750, 0.031250}, {0.843750, 0.031250}, //
{0.906250, 0.031250}, {0.906250, 0.031250}, {0.968750, 0.031250}, {0.968750, 0.031250}, //
{0.031250, 0.093750}, {0.031250, 0.093750}, {0.093750, 0.093750}, {0.093750, 0.093750}, //
{0.156250, 0.093750}, {0.156250, 0.093750}, {0.218750, 0.093750}, {0.218750, 0.093750}, //
{0.281250, 0.093750}, {0.281250, 0.093750}, {0.343750, 0.093750}, {0.343750, 0.093750}, //
{0.406250, 0.093750}, {0.406250, 0.093750}, {0.468750, 0.093750}, {0.468750, 0.093750}, //
{0.531250, 0.093750}, {0.531250, 0.093750}, {0.593750, 0.093750}, {0.593750, 0.093750}, //
{0.656250, 0.093750}, {0.656250, 0.093750}, {0.718750, 0.093750}, {0.718750, 0.093750}, //
{0.781250, 0.093750}, {0.781250, 0.093750}, {0.843750, 0.093750}, {0.843750, 0.093750}, //
{0.906250, 0.093750}, {0.906250, 0.093750}, {0.968750, 0.093750}, {0.968750, 0.093750}, //
{0.031250, 0.156250}, {0.031250, 0.156250}, {0.093750, 0.156250}, {0.093750, 0.156250}, //
{0.156250, 0.156250}, {0.156250, 0.156250}, {0.218750, 0.156250}, {0.218750, 0.156250}, //
{0.281250, 0.156250}, {0.281250, 0.156250}, {0.343750, 0.156250}, {0.343750, 0.156250}, //
{0.406250, 0.156250}, {0.406250, 0.156250}, {0.468750, 0.156250}, {0.468750, 0.156250}, //
{0.531250, 0.156250}, {0.531250, 0.156250}, {0.593750, 0.156250}, {0.593750, 0.156250}, //
{0.656250, 0.156250}, {0.656250, 0.156250}, {0.718750, 0.156250}, {0.718750, 0.156250}, //
{0.781250, 0.156250}, {0.781250, 0.156250}, {0.843750, 0.156250}, {0.843750, 0.156250}, //
{0.906250, 0.156250}, {0.906250, 0.156250}, {0.968750, 0.156250}, {0.968750, 0.156250}, //
{0.031250, 0.218750}, {0.031250, 0.218750}, {0.093750, 0.218750}, {0.093750, 0.218750}, //
{0.156250, 0.218750}, {0.156250, 0.218750}, {0.218750, 0.218750}, {0.218750, 0.218750}, //
{0.281250, 0.218750}, {0.281250, 0.218750}, {0.343750, 0.218750}, {0.343750, 0.218750}, //
{0.406250, 0.218750}, {0.406250, 0.218750}, {0.468750, 0.218750}, {0.468750, 0.218750}, //
{0.531250, 0.218750}, {0.531250, 0.218750}, {0.593750, 0.218750}, {0.593750, 0.218750}, //
{0.656250, 0.218750}, {0.656250, 0.218750}, {0.718750, 0.218750}, {0.718750, 0.218750}, //
{0.781250, 0.218750}, {0.781250, 0.218750}, {0.843750, 0.218750}, {0.843750, 0.218750}, //
{0.906250, 0.218750}, {0.906250, 0.218750}, {0.968750, 0.218750}, {0.968750, 0.218750}, //
{0.031250, 0.281250}, {0.031250, 0.281250}, {0.093750, 0.281250}, {0.093750, 0.281250}, //
{0.156250, 0.281250}, {0.156250, 0.281250}, {0.218750, 0.281250}, {0.218750, 0.281250}, //
{0.281250, 0.281250}, {0.281250, 0.281250}, {0.343750, 0.281250}, {0.343750, 0.281250}, //
{0.406250, 0.281250}, {0.406250, 0.281250}, {0.468750, 0.281250}, {0.468750, 0.281250}, //
{0.531250, 0.281250}, {0.531250, 0.281250}, {0.593750, 0.281250}, {0.593750, 0.281250}, //
{0.656250, 0.281250}, {0.656250, 0.281250}, {0.718750, 0.281250}, {0.718750, 0.281250}, //
{0.781250, 0.281250}, {0.781250, 0.281250}, {0.843750, 0.281250}, {0.843750, 0.281250}, //
{0.906250, 0.281250}, {0.906250, 0.281250}, {0.968750, 0.281250}, {0.968750, 0.281250}, //
{0.031250, 0.343750}, {0.031250, 0.343750}, {0.093750, 0.343750}, {0.093750, 0.343750}, //
{0.156250, 0.343750}, {0.156250, 0.343750}, {0.218750, 0.343750}, {0.218750, 0.343750}, //
{0.281250, 0.343750}, {0.281250, 0.343750}, {0.343750, 0.343750}, {0.343750, 0.343750}, //
{0.406250, 0.343750}, {0.406250, 0.343750}, {0.468750, 0.343750}, {0.468750, 0.343750}, //
{0.531250, 0.343750}, {0.531250, 0.343750}, {0.593750, 0.343750}, {0.593750, 0.343750}, //
{0.656250, 0.343750}, {0.656250, 0.343750}, {0.718750, 0.343750}, {0.718750, 0.343750}, //
{0.781250, 0.343750}, {0.781250, 0.343750}, {0.843750, 0.343750}, {0.843750, 0.343750}, //
{0.906250, 0.343750}, {0.906250, 0.343750}, {0.968750, 0.343750}, {0.968750, 0.343750}, //
{0.031250, 0.406250}, {0.031250, 0.406250}, {0.093750, 0.406250}, {0.093750, 0.406250}, //
{0.156250, 0.406250}, {0.156250, 0.406250}, {0.218750, 0.406250}, {0.218750, 0.406250}, //
{0.281250, 0.406250}, {0.281250, 0.406250}, {0.343750, 0.406250}, {0.343750, 0.406250}, //
{0.406250, 0.406250}, {0.406250, 0.406250}, {0.468750, 0.406250}, {0.468750, 0.406250}, //
{0.531250, 0.406250}, {0.531250, 0.406250}, {0.593750, 0.406250}, {0.593750, 0.406250}, //
{0.656250, 0.406250}, {0.656250, 0.406250}, {0.718750, 0.406250}, {0.718750, 0.406250}, //
{0.781250, 0.406250}, {0.781250, 0.406250}, {0.843750, 0.406250}, {0.843750, 0.406250}, //
{0.906250, 0.406250}, {0.906250, 0.406250}, {0.968750, 0.406250}, {0.968750, 0.406250}, //
{0.031250, 0.468750}, {0.031250, 0.468750}, {0.093750, 0.468750}, {0.093750, 0.468750}, //
{0.156250, 0.468750}, {0.156250, 0.468750}, {0.218750, 0.468750}, {0.218750, 0.468750}, //
{0.281250, 0.468750}, {0.281250, 0.468750}, {0.343750, 0.468750}, {0.343750, 0.468750}, //
{0.406250, 0.468750}, {0.406250, 0.468750}, {0.468750, 0.468750}, {0.468750, 0.468750}, //
{0.531250, 0.468750}, {0.531250, 0.468750}, {0.593750, 0.468750}, {0.593750, 0.468750}, //
{0.656250, 0.468750}, {0.656250, 0.468750}, {0.718750, 0.468750}, {0.718750, 0.468750}, //
{0.781250, 0.468750}, {0.781250, 0.468750}, {0.843750, 0.468750}, {0.843750, 0.468750}, //
{0.906250, 0.468750}, {0.906250, 0.468750}, {0.968750, 0.468750}, {0.968750, 0.468750}, //
{0.031250, 0.531250}, {0.031250, 0.531250}, {0.093750, 0.531250}, {0.093750, 0.531250}, //
{0.156250, 0.531250}, {0.156250, 0.531250}, {0.218750, 0.531250}, {0.218750, 0.531250}, //
{0.281250, 0.531250}, {0.281250, 0.531250}, {0.343750, 0.531250}, {0.343750, 0.531250}, //
{0.406250, 0.531250}, {0.406250, 0.531250}, {0.468750, 0.531250}, {0.468750, 0.531250}, //
{0.531250, 0.531250}, {0.531250, 0.531250}, {0.593750, 0.531250}, {0.593750, 0.531250}, //
{0.656250, 0.531250}, {0.656250, 0.531250}, {0.718750, 0.531250}, {0.718750, 0.531250}, //
{0.781250, 0.531250}, {0.781250, 0.531250}, {0.843750, 0.531250}, {0.843750, 0.531250}, //
{0.906250, 0.531250}, {0.906250, 0.531250}, {0.968750, 0.531250}, {0.968750, 0.531250}, //
{0.031250, 0.593750}, {0.031250, 0.593750}, {0.093750, 0.593750}, {0.093750, 0.593750}, //
{0.156250, 0.593750}, {0.156250, 0.593750}, {0.218750, 0.593750}, {0.218750, 0.593750}, //
{0.281250, 0.593750}, {0.281250, 0.593750}, {0.343750, 0.593750}, {0.343750, 0.593750}, //
{0.406250, 0.593750}, {0.406250, 0.593750}, {0.468750, 0.593750}, {0.468750, 0.593750}, //
{0.531250, 0.593750}, {0.531250, 0.593750}, {0.593750, 0.593750}, {0.593750, 0.593750}, //
{0.656250, 0.593750}, {0.656250, 0.593750}, {0.718750, 0.593750}, {0.718750, 0.593750}, //
{0.781250, 0.593750}, {0.781250, 0.593750}, {0.843750, 0.593750}, {0.843750, 0.593750}, //
{0.906250, 0.593750}, {0.906250, 0.593750}, {0.968750, 0.593750}, {0.968750, 0.593750}, //
{0.031250, 0.656250}, {0.031250, 0.656250}, {0.093750, 0.656250}, {0.093750, 0.656250}, //
{0.156250, 0.656250}, {0.156250, 0.656250}, {0.218750, 0.656250}, {0.218750, 0.656250}, //
{0.281250, 0.656250}, {0.281250, 0.656250}, {0.343750, 0.656250}, {0.343750, 0.656250}, //
{0.406250, 0.656250}, {0.406250, 0.656250}, {0.468750, 0.656250}, {0.468750, 0.656250}, //
{0.531250, 0.656250}, {0.531250, 0.656250}, {0.593750, 0.656250}, {0.593750, 0.656250}, //
{0.656250, 0.656250}, {0.656250, 0.656250}, {0.718750, 0.656250}, {0.718750, 0.656250}, //
{0.781250, 0.656250}, {0.781250, 0.656250}, {0.843750, 0.656250}, {0.843750, 0.656250}, //
{0.906250, 0.656250}, {0.906250, 0.656250}, {0.968750, 0.656250}, {0.968750, 0.656250}, //
{0.031250, 0.718750}, {0.031250, 0.718750}, {0.093750, 0.718750}, {0.093750, 0.718750}, //
{0.156250, 0.718750}, {0.156250, 0.718750}, {0.218750, 0.718750}, {0.218750, 0.718750}, //
{0.281250, 0.718750}, {0.281250, 0.718750}, {0.343750, 0.718750}, {0.343750, 0.718750}, //
{0.406250, 0.718750}, {0.406250, 0.718750}, {0.468750, 0.718750}, {0.468750, 0.718750}, //
{0.531250, 0.718750}, {0.531250, 0.718750}, {0.593750, 0.718750}, {0.593750, 0.718750}, //
{0.656250, 0.718750}, {0.656250, 0.718750}, {0.718750, 0.718750}, {0.718750, 0.718750}, //
{0.781250, 0.718750}, {0.781250, 0.718750}, {0.843750, 0.718750}, {0.843750, 0.718750}, //
{0.906250, 0.718750}, {0.906250, 0.718750}, {0.968750, 0.718750}, {0.968750, 0.718750}, //
{0.031250, 0.781250}, {0.031250, 0.781250}, {0.093750, 0.781250}, {0.093750, 0.781250}, //
{0.156250, 0.781250}, {0.156250, 0.781250}, {0.218750, 0.781250}, {0.218750, 0.781250}, //
{0.281250, 0.781250}, {0.281250, 0.781250}, {0.343750, 0.781250}, {0.343750, 0.781250}, //
{0.406250, 0.781250}, {0.406250, 0.781250}, {0.468750, 0.781250}, {0.468750, 0.781250}, //
{0.531250, 0.781250}, {0.531250, 0.781250}, {0.593750, 0.781250}, {0.593750, 0.781250}, //
{0.656250, 0.781250}, {0.656250, 0.781250}, {0.718750, 0.781250}, {0.718750, 0.781250}, //
{0.781250, 0.781250}, {0.781250, 0.781250}, {0.843750, 0.781250}, {0.843750, 0.781250}, //
{0.906250, 0.781250}, {0.906250, 0.781250}, {0.968750, 0.781250}, {0.968750, 0.781250}, //
{0.031250, 0.843750}, {0.031250, 0.843750}, {0.093750, 0.843750}, {0.093750, 0.843750}, //
{0.156250, 0.843750}, {0.156250, 0.843750}, {0.218750, 0.843750}, {0.218750, 0.843750}, //
{0.281250, 0.843750}, {0.281250, 0.843750}, {0.343750, 0.843750}, {0.343750, 0.843750}, //
{0.406250, 0.843750}, {0.406250, 0.843750}, {0.468750, 0.843750}, {0.468750, 0.843750}, //
{0.531250, 0.843750}, {0.531250, 0.843750}, {0.593750, 0.843750}, {0.593750, 0.843750}, //
{0.656250, 0.843750}, {0.656250, 0.843750}, {0.718750, 0.843750}, {0.718750, 0.843750}, //
{0.781250, 0.843750}, {0.781250, 0.843750}, {0.843750, 0.843750}, {0.843750, 0.843750}, //
{0.906250, 0.843750}, {0.906250, 0.843750}, {0.968750, 0.843750}, {0.968750, 0.843750}, //
{0.031250, 0.906250}, {0.031250, 0.906250}, {0.093750, 0.906250}, {0.093750, 0.906250}, //
{0.156250, 0.906250}, {0.156250, 0.906250}, {0.218750, 0.906250}, {0.218750, 0.906250}, //
{0.281250, 0.906250}, {0.281250, 0.906250}, {0.343750, 0.906250}, {0.343750, 0.906250}, //
{0.406250, 0.906250}, {0.406250, 0.906250}, {0.468750, 0.906250}, {0.468750, 0.906250}, //
{0.531250, 0.906250}, {0.531250, 0.906250}, {0.593750, 0.906250}, {0.593750, 0.906250}, //
{0.656250, 0.906250}, {0.656250, 0.906250}, {0.718750, 0.906250}, {0.718750, 0.906250}, //
{0.781250, 0.906250}, {0.781250, 0.906250}, {0.843750, 0.906250}, {0.843750, 0.906250}, //
{0.906250, 0.906250}, {0.906250, 0.906250}, {0.968750, 0.906250}, {0.968750, 0.906250}, //
{0.031250, 0.968750}, {0.031250, 0.968750}, {0.093750, 0.968750}, {0.093750, 0.968750}, //
{0.156250, 0.968750}, {0.156250, 0.968750}, {0.218750, 0.968750}, {0.218750, 0.968750}, //
{0.281250, 0.968750}, {0.281250, 0.968750}, {0.343750, 0.968750}, {0.343750, 0.968750}, //
{0.406250, 0.968750}, {0.406250, 0.968750}, {0.468750, 0.968750}, {0.468750, 0.968750}, //
{0.531250, 0.968750}, {0.531250, 0.968750}, {0.593750, 0.968750}, {0.593750, 0.968750}, //
{0.656250, 0.968750}, {0.656250, 0.968750}, {0.718750, 0.968750}, {0.718750, 0.968750}, //
{0.781250, 0.968750}, {0.781250, 0.968750}, {0.843750, 0.968750}, {0.843750, 0.968750}, //
{0.906250, 0.968750}, {0.906250, 0.968750}, {0.968750, 0.968750}, {0.968750, 0.968750}, //
{0.062500, 0.062500}, {0.062500, 0.062500}, {0.062500, 0.062500}, {0.062500, 0.062500}, //
{0.062500, 0.062500}, {0.062500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, //
{0.187500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, //
{0.312500, 0.062500}, {0.312500, 0.062500}, {0.312500, 0.062500}, {0.312500, 0.062500}, //
{0.312500, 0.062500}, {0.312500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, //
{0.437500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, //
{0.562500, 0.062500}, {0.562500, 0.062500}, {0.562500, 0.062500}, {0.562500, 0.062500}, //
{0.562500, 0.062500}, {0.562500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, //
{0.687500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, //
{0.812500, 0.062500}, {0.812500, 0.062500}, {0.812500, 0.062500}, {0.812500, 0.062500}, //
{0.812500, 0.062500}, {0.812500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, //
{0.937500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, //
{0.062500, 0.187500}, {0.062500, 0.187500}, {0.062500, 0.187500}, {0.062500, 0.187500}, //
{0.062500, 0.187500}, {0.062500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, //
{0.187500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, //
{0.312500, 0.187500}, {0.312500, 0.187500}, {0.312500, 0.187500}, {0.312500, 0.187500}, //
{0.312500, 0.187500}, {0.312500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, //
{0.437500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, //
{0.562500, 0.187500}, {0.562500, 0.187500}, {0.562500, 0.187500}, {0.562500, 0.187500}, //
{0.562500, 0.187500}, {0.562500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, //
{0.687500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, //
{0.812500, 0.187500}, {0.812500, 0.187500}, {0.812500, 0.187500}, {0.812500, 0.187500}, //
{0.812500, 0.187500}, {0.812500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, //
{0.937500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, //
{0.062500, 0.312500}, {0.062500, 0.312500}, {0.062500, 0.312500}, {0.062500, 0.312500}, //
{0.062500, 0.312500}, {0.062500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, //
{0.187500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, //
{0.312500, 0.312500}, {0.312500, 0.312500}, {0.312500, 0.312500}, {0.312500, 0.312500}, //
{0.312500, 0.312500}, {0.312500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, //
{0.437500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, //
{0.562500, 0.312500}, {0.562500, 0.312500}, {0.562500, 0.312500}, {0.562500, 0.312500}, //
{0.562500, 0.312500}, {0.562500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, //
{0.687500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, //
{0.812500, 0.312500}, {0.812500, 0.312500}, {0.812500, 0.312500}, {0.812500, 0.312500}, //
{0.812500, 0.312500}, {0.812500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, //
{0.937500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, //
{0.062500, 0.437500}, {0.062500, 0.437500}, {0.062500, 0.437500}, {0.062500, 0.437500}, //
{0.062500, 0.437500}, {0.062500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, //
{0.187500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, //
{0.312500, 0.437500}, {0.312500, 0.437500}, {0.312500, 0.437500}, {0.312500, 0.437500}, //
{0.312500, 0.437500}, {0.312500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, //
{0.437500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, //
{0.562500, 0.437500}, {0.562500, 0.437500}, {0.562500, 0.437500}, {0.562500, 0.437500}, //
{0.562500, 0.437500}, {0.562500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, //
{0.687500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, //
{0.812500, 0.437500}, {0.812500, 0.437500}, {0.812500, 0.437500}, {0.812500, 0.437500}, //
{0.812500, 0.437500}, {0.812500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, //
{0.937500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, //
{0.062500, 0.562500}, {0.062500, 0.562500}, {0.062500, 0.562500}, {0.062500, 0.562500}, //
{0.062500, 0.562500}, {0.062500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, //
{0.187500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, //
{0.312500, 0.562500}, {0.312500, 0.562500}, {0.312500, 0.562500}, {0.312500, 0.562500}, //
{0.312500, 0.562500}, {0.312500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, //
{0.437500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, //
{0.562500, 0.562500}, {0.562500, 0.562500}, {0.562500, 0.562500}, {0.562500, 0.562500}, //
{0.562500, 0.562500}, {0.562500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, //
{0.687500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, //
{0.812500, 0.562500}, {0.812500, 0.562500}, {0.812500, 0.562500}, {0.812500, 0.562500}, //
{0.812500, 0.562500}, {0.812500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, //
{0.937500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, //
{0.062500, 0.687500}, {0.062500, 0.687500}, {0.062500, 0.687500}, {0.062500, 0.687500}, //
{0.062500, 0.687500}, {0.062500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, //
{0.187500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, //
{0.312500, 0.687500}, {0.312500, 0.687500}, {0.312500, 0.687500}, {0.312500, 0.687500}, //
{0.312500, 0.687500}, {0.312500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, //
{0.437500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, //
{0.562500, 0.687500}, {0.562500, 0.687500}, {0.562500, 0.687500}, {0.562500, 0.687500}, //
{0.562500, 0.687500}, {0.562500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, //
{0.687500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, //
{0.812500, 0.687500}, {0.812500, 0.687500}, {0.812500, 0.687500}, {0.812500, 0.687500}, //
{0.812500, 0.687500}, {0.812500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, //
{0.937500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, //
{0.062500, 0.812500}, {0.062500, 0.812500}, {0.062500, 0.812500}, {0.062500, 0.812500}, //
{0.062500, 0.812500}, {0.062500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, //
{0.187500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, //
{0.312500, 0.812500}, {0.312500, 0.812500}, {0.312500, 0.812500}, {0.312500, 0.812500}, //
{0.312500, 0.812500}, {0.312500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, //
{0.437500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, //
{0.562500, 0.812500}, {0.562500, 0.812500}, {0.562500, 0.812500}, {0.562500, 0.812500}, //
{0.562500, 0.812500}, {0.562500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, //
{0.687500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, //
{0.812500, 0.812500}, {0.812500, 0.812500}, {0.812500, 0.812500}, {0.812500, 0.812500}, //
{0.812500, 0.812500}, {0.812500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, //
{0.937500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, //
{0.062500, 0.937500}, {0.062500, 0.937500}, {0.062500, 0.937500}, {0.062500, 0.937500}, //
{0.062500, 0.937500}, {0.062500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, //
{0.187500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, //
{0.312500, 0.937500}, {0.312500, 0.937500}, {0.312500, 0.937500}, {0.312500, 0.937500}, //
{0.312500, 0.937500}, {0.312500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, //
{0.437500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, //
{0.562500, 0.937500}, {0.562500, 0.937500}, {0.562500, 0.937500}, {0.562500, 0.937500}, //
{0.562500, 0.937500}, {0.562500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, //
{0.687500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, //
{0.812500, 0.937500}, {0.812500, 0.937500}, {0.812500, 0.937500}, {0.812500, 0.937500}, //
{0.812500, 0.937500}, {0.812500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, //
{0.937500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, //
};
#define ORT(expr) \
do { \
OrtStatus *status = this->api->expr; \
if (status != nullptr) { \
const char *msg = this->api->GetErrorMessage(status); \
HT_ERROR(this->device, "[%s:%d]: %s\n", __FILE__, __LINE__, msg); \
this->api->ReleaseStatus(status); \
assert(false); \
} \
} while (0)
void
ht_model::init_palm_detection(OrtSessionOptions *opts)
{
// Both models have slightly different shapes, preventing us to constexpr the input shape
std::array<int64_t, 4> input_shape;
std::array<std::string, 1> input_names;
std::filesystem::path path = this->device->startup_config.model_slug;
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
path /= "palm_detection_MEDIAPIPE.onnx";
input_shape = {1, 3, 128, 128};
input_names = {"input"};
} else {
path /= "palm_detection_COLLABORA.onnx";
input_shape = {1, 128, 128, 3};
input_names = {"input:0"};
}
HT_DEBUG(this->device, "Loading palm detection model from file '%s'", path.c_str());
ORT(CreateSession(this->env, path.c_str(), opts, &this->palm_detection_session));
assert(this->palm_detection_session);
constexpr size_t input_size = 3 * 128 * 128;
ORT(CreateTensorWithDataAsOrtValue(this->palm_detection_meminfo, this->palm_detection_data.data(),
input_size * sizeof(float), input_shape.data(), input_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &this->palm_detection_tensor));
assert(this->palm_detection_tensor);
int is_tensor;
ORT(IsTensor(this->palm_detection_tensor, &is_tensor));
assert(is_tensor);
}
void
ht_model::init_hand_landmark(OrtSessionOptions *opts)
{
std::filesystem::path path = this->device->startup_config.model_slug;
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
path /= "hand_landmark_MEDIAPIPE.onnx";
} else {
path /= "hand_landmark_COLLABORA.onnx";
}
HT_DEBUG(this->device, "Loading hand landmark model from file '%s'", path.c_str());
ORT(CreateSession(this->env, path.c_str(), opts, &this->hand_landmark_session));
assert(this->hand_landmark_session);
constexpr size_t input_size = 3 * 224 * 224;
constexpr std::array<int64_t, 4> input_shape = {1, 3, 224, 224};
ORT(CreateTensorWithDataAsOrtValue(this->hand_landmark_meminfo, this->hand_landmark_data.data(),
input_size * sizeof(float), input_shape.data(), input_shape.size(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &this->hand_landmark_tensor));
assert(this->hand_landmark_tensor != nullptr);
int is_tensor;
ORT(IsTensor(hand_landmark_tensor, &is_tensor));
assert(is_tensor);
}
ht_model::ht_model(struct ht_device *htd) : device(htd), api(OrtGetApiBase()->GetApi(ORT_API_VERSION))
{
ORT(CreateEnv(ORT_LOGGING_LEVEL_WARNING, "monado_ht", &this->env));
ORT(CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &this->palm_detection_meminfo));
ORT(CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &this->hand_landmark_meminfo));
OrtSessionOptions *opts = nullptr;
ORT(CreateSessionOptions(&opts));
// TODO review options, config for threads?
ORT(SetSessionGraphOptimizationLevel(opts, ORT_ENABLE_ALL));
ORT(SetIntraOpNumThreads(opts, 1));
this->init_palm_detection(opts);
this->init_hand_landmark(opts);
this->api->ReleaseSessionOptions(opts);
}
ht_model::~ht_model()
{
this->api->ReleaseMemoryInfo(this->palm_detection_meminfo);
this->api->ReleaseSession(this->palm_detection_session);
this->api->ReleaseValue(this->palm_detection_tensor);
this->api->ReleaseMemoryInfo(this->hand_landmark_meminfo);
this->api->ReleaseSession(this->hand_landmark_session);
this->api->ReleaseValue(this->hand_landmark_tensor);
this->api->ReleaseEnv(this->env);
}
std::vector<Palm7KP>
ht_model::palm_detection(ht_view *htv, const cv::Mat &input)
{
// TODO use opencv to handle input preprocessing
constexpr int hd_size = 128;
constexpr size_t nb_planes = 3;
constexpr size_t size = hd_size * hd_size * nb_planes;
cv::Mat img;
cv::Matx23f back_from_blackbar = blackbar(input, img, {hd_size, hd_size});
float scale_factor = back_from_blackbar(0, 0); // 960/128
assert(img.isContinuous());
constexpr float mean = 128.0f;
constexpr float std = 128.0f;
if (htv->htd->startup_config.palm_detection_use_mediapipe) {
std::vector<uint8_t> combined_planes(size);
planarize(img, combined_planes.data());
for (size_t i = 0; i < size; i++) {
float val = (float)combined_planes[i];
this->palm_detection_data[i] = (val - mean) / std;
}
} else {
assert(img.isContinuous());
for (size_t i = 0; i < size; i++) {
int val = img.data[i];
this->palm_detection_data[i] = (val - mean) / std;
}
}
const char *input_names[1];
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
input_names[0] = "input";
} else {
input_names[0] = "input:0";
}
static const char *const output_names[] = {"classificators", "regressors"};
OrtValue *output_tensor[] = {nullptr, nullptr};
ORT(Run(this->palm_detection_session, nullptr, input_names, &this->palm_detection_tensor, 1, output_names, 2,
output_tensor));
// TODO define types to handle data
float *classificators = nullptr;
float *regressors = nullptr;
// Output is 896 floats
ORT(GetTensorMutableData(output_tensor[0], (void **)&classificators));
// Output is 896 * 18 floats
ORT(GetTensorMutableData(output_tensor[1], (void **)&regressors));
std::vector<NMSPalm> detections;
for (size_t i = 0; i < 896; ++i) {
const float score = 1.0 / (1.0 + exp(-classificators[i]));
// Let a lot of detections in - they'll be slowly rejected later
if (score <= this->device->dynamic_config.nms_threshold.val) {
continue;
}
const struct anchor *anchor = &anchors[i];
// Boundary box.
NMSPalm det;
float anchx = anchor->x * 128;
float anchy = anchor->y * 128;
float shiftx = regressors[i * 18];
float shifty = regressors[i * 18 + 1];
float w = regressors[i * 18 + 2];
float h = regressors[i * 18 + 3];
float cx = shiftx + anchx;
float cy = shifty + anchy;
struct xrt_vec2 *kps = det.keypoints;
kps[0] = {regressors[i * 18 + 4], regressors[i * 18 + 5]};
kps[1] = {regressors[i * 18 + 6], regressors[i * 18 + 7]};
kps[2] = {regressors[i * 18 + 8], regressors[i * 18 + 9]};
kps[3] = {regressors[i * 18 + 10], regressors[i * 18 + 11]};
kps[4] = {regressors[i * 18 + 12], regressors[i * 18 + 13]};
kps[5] = {regressors[i * 18 + 14], regressors[i * 18 + 15]};
kps[6] = {regressors[i * 18 + 16], regressors[i * 18 + 17]};
for (int i = 0; i < 7; i++) {
struct xrt_vec2 *b = &kps[i];
b->x += anchx;
b->y += anchy;
}
det.bbox.w = w;
det.bbox.h = h;
det.bbox.cx = cx;
det.bbox.cy = cy;
det.confidence = score;
detections.push_back(det);
if (htv->htd->debug_scribble && (htv->htd->dynamic_config.scribble_raw_detections)) {
xrt_vec2 center = transformVecBy2x3(xrt_vec2{cx, cy}, back_from_blackbar);
float sz = det.bbox.w * scale_factor;
cv::rectangle(htv->debug_out_to_this,
{(int)(center.x - (sz / 2)), (int)(center.y - (sz / 2)), (int)sz, (int)sz},
hsv2rgb(0.0f, math_map_ranges(det.confidence, 0.0f, 1.0f, 1.5f, -0.1f),
math_map_ranges(det.confidence, 0.0f, 1.0f, 0.2f, 1.4f)),
1);
for (int i = 0; i < 7; i++) {
handDot(htv->debug_out_to_this, transformVecBy2x3(kps[i], back_from_blackbar),
det.confidence * 7, ((float)i) * (360.0f / 7.0f), det.confidence, 1);
}
}
}
this->api->ReleaseValue(output_tensor[0]);
this->api->ReleaseValue(output_tensor[1]);
std::vector<Palm7KP> output;
if (detections.empty()) {
return output;
}
std::vector<NMSPalm> nms_palms = filterBoxesWeightedAvg(detections, htv->htd->dynamic_config.nms_iou.val);
for (const NMSPalm &cooler : nms_palms) {
// Display box
struct xrt_vec2 tl = {cooler.bbox.cx - cooler.bbox.w / 2, cooler.bbox.cy - cooler.bbox.h / 2};
struct xrt_vec2 bob = transformVecBy2x3(tl, back_from_blackbar);
float sz = cooler.bbox.w * scale_factor;
if (htv->htd->debug_scribble && htv->htd->dynamic_config.scribble_nms_detections) {
cv::rectangle(htv->debug_out_to_this, {(int)bob.x, (int)bob.y, (int)sz, (int)sz},
hsv2rgb(180.0f, math_map_ranges(cooler.confidence, 0.0f, 1.0f, 0.8f, -0.1f),
math_map_ranges(cooler.confidence, 0.0f, 1.0f, 0.2f, 1.4f)),
2);
for (int i = 0; i < 7; i++) {
handDot(htv->debug_out_to_this,
transformVecBy2x3(cooler.keypoints[i], back_from_blackbar),
cooler.confidence * 14, ((float)i) * (360.0f / 7.0f), cooler.confidence, 3);
}
}
Palm7KP this_element;
for (int i = 0; i < 7; i++) {
struct xrt_vec2 b = cooler.keypoints[i];
this_element.kps[i] = transformVecBy2x3(b, back_from_blackbar);
}
this_element.confidence = cooler.confidence;
output.push_back(this_element);
}
return output;
}
Hand2D
ht_model::hand_landmark(const cv::Mat input)
{
std::scoped_lock lock(this->hand_landmark_lock);
// TODO use opencv to handle input preprocessing
constexpr size_t lix = 224;
constexpr size_t liy = 224;
constexpr size_t nb_planes = 3;
cv::Mat planes[nb_planes];
constexpr size_t size = lix * liy * nb_planes;
std::vector<uint8_t> combined_planes(size);
planarize(input, combined_planes.data());
// Normalize - supposedly, the keypoint estimator wants keypoints in [0,1]
for (size_t i = 0; i < size; i++) {
this->hand_landmark_data[i] = (float)combined_planes[i] / 255.0;
}
static const char *const input_names[] = {"input_1"};
static const char *const output_names[] = {"Identity", "Identity_1", "Identity_2"};
OrtValue *output_tensor[] = {nullptr, nullptr, nullptr};
ORT(Run(this->hand_landmark_session, nullptr, input_names, &this->hand_landmark_tensor, 1, output_names, 3,
output_tensor));
Hand2D hand{};
float *landmarks = nullptr;
// Should give a pointer to data that is freed on g_ort->ReleaseValue(output_tensor[0]);.
ORT(GetTensorMutableData(output_tensor[0], (void **)&landmarks));
constexpr int stride = 3;
for (size_t i = 0; i < 21; i++) {
int rt = i * stride;
float x = landmarks[rt];
float y = landmarks[rt + 1];
float z = landmarks[rt + 2];
hand.kps[i].x = x;
hand.kps[i].y = y;
hand.kps[i].z = z;
}
this->api->ReleaseValue(output_tensor[0]);
this->api->ReleaseValue(output_tensor[1]);
this->api->ReleaseValue(output_tensor[2]);
return hand;
}

View file

@ -0,0 +1,51 @@
// Copyright 2021, Collabora, Ltd.
// SPDX-License-Identifier: BSL-1.0
/*!
* @file
* @brief Code to run machine learning models for camera-based hand tracker.
* @author Moses Turner <moses@collabora.com>
* @author Marcus Edel <marcus.edel@collabora.com>
* @author Simon Zeni <simon@bl4ckb0ne.ca>
* @ingroup drv_ht
*/
#pragma once
#include <core/session/onnxruntime_c_api.h>
#include <opencv2/core/mat.hpp>
#include <filesystem>
#include <array>
class ht_model
{
struct ht_device *device = nullptr;
const OrtApi *api = nullptr;
OrtEnv *env = nullptr;
OrtMemoryInfo *palm_detection_meminfo = nullptr;
OrtSession *palm_detection_session = nullptr;
OrtValue *palm_detection_tensor = nullptr;
std::array<float, 3 * 128 * 128> palm_detection_data;
std::mutex hand_landmark_lock;
OrtMemoryInfo *hand_landmark_meminfo = nullptr;
OrtSession *hand_landmark_session = nullptr;
OrtValue *hand_landmark_tensor = nullptr;
std::array<float, 3 * 224 * 224> hand_landmark_data;
void
init_palm_detection(OrtSessionOptions *opts);
void
init_hand_landmark(OrtSessionOptions *opts);
public:
ht_model(struct ht_device *htd);
~ht_model();
std::vector<Palm7KP>
palm_detection(ht_view *htv, const cv::Mat &input);
Hand2D
hand_landmark(const cv::Mat input);
};

View file

@ -92,6 +92,7 @@ lib_drv_ht = static_library(
'ht/ht_driver.hpp',
'ht/ht_interface.h',
'ht/ht_models.cpp',
'ht/ht_model.cpp',
'ht/ht_hand_math.cpp',
'ht/ht_image_math.cpp',
'ht/ht_nms.cpp',