mirror of
https://gitlab.freedesktop.org/monado/monado.git
synced 2025-02-03 12:28:07 +00:00
d/ht: introduce ht_model
This commit is contained in:
parent
e84d948fb4
commit
0e53b3b87f
src/xrt/drivers
|
@ -218,6 +218,7 @@ if(XRT_BUILD_DRIVER_HANDTRACKING)
|
|||
ht/ht_driver.hpp
|
||||
ht/ht_interface.h
|
||||
ht/ht_models.cpp
|
||||
ht/ht_model.cpp
|
||||
ht/ht_hand_math.cpp
|
||||
ht/ht_image_math.cpp
|
||||
ht/ht_nms.cpp
|
||||
|
|
594
src/xrt/drivers/ht/ht_model.cpp
Normal file
594
src/xrt/drivers/ht/ht_model.cpp
Normal file
|
@ -0,0 +1,594 @@
|
|||
// Copyright 2021, Collabora, Ltd.
|
||||
// SPDX-License-Identifier: BSL-1.0
|
||||
/*!
|
||||
* @file
|
||||
* @brief Code to run machine learning models for camera-based hand tracker.
|
||||
* @author Moses Turner <moses@collabora.com>
|
||||
* @author Marcus Edel <marcus.edel@collabora.com>
|
||||
* @author Simon Zeni <simon@bl4ckb0ne.ca>
|
||||
* @ingroup drv_ht
|
||||
*/
|
||||
|
||||
// Many C api things were stolen from here (MIT license):
|
||||
// https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c
|
||||
|
||||
#include "ht_driver.hpp"
|
||||
#include "ht_image_math.hpp"
|
||||
#include "ht_model.hpp"
|
||||
#include "ht_nms.hpp"
|
||||
|
||||
#include <array>
|
||||
|
||||
#undef HEAVY_SCRIBBLE
|
||||
|
||||
/*
|
||||
* Anchors data taken from mediapipe's palm detection, used for single-shot detector model.
|
||||
*
|
||||
* See:
|
||||
* https://google.github.io/mediapipe/solutions/hands.html#palm-detection-model
|
||||
* https://github.com/google/mediapipe/blob/v0.8.8/mediapipe/calculators/tflite/ssd_anchors_calculator.cc#L101
|
||||
* https://github.com/google/mediapipe/blob/v0.8.8/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt#L60
|
||||
*/
|
||||
struct anchor
|
||||
{
|
||||
float x, y;
|
||||
};
|
||||
|
||||
static const struct anchor anchors[896]{
|
||||
{0.031250, 0.031250}, {0.031250, 0.031250}, {0.093750, 0.031250}, {0.093750, 0.031250}, //
|
||||
{0.156250, 0.031250}, {0.156250, 0.031250}, {0.218750, 0.031250}, {0.218750, 0.031250}, //
|
||||
{0.281250, 0.031250}, {0.281250, 0.031250}, {0.343750, 0.031250}, {0.343750, 0.031250}, //
|
||||
{0.406250, 0.031250}, {0.406250, 0.031250}, {0.468750, 0.031250}, {0.468750, 0.031250}, //
|
||||
{0.531250, 0.031250}, {0.531250, 0.031250}, {0.593750, 0.031250}, {0.593750, 0.031250}, //
|
||||
{0.656250, 0.031250}, {0.656250, 0.031250}, {0.718750, 0.031250}, {0.718750, 0.031250}, //
|
||||
{0.781250, 0.031250}, {0.781250, 0.031250}, {0.843750, 0.031250}, {0.843750, 0.031250}, //
|
||||
{0.906250, 0.031250}, {0.906250, 0.031250}, {0.968750, 0.031250}, {0.968750, 0.031250}, //
|
||||
{0.031250, 0.093750}, {0.031250, 0.093750}, {0.093750, 0.093750}, {0.093750, 0.093750}, //
|
||||
{0.156250, 0.093750}, {0.156250, 0.093750}, {0.218750, 0.093750}, {0.218750, 0.093750}, //
|
||||
{0.281250, 0.093750}, {0.281250, 0.093750}, {0.343750, 0.093750}, {0.343750, 0.093750}, //
|
||||
{0.406250, 0.093750}, {0.406250, 0.093750}, {0.468750, 0.093750}, {0.468750, 0.093750}, //
|
||||
{0.531250, 0.093750}, {0.531250, 0.093750}, {0.593750, 0.093750}, {0.593750, 0.093750}, //
|
||||
{0.656250, 0.093750}, {0.656250, 0.093750}, {0.718750, 0.093750}, {0.718750, 0.093750}, //
|
||||
{0.781250, 0.093750}, {0.781250, 0.093750}, {0.843750, 0.093750}, {0.843750, 0.093750}, //
|
||||
{0.906250, 0.093750}, {0.906250, 0.093750}, {0.968750, 0.093750}, {0.968750, 0.093750}, //
|
||||
{0.031250, 0.156250}, {0.031250, 0.156250}, {0.093750, 0.156250}, {0.093750, 0.156250}, //
|
||||
{0.156250, 0.156250}, {0.156250, 0.156250}, {0.218750, 0.156250}, {0.218750, 0.156250}, //
|
||||
{0.281250, 0.156250}, {0.281250, 0.156250}, {0.343750, 0.156250}, {0.343750, 0.156250}, //
|
||||
{0.406250, 0.156250}, {0.406250, 0.156250}, {0.468750, 0.156250}, {0.468750, 0.156250}, //
|
||||
{0.531250, 0.156250}, {0.531250, 0.156250}, {0.593750, 0.156250}, {0.593750, 0.156250}, //
|
||||
{0.656250, 0.156250}, {0.656250, 0.156250}, {0.718750, 0.156250}, {0.718750, 0.156250}, //
|
||||
{0.781250, 0.156250}, {0.781250, 0.156250}, {0.843750, 0.156250}, {0.843750, 0.156250}, //
|
||||
{0.906250, 0.156250}, {0.906250, 0.156250}, {0.968750, 0.156250}, {0.968750, 0.156250}, //
|
||||
{0.031250, 0.218750}, {0.031250, 0.218750}, {0.093750, 0.218750}, {0.093750, 0.218750}, //
|
||||
{0.156250, 0.218750}, {0.156250, 0.218750}, {0.218750, 0.218750}, {0.218750, 0.218750}, //
|
||||
{0.281250, 0.218750}, {0.281250, 0.218750}, {0.343750, 0.218750}, {0.343750, 0.218750}, //
|
||||
{0.406250, 0.218750}, {0.406250, 0.218750}, {0.468750, 0.218750}, {0.468750, 0.218750}, //
|
||||
{0.531250, 0.218750}, {0.531250, 0.218750}, {0.593750, 0.218750}, {0.593750, 0.218750}, //
|
||||
{0.656250, 0.218750}, {0.656250, 0.218750}, {0.718750, 0.218750}, {0.718750, 0.218750}, //
|
||||
{0.781250, 0.218750}, {0.781250, 0.218750}, {0.843750, 0.218750}, {0.843750, 0.218750}, //
|
||||
{0.906250, 0.218750}, {0.906250, 0.218750}, {0.968750, 0.218750}, {0.968750, 0.218750}, //
|
||||
{0.031250, 0.281250}, {0.031250, 0.281250}, {0.093750, 0.281250}, {0.093750, 0.281250}, //
|
||||
{0.156250, 0.281250}, {0.156250, 0.281250}, {0.218750, 0.281250}, {0.218750, 0.281250}, //
|
||||
{0.281250, 0.281250}, {0.281250, 0.281250}, {0.343750, 0.281250}, {0.343750, 0.281250}, //
|
||||
{0.406250, 0.281250}, {0.406250, 0.281250}, {0.468750, 0.281250}, {0.468750, 0.281250}, //
|
||||
{0.531250, 0.281250}, {0.531250, 0.281250}, {0.593750, 0.281250}, {0.593750, 0.281250}, //
|
||||
{0.656250, 0.281250}, {0.656250, 0.281250}, {0.718750, 0.281250}, {0.718750, 0.281250}, //
|
||||
{0.781250, 0.281250}, {0.781250, 0.281250}, {0.843750, 0.281250}, {0.843750, 0.281250}, //
|
||||
{0.906250, 0.281250}, {0.906250, 0.281250}, {0.968750, 0.281250}, {0.968750, 0.281250}, //
|
||||
{0.031250, 0.343750}, {0.031250, 0.343750}, {0.093750, 0.343750}, {0.093750, 0.343750}, //
|
||||
{0.156250, 0.343750}, {0.156250, 0.343750}, {0.218750, 0.343750}, {0.218750, 0.343750}, //
|
||||
{0.281250, 0.343750}, {0.281250, 0.343750}, {0.343750, 0.343750}, {0.343750, 0.343750}, //
|
||||
{0.406250, 0.343750}, {0.406250, 0.343750}, {0.468750, 0.343750}, {0.468750, 0.343750}, //
|
||||
{0.531250, 0.343750}, {0.531250, 0.343750}, {0.593750, 0.343750}, {0.593750, 0.343750}, //
|
||||
{0.656250, 0.343750}, {0.656250, 0.343750}, {0.718750, 0.343750}, {0.718750, 0.343750}, //
|
||||
{0.781250, 0.343750}, {0.781250, 0.343750}, {0.843750, 0.343750}, {0.843750, 0.343750}, //
|
||||
{0.906250, 0.343750}, {0.906250, 0.343750}, {0.968750, 0.343750}, {0.968750, 0.343750}, //
|
||||
{0.031250, 0.406250}, {0.031250, 0.406250}, {0.093750, 0.406250}, {0.093750, 0.406250}, //
|
||||
{0.156250, 0.406250}, {0.156250, 0.406250}, {0.218750, 0.406250}, {0.218750, 0.406250}, //
|
||||
{0.281250, 0.406250}, {0.281250, 0.406250}, {0.343750, 0.406250}, {0.343750, 0.406250}, //
|
||||
{0.406250, 0.406250}, {0.406250, 0.406250}, {0.468750, 0.406250}, {0.468750, 0.406250}, //
|
||||
{0.531250, 0.406250}, {0.531250, 0.406250}, {0.593750, 0.406250}, {0.593750, 0.406250}, //
|
||||
{0.656250, 0.406250}, {0.656250, 0.406250}, {0.718750, 0.406250}, {0.718750, 0.406250}, //
|
||||
{0.781250, 0.406250}, {0.781250, 0.406250}, {0.843750, 0.406250}, {0.843750, 0.406250}, //
|
||||
{0.906250, 0.406250}, {0.906250, 0.406250}, {0.968750, 0.406250}, {0.968750, 0.406250}, //
|
||||
{0.031250, 0.468750}, {0.031250, 0.468750}, {0.093750, 0.468750}, {0.093750, 0.468750}, //
|
||||
{0.156250, 0.468750}, {0.156250, 0.468750}, {0.218750, 0.468750}, {0.218750, 0.468750}, //
|
||||
{0.281250, 0.468750}, {0.281250, 0.468750}, {0.343750, 0.468750}, {0.343750, 0.468750}, //
|
||||
{0.406250, 0.468750}, {0.406250, 0.468750}, {0.468750, 0.468750}, {0.468750, 0.468750}, //
|
||||
{0.531250, 0.468750}, {0.531250, 0.468750}, {0.593750, 0.468750}, {0.593750, 0.468750}, //
|
||||
{0.656250, 0.468750}, {0.656250, 0.468750}, {0.718750, 0.468750}, {0.718750, 0.468750}, //
|
||||
{0.781250, 0.468750}, {0.781250, 0.468750}, {0.843750, 0.468750}, {0.843750, 0.468750}, //
|
||||
{0.906250, 0.468750}, {0.906250, 0.468750}, {0.968750, 0.468750}, {0.968750, 0.468750}, //
|
||||
{0.031250, 0.531250}, {0.031250, 0.531250}, {0.093750, 0.531250}, {0.093750, 0.531250}, //
|
||||
{0.156250, 0.531250}, {0.156250, 0.531250}, {0.218750, 0.531250}, {0.218750, 0.531250}, //
|
||||
{0.281250, 0.531250}, {0.281250, 0.531250}, {0.343750, 0.531250}, {0.343750, 0.531250}, //
|
||||
{0.406250, 0.531250}, {0.406250, 0.531250}, {0.468750, 0.531250}, {0.468750, 0.531250}, //
|
||||
{0.531250, 0.531250}, {0.531250, 0.531250}, {0.593750, 0.531250}, {0.593750, 0.531250}, //
|
||||
{0.656250, 0.531250}, {0.656250, 0.531250}, {0.718750, 0.531250}, {0.718750, 0.531250}, //
|
||||
{0.781250, 0.531250}, {0.781250, 0.531250}, {0.843750, 0.531250}, {0.843750, 0.531250}, //
|
||||
{0.906250, 0.531250}, {0.906250, 0.531250}, {0.968750, 0.531250}, {0.968750, 0.531250}, //
|
||||
{0.031250, 0.593750}, {0.031250, 0.593750}, {0.093750, 0.593750}, {0.093750, 0.593750}, //
|
||||
{0.156250, 0.593750}, {0.156250, 0.593750}, {0.218750, 0.593750}, {0.218750, 0.593750}, //
|
||||
{0.281250, 0.593750}, {0.281250, 0.593750}, {0.343750, 0.593750}, {0.343750, 0.593750}, //
|
||||
{0.406250, 0.593750}, {0.406250, 0.593750}, {0.468750, 0.593750}, {0.468750, 0.593750}, //
|
||||
{0.531250, 0.593750}, {0.531250, 0.593750}, {0.593750, 0.593750}, {0.593750, 0.593750}, //
|
||||
{0.656250, 0.593750}, {0.656250, 0.593750}, {0.718750, 0.593750}, {0.718750, 0.593750}, //
|
||||
{0.781250, 0.593750}, {0.781250, 0.593750}, {0.843750, 0.593750}, {0.843750, 0.593750}, //
|
||||
{0.906250, 0.593750}, {0.906250, 0.593750}, {0.968750, 0.593750}, {0.968750, 0.593750}, //
|
||||
{0.031250, 0.656250}, {0.031250, 0.656250}, {0.093750, 0.656250}, {0.093750, 0.656250}, //
|
||||
{0.156250, 0.656250}, {0.156250, 0.656250}, {0.218750, 0.656250}, {0.218750, 0.656250}, //
|
||||
{0.281250, 0.656250}, {0.281250, 0.656250}, {0.343750, 0.656250}, {0.343750, 0.656250}, //
|
||||
{0.406250, 0.656250}, {0.406250, 0.656250}, {0.468750, 0.656250}, {0.468750, 0.656250}, //
|
||||
{0.531250, 0.656250}, {0.531250, 0.656250}, {0.593750, 0.656250}, {0.593750, 0.656250}, //
|
||||
{0.656250, 0.656250}, {0.656250, 0.656250}, {0.718750, 0.656250}, {0.718750, 0.656250}, //
|
||||
{0.781250, 0.656250}, {0.781250, 0.656250}, {0.843750, 0.656250}, {0.843750, 0.656250}, //
|
||||
{0.906250, 0.656250}, {0.906250, 0.656250}, {0.968750, 0.656250}, {0.968750, 0.656250}, //
|
||||
{0.031250, 0.718750}, {0.031250, 0.718750}, {0.093750, 0.718750}, {0.093750, 0.718750}, //
|
||||
{0.156250, 0.718750}, {0.156250, 0.718750}, {0.218750, 0.718750}, {0.218750, 0.718750}, //
|
||||
{0.281250, 0.718750}, {0.281250, 0.718750}, {0.343750, 0.718750}, {0.343750, 0.718750}, //
|
||||
{0.406250, 0.718750}, {0.406250, 0.718750}, {0.468750, 0.718750}, {0.468750, 0.718750}, //
|
||||
{0.531250, 0.718750}, {0.531250, 0.718750}, {0.593750, 0.718750}, {0.593750, 0.718750}, //
|
||||
{0.656250, 0.718750}, {0.656250, 0.718750}, {0.718750, 0.718750}, {0.718750, 0.718750}, //
|
||||
{0.781250, 0.718750}, {0.781250, 0.718750}, {0.843750, 0.718750}, {0.843750, 0.718750}, //
|
||||
{0.906250, 0.718750}, {0.906250, 0.718750}, {0.968750, 0.718750}, {0.968750, 0.718750}, //
|
||||
{0.031250, 0.781250}, {0.031250, 0.781250}, {0.093750, 0.781250}, {0.093750, 0.781250}, //
|
||||
{0.156250, 0.781250}, {0.156250, 0.781250}, {0.218750, 0.781250}, {0.218750, 0.781250}, //
|
||||
{0.281250, 0.781250}, {0.281250, 0.781250}, {0.343750, 0.781250}, {0.343750, 0.781250}, //
|
||||
{0.406250, 0.781250}, {0.406250, 0.781250}, {0.468750, 0.781250}, {0.468750, 0.781250}, //
|
||||
{0.531250, 0.781250}, {0.531250, 0.781250}, {0.593750, 0.781250}, {0.593750, 0.781250}, //
|
||||
{0.656250, 0.781250}, {0.656250, 0.781250}, {0.718750, 0.781250}, {0.718750, 0.781250}, //
|
||||
{0.781250, 0.781250}, {0.781250, 0.781250}, {0.843750, 0.781250}, {0.843750, 0.781250}, //
|
||||
{0.906250, 0.781250}, {0.906250, 0.781250}, {0.968750, 0.781250}, {0.968750, 0.781250}, //
|
||||
{0.031250, 0.843750}, {0.031250, 0.843750}, {0.093750, 0.843750}, {0.093750, 0.843750}, //
|
||||
{0.156250, 0.843750}, {0.156250, 0.843750}, {0.218750, 0.843750}, {0.218750, 0.843750}, //
|
||||
{0.281250, 0.843750}, {0.281250, 0.843750}, {0.343750, 0.843750}, {0.343750, 0.843750}, //
|
||||
{0.406250, 0.843750}, {0.406250, 0.843750}, {0.468750, 0.843750}, {0.468750, 0.843750}, //
|
||||
{0.531250, 0.843750}, {0.531250, 0.843750}, {0.593750, 0.843750}, {0.593750, 0.843750}, //
|
||||
{0.656250, 0.843750}, {0.656250, 0.843750}, {0.718750, 0.843750}, {0.718750, 0.843750}, //
|
||||
{0.781250, 0.843750}, {0.781250, 0.843750}, {0.843750, 0.843750}, {0.843750, 0.843750}, //
|
||||
{0.906250, 0.843750}, {0.906250, 0.843750}, {0.968750, 0.843750}, {0.968750, 0.843750}, //
|
||||
{0.031250, 0.906250}, {0.031250, 0.906250}, {0.093750, 0.906250}, {0.093750, 0.906250}, //
|
||||
{0.156250, 0.906250}, {0.156250, 0.906250}, {0.218750, 0.906250}, {0.218750, 0.906250}, //
|
||||
{0.281250, 0.906250}, {0.281250, 0.906250}, {0.343750, 0.906250}, {0.343750, 0.906250}, //
|
||||
{0.406250, 0.906250}, {0.406250, 0.906250}, {0.468750, 0.906250}, {0.468750, 0.906250}, //
|
||||
{0.531250, 0.906250}, {0.531250, 0.906250}, {0.593750, 0.906250}, {0.593750, 0.906250}, //
|
||||
{0.656250, 0.906250}, {0.656250, 0.906250}, {0.718750, 0.906250}, {0.718750, 0.906250}, //
|
||||
{0.781250, 0.906250}, {0.781250, 0.906250}, {0.843750, 0.906250}, {0.843750, 0.906250}, //
|
||||
{0.906250, 0.906250}, {0.906250, 0.906250}, {0.968750, 0.906250}, {0.968750, 0.906250}, //
|
||||
{0.031250, 0.968750}, {0.031250, 0.968750}, {0.093750, 0.968750}, {0.093750, 0.968750}, //
|
||||
{0.156250, 0.968750}, {0.156250, 0.968750}, {0.218750, 0.968750}, {0.218750, 0.968750}, //
|
||||
{0.281250, 0.968750}, {0.281250, 0.968750}, {0.343750, 0.968750}, {0.343750, 0.968750}, //
|
||||
{0.406250, 0.968750}, {0.406250, 0.968750}, {0.468750, 0.968750}, {0.468750, 0.968750}, //
|
||||
{0.531250, 0.968750}, {0.531250, 0.968750}, {0.593750, 0.968750}, {0.593750, 0.968750}, //
|
||||
{0.656250, 0.968750}, {0.656250, 0.968750}, {0.718750, 0.968750}, {0.718750, 0.968750}, //
|
||||
{0.781250, 0.968750}, {0.781250, 0.968750}, {0.843750, 0.968750}, {0.843750, 0.968750}, //
|
||||
{0.906250, 0.968750}, {0.906250, 0.968750}, {0.968750, 0.968750}, {0.968750, 0.968750}, //
|
||||
{0.062500, 0.062500}, {0.062500, 0.062500}, {0.062500, 0.062500}, {0.062500, 0.062500}, //
|
||||
{0.062500, 0.062500}, {0.062500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, //
|
||||
{0.187500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, {0.187500, 0.062500}, //
|
||||
{0.312500, 0.062500}, {0.312500, 0.062500}, {0.312500, 0.062500}, {0.312500, 0.062500}, //
|
||||
{0.312500, 0.062500}, {0.312500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, //
|
||||
{0.437500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, {0.437500, 0.062500}, //
|
||||
{0.562500, 0.062500}, {0.562500, 0.062500}, {0.562500, 0.062500}, {0.562500, 0.062500}, //
|
||||
{0.562500, 0.062500}, {0.562500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, //
|
||||
{0.687500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, {0.687500, 0.062500}, //
|
||||
{0.812500, 0.062500}, {0.812500, 0.062500}, {0.812500, 0.062500}, {0.812500, 0.062500}, //
|
||||
{0.812500, 0.062500}, {0.812500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, //
|
||||
{0.937500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, {0.937500, 0.062500}, //
|
||||
{0.062500, 0.187500}, {0.062500, 0.187500}, {0.062500, 0.187500}, {0.062500, 0.187500}, //
|
||||
{0.062500, 0.187500}, {0.062500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, //
|
||||
{0.187500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, {0.187500, 0.187500}, //
|
||||
{0.312500, 0.187500}, {0.312500, 0.187500}, {0.312500, 0.187500}, {0.312500, 0.187500}, //
|
||||
{0.312500, 0.187500}, {0.312500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, //
|
||||
{0.437500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, {0.437500, 0.187500}, //
|
||||
{0.562500, 0.187500}, {0.562500, 0.187500}, {0.562500, 0.187500}, {0.562500, 0.187500}, //
|
||||
{0.562500, 0.187500}, {0.562500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, //
|
||||
{0.687500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, {0.687500, 0.187500}, //
|
||||
{0.812500, 0.187500}, {0.812500, 0.187500}, {0.812500, 0.187500}, {0.812500, 0.187500}, //
|
||||
{0.812500, 0.187500}, {0.812500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, //
|
||||
{0.937500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, {0.937500, 0.187500}, //
|
||||
{0.062500, 0.312500}, {0.062500, 0.312500}, {0.062500, 0.312500}, {0.062500, 0.312500}, //
|
||||
{0.062500, 0.312500}, {0.062500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, //
|
||||
{0.187500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, {0.187500, 0.312500}, //
|
||||
{0.312500, 0.312500}, {0.312500, 0.312500}, {0.312500, 0.312500}, {0.312500, 0.312500}, //
|
||||
{0.312500, 0.312500}, {0.312500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, //
|
||||
{0.437500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, {0.437500, 0.312500}, //
|
||||
{0.562500, 0.312500}, {0.562500, 0.312500}, {0.562500, 0.312500}, {0.562500, 0.312500}, //
|
||||
{0.562500, 0.312500}, {0.562500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, //
|
||||
{0.687500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, {0.687500, 0.312500}, //
|
||||
{0.812500, 0.312500}, {0.812500, 0.312500}, {0.812500, 0.312500}, {0.812500, 0.312500}, //
|
||||
{0.812500, 0.312500}, {0.812500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, //
|
||||
{0.937500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, {0.937500, 0.312500}, //
|
||||
{0.062500, 0.437500}, {0.062500, 0.437500}, {0.062500, 0.437500}, {0.062500, 0.437500}, //
|
||||
{0.062500, 0.437500}, {0.062500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, //
|
||||
{0.187500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, {0.187500, 0.437500}, //
|
||||
{0.312500, 0.437500}, {0.312500, 0.437500}, {0.312500, 0.437500}, {0.312500, 0.437500}, //
|
||||
{0.312500, 0.437500}, {0.312500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, //
|
||||
{0.437500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, {0.437500, 0.437500}, //
|
||||
{0.562500, 0.437500}, {0.562500, 0.437500}, {0.562500, 0.437500}, {0.562500, 0.437500}, //
|
||||
{0.562500, 0.437500}, {0.562500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, //
|
||||
{0.687500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, {0.687500, 0.437500}, //
|
||||
{0.812500, 0.437500}, {0.812500, 0.437500}, {0.812500, 0.437500}, {0.812500, 0.437500}, //
|
||||
{0.812500, 0.437500}, {0.812500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, //
|
||||
{0.937500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, {0.937500, 0.437500}, //
|
||||
{0.062500, 0.562500}, {0.062500, 0.562500}, {0.062500, 0.562500}, {0.062500, 0.562500}, //
|
||||
{0.062500, 0.562500}, {0.062500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, //
|
||||
{0.187500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, {0.187500, 0.562500}, //
|
||||
{0.312500, 0.562500}, {0.312500, 0.562500}, {0.312500, 0.562500}, {0.312500, 0.562500}, //
|
||||
{0.312500, 0.562500}, {0.312500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, //
|
||||
{0.437500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, {0.437500, 0.562500}, //
|
||||
{0.562500, 0.562500}, {0.562500, 0.562500}, {0.562500, 0.562500}, {0.562500, 0.562500}, //
|
||||
{0.562500, 0.562500}, {0.562500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, //
|
||||
{0.687500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, {0.687500, 0.562500}, //
|
||||
{0.812500, 0.562500}, {0.812500, 0.562500}, {0.812500, 0.562500}, {0.812500, 0.562500}, //
|
||||
{0.812500, 0.562500}, {0.812500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, //
|
||||
{0.937500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, {0.937500, 0.562500}, //
|
||||
{0.062500, 0.687500}, {0.062500, 0.687500}, {0.062500, 0.687500}, {0.062500, 0.687500}, //
|
||||
{0.062500, 0.687500}, {0.062500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, //
|
||||
{0.187500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, {0.187500, 0.687500}, //
|
||||
{0.312500, 0.687500}, {0.312500, 0.687500}, {0.312500, 0.687500}, {0.312500, 0.687500}, //
|
||||
{0.312500, 0.687500}, {0.312500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, //
|
||||
{0.437500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, {0.437500, 0.687500}, //
|
||||
{0.562500, 0.687500}, {0.562500, 0.687500}, {0.562500, 0.687500}, {0.562500, 0.687500}, //
|
||||
{0.562500, 0.687500}, {0.562500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, //
|
||||
{0.687500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, {0.687500, 0.687500}, //
|
||||
{0.812500, 0.687500}, {0.812500, 0.687500}, {0.812500, 0.687500}, {0.812500, 0.687500}, //
|
||||
{0.812500, 0.687500}, {0.812500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, //
|
||||
{0.937500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, {0.937500, 0.687500}, //
|
||||
{0.062500, 0.812500}, {0.062500, 0.812500}, {0.062500, 0.812500}, {0.062500, 0.812500}, //
|
||||
{0.062500, 0.812500}, {0.062500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, //
|
||||
{0.187500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, {0.187500, 0.812500}, //
|
||||
{0.312500, 0.812500}, {0.312500, 0.812500}, {0.312500, 0.812500}, {0.312500, 0.812500}, //
|
||||
{0.312500, 0.812500}, {0.312500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, //
|
||||
{0.437500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, {0.437500, 0.812500}, //
|
||||
{0.562500, 0.812500}, {0.562500, 0.812500}, {0.562500, 0.812500}, {0.562500, 0.812500}, //
|
||||
{0.562500, 0.812500}, {0.562500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, //
|
||||
{0.687500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, {0.687500, 0.812500}, //
|
||||
{0.812500, 0.812500}, {0.812500, 0.812500}, {0.812500, 0.812500}, {0.812500, 0.812500}, //
|
||||
{0.812500, 0.812500}, {0.812500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, //
|
||||
{0.937500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, {0.937500, 0.812500}, //
|
||||
{0.062500, 0.937500}, {0.062500, 0.937500}, {0.062500, 0.937500}, {0.062500, 0.937500}, //
|
||||
{0.062500, 0.937500}, {0.062500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, //
|
||||
{0.187500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, {0.187500, 0.937500}, //
|
||||
{0.312500, 0.937500}, {0.312500, 0.937500}, {0.312500, 0.937500}, {0.312500, 0.937500}, //
|
||||
{0.312500, 0.937500}, {0.312500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, //
|
||||
{0.437500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, {0.437500, 0.937500}, //
|
||||
{0.562500, 0.937500}, {0.562500, 0.937500}, {0.562500, 0.937500}, {0.562500, 0.937500}, //
|
||||
{0.562500, 0.937500}, {0.562500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, //
|
||||
{0.687500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, {0.687500, 0.937500}, //
|
||||
{0.812500, 0.937500}, {0.812500, 0.937500}, {0.812500, 0.937500}, {0.812500, 0.937500}, //
|
||||
{0.812500, 0.937500}, {0.812500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, //
|
||||
{0.937500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, {0.937500, 0.937500}, //
|
||||
};
|
||||
|
||||
#define ORT(expr) \
|
||||
do { \
|
||||
OrtStatus *status = this->api->expr; \
|
||||
if (status != nullptr) { \
|
||||
const char *msg = this->api->GetErrorMessage(status); \
|
||||
HT_ERROR(this->device, "[%s:%d]: %s\n", __FILE__, __LINE__, msg); \
|
||||
this->api->ReleaseStatus(status); \
|
||||
assert(false); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
void
|
||||
ht_model::init_palm_detection(OrtSessionOptions *opts)
|
||||
{
|
||||
// Both models have slightly different shapes, preventing us to constexpr the input shape
|
||||
std::array<int64_t, 4> input_shape;
|
||||
std::array<std::string, 1> input_names;
|
||||
|
||||
std::filesystem::path path = this->device->startup_config.model_slug;
|
||||
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
|
||||
path /= "palm_detection_MEDIAPIPE.onnx";
|
||||
|
||||
input_shape = {1, 3, 128, 128};
|
||||
input_names = {"input"};
|
||||
} else {
|
||||
path /= "palm_detection_COLLABORA.onnx";
|
||||
|
||||
input_shape = {1, 128, 128, 3};
|
||||
input_names = {"input:0"};
|
||||
}
|
||||
|
||||
HT_DEBUG(this->device, "Loading palm detection model from file '%s'", path.c_str());
|
||||
ORT(CreateSession(this->env, path.c_str(), opts, &this->palm_detection_session));
|
||||
assert(this->palm_detection_session);
|
||||
|
||||
constexpr size_t input_size = 3 * 128 * 128;
|
||||
|
||||
ORT(CreateTensorWithDataAsOrtValue(this->palm_detection_meminfo, this->palm_detection_data.data(),
|
||||
input_size * sizeof(float), input_shape.data(), input_shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &this->palm_detection_tensor));
|
||||
|
||||
assert(this->palm_detection_tensor);
|
||||
int is_tensor;
|
||||
ORT(IsTensor(this->palm_detection_tensor, &is_tensor));
|
||||
assert(is_tensor);
|
||||
}
|
||||
|
||||
void
|
||||
ht_model::init_hand_landmark(OrtSessionOptions *opts)
|
||||
{
|
||||
std::filesystem::path path = this->device->startup_config.model_slug;
|
||||
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
|
||||
path /= "hand_landmark_MEDIAPIPE.onnx";
|
||||
} else {
|
||||
path /= "hand_landmark_COLLABORA.onnx";
|
||||
}
|
||||
|
||||
HT_DEBUG(this->device, "Loading hand landmark model from file '%s'", path.c_str());
|
||||
ORT(CreateSession(this->env, path.c_str(), opts, &this->hand_landmark_session));
|
||||
assert(this->hand_landmark_session);
|
||||
|
||||
constexpr size_t input_size = 3 * 224 * 224;
|
||||
|
||||
constexpr std::array<int64_t, 4> input_shape = {1, 3, 224, 224};
|
||||
ORT(CreateTensorWithDataAsOrtValue(this->hand_landmark_meminfo, this->hand_landmark_data.data(),
|
||||
input_size * sizeof(float), input_shape.data(), input_shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &this->hand_landmark_tensor));
|
||||
|
||||
assert(this->hand_landmark_tensor != nullptr);
|
||||
int is_tensor;
|
||||
ORT(IsTensor(hand_landmark_tensor, &is_tensor));
|
||||
assert(is_tensor);
|
||||
}
|
||||
|
||||
ht_model::ht_model(struct ht_device *htd) : device(htd), api(OrtGetApiBase()->GetApi(ORT_API_VERSION))
|
||||
{
|
||||
ORT(CreateEnv(ORT_LOGGING_LEVEL_WARNING, "monado_ht", &this->env));
|
||||
|
||||
ORT(CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &this->palm_detection_meminfo));
|
||||
ORT(CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &this->hand_landmark_meminfo));
|
||||
|
||||
OrtSessionOptions *opts = nullptr;
|
||||
ORT(CreateSessionOptions(&opts));
|
||||
|
||||
// TODO review options, config for threads?
|
||||
ORT(SetSessionGraphOptimizationLevel(opts, ORT_ENABLE_ALL));
|
||||
ORT(SetIntraOpNumThreads(opts, 1));
|
||||
|
||||
this->init_palm_detection(opts);
|
||||
this->init_hand_landmark(opts);
|
||||
|
||||
this->api->ReleaseSessionOptions(opts);
|
||||
}
|
||||
|
||||
|
||||
ht_model::~ht_model()
|
||||
{
|
||||
this->api->ReleaseMemoryInfo(this->palm_detection_meminfo);
|
||||
this->api->ReleaseSession(this->palm_detection_session);
|
||||
this->api->ReleaseValue(this->palm_detection_tensor);
|
||||
|
||||
this->api->ReleaseMemoryInfo(this->hand_landmark_meminfo);
|
||||
this->api->ReleaseSession(this->hand_landmark_session);
|
||||
this->api->ReleaseValue(this->hand_landmark_tensor);
|
||||
|
||||
this->api->ReleaseEnv(this->env);
|
||||
}
|
||||
|
||||
std::vector<Palm7KP>
|
||||
ht_model::palm_detection(ht_view *htv, const cv::Mat &input)
|
||||
{
|
||||
// TODO use opencv to handle input preprocessing
|
||||
constexpr int hd_size = 128;
|
||||
constexpr size_t nb_planes = 3;
|
||||
constexpr size_t size = hd_size * hd_size * nb_planes;
|
||||
|
||||
cv::Mat img;
|
||||
cv::Matx23f back_from_blackbar = blackbar(input, img, {hd_size, hd_size});
|
||||
|
||||
float scale_factor = back_from_blackbar(0, 0); // 960/128
|
||||
assert(img.isContinuous());
|
||||
constexpr float mean = 128.0f;
|
||||
constexpr float std = 128.0f;
|
||||
|
||||
if (htv->htd->startup_config.palm_detection_use_mediapipe) {
|
||||
std::vector<uint8_t> combined_planes(size);
|
||||
planarize(img, combined_planes.data());
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
float val = (float)combined_planes[i];
|
||||
this->palm_detection_data[i] = (val - mean) / std;
|
||||
}
|
||||
} else {
|
||||
|
||||
assert(img.isContinuous());
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
int val = img.data[i];
|
||||
|
||||
this->palm_detection_data[i] = (val - mean) / std;
|
||||
}
|
||||
}
|
||||
|
||||
const char *input_names[1];
|
||||
if (this->device->startup_config.keypoint_estimation_use_mediapipe) {
|
||||
input_names[0] = "input";
|
||||
} else {
|
||||
input_names[0] = "input:0";
|
||||
}
|
||||
|
||||
static const char *const output_names[] = {"classificators", "regressors"};
|
||||
|
||||
OrtValue *output_tensor[] = {nullptr, nullptr};
|
||||
ORT(Run(this->palm_detection_session, nullptr, input_names, &this->palm_detection_tensor, 1, output_names, 2,
|
||||
output_tensor));
|
||||
|
||||
// TODO define types to handle data
|
||||
float *classificators = nullptr;
|
||||
float *regressors = nullptr;
|
||||
|
||||
// Output is 896 floats
|
||||
ORT(GetTensorMutableData(output_tensor[0], (void **)&classificators));
|
||||
|
||||
// Output is 896 * 18 floats
|
||||
ORT(GetTensorMutableData(output_tensor[1], (void **)®ressors));
|
||||
|
||||
std::vector<NMSPalm> detections;
|
||||
for (size_t i = 0; i < 896; ++i) {
|
||||
const float score = 1.0 / (1.0 + exp(-classificators[i]));
|
||||
|
||||
// Let a lot of detections in - they'll be slowly rejected later
|
||||
if (score <= this->device->dynamic_config.nms_threshold.val) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const struct anchor *anchor = &anchors[i];
|
||||
|
||||
// Boundary box.
|
||||
NMSPalm det;
|
||||
|
||||
float anchx = anchor->x * 128;
|
||||
float anchy = anchor->y * 128;
|
||||
|
||||
float shiftx = regressors[i * 18];
|
||||
float shifty = regressors[i * 18 + 1];
|
||||
|
||||
float w = regressors[i * 18 + 2];
|
||||
float h = regressors[i * 18 + 3];
|
||||
|
||||
float cx = shiftx + anchx;
|
||||
float cy = shifty + anchy;
|
||||
|
||||
struct xrt_vec2 *kps = det.keypoints;
|
||||
|
||||
kps[0] = {regressors[i * 18 + 4], regressors[i * 18 + 5]};
|
||||
kps[1] = {regressors[i * 18 + 6], regressors[i * 18 + 7]};
|
||||
kps[2] = {regressors[i * 18 + 8], regressors[i * 18 + 9]};
|
||||
kps[3] = {regressors[i * 18 + 10], regressors[i * 18 + 11]};
|
||||
kps[4] = {regressors[i * 18 + 12], regressors[i * 18 + 13]};
|
||||
kps[5] = {regressors[i * 18 + 14], regressors[i * 18 + 15]};
|
||||
kps[6] = {regressors[i * 18 + 16], regressors[i * 18 + 17]};
|
||||
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
struct xrt_vec2 *b = &kps[i];
|
||||
b->x += anchx;
|
||||
b->y += anchy;
|
||||
}
|
||||
|
||||
det.bbox.w = w;
|
||||
det.bbox.h = h;
|
||||
det.bbox.cx = cx;
|
||||
det.bbox.cy = cy;
|
||||
det.confidence = score;
|
||||
detections.push_back(det);
|
||||
|
||||
if (htv->htd->debug_scribble && (htv->htd->dynamic_config.scribble_raw_detections)) {
|
||||
xrt_vec2 center = transformVecBy2x3(xrt_vec2{cx, cy}, back_from_blackbar);
|
||||
|
||||
float sz = det.bbox.w * scale_factor;
|
||||
|
||||
cv::rectangle(htv->debug_out_to_this,
|
||||
{(int)(center.x - (sz / 2)), (int)(center.y - (sz / 2)), (int)sz, (int)sz},
|
||||
hsv2rgb(0.0f, math_map_ranges(det.confidence, 0.0f, 1.0f, 1.5f, -0.1f),
|
||||
math_map_ranges(det.confidence, 0.0f, 1.0f, 0.2f, 1.4f)),
|
||||
1);
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
handDot(htv->debug_out_to_this, transformVecBy2x3(kps[i], back_from_blackbar),
|
||||
det.confidence * 7, ((float)i) * (360.0f / 7.0f), det.confidence, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this->api->ReleaseValue(output_tensor[0]);
|
||||
this->api->ReleaseValue(output_tensor[1]);
|
||||
|
||||
std::vector<Palm7KP> output;
|
||||
if (detections.empty()) {
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<NMSPalm> nms_palms = filterBoxesWeightedAvg(detections, htv->htd->dynamic_config.nms_iou.val);
|
||||
|
||||
for (const NMSPalm &cooler : nms_palms) {
|
||||
|
||||
// Display box
|
||||
|
||||
struct xrt_vec2 tl = {cooler.bbox.cx - cooler.bbox.w / 2, cooler.bbox.cy - cooler.bbox.h / 2};
|
||||
struct xrt_vec2 bob = transformVecBy2x3(tl, back_from_blackbar);
|
||||
float sz = cooler.bbox.w * scale_factor;
|
||||
|
||||
if (htv->htd->debug_scribble && htv->htd->dynamic_config.scribble_nms_detections) {
|
||||
cv::rectangle(htv->debug_out_to_this, {(int)bob.x, (int)bob.y, (int)sz, (int)sz},
|
||||
hsv2rgb(180.0f, math_map_ranges(cooler.confidence, 0.0f, 1.0f, 0.8f, -0.1f),
|
||||
math_map_ranges(cooler.confidence, 0.0f, 1.0f, 0.2f, 1.4f)),
|
||||
2);
|
||||
for (int i = 0; i < 7; i++) {
|
||||
handDot(htv->debug_out_to_this,
|
||||
transformVecBy2x3(cooler.keypoints[i], back_from_blackbar),
|
||||
cooler.confidence * 14, ((float)i) * (360.0f / 7.0f), cooler.confidence, 3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Palm7KP this_element;
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
struct xrt_vec2 b = cooler.keypoints[i];
|
||||
this_element.kps[i] = transformVecBy2x3(b, back_from_blackbar);
|
||||
}
|
||||
this_element.confidence = cooler.confidence;
|
||||
|
||||
output.push_back(this_element);
|
||||
}
|
||||
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Hand2D
|
||||
ht_model::hand_landmark(const cv::Mat input)
|
||||
{
|
||||
std::scoped_lock lock(this->hand_landmark_lock);
|
||||
|
||||
// TODO use opencv to handle input preprocessing
|
||||
constexpr size_t lix = 224;
|
||||
constexpr size_t liy = 224;
|
||||
constexpr size_t nb_planes = 3;
|
||||
cv::Mat planes[nb_planes];
|
||||
|
||||
constexpr size_t size = lix * liy * nb_planes;
|
||||
|
||||
std::vector<uint8_t> combined_planes(size);
|
||||
planarize(input, combined_planes.data());
|
||||
|
||||
// Normalize - supposedly, the keypoint estimator wants keypoints in [0,1]
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
this->hand_landmark_data[i] = (float)combined_planes[i] / 255.0;
|
||||
}
|
||||
|
||||
static const char *const input_names[] = {"input_1"};
|
||||
static const char *const output_names[] = {"Identity", "Identity_1", "Identity_2"};
|
||||
|
||||
OrtValue *output_tensor[] = {nullptr, nullptr, nullptr};
|
||||
ORT(Run(this->hand_landmark_session, nullptr, input_names, &this->hand_landmark_tensor, 1, output_names, 3,
|
||||
output_tensor));
|
||||
|
||||
Hand2D hand{};
|
||||
|
||||
float *landmarks = nullptr;
|
||||
|
||||
// Should give a pointer to data that is freed on g_ort->ReleaseValue(output_tensor[0]);.
|
||||
ORT(GetTensorMutableData(output_tensor[0], (void **)&landmarks));
|
||||
|
||||
constexpr int stride = 3;
|
||||
for (size_t i = 0; i < 21; i++) {
|
||||
int rt = i * stride;
|
||||
float x = landmarks[rt];
|
||||
float y = landmarks[rt + 1];
|
||||
float z = landmarks[rt + 2];
|
||||
hand.kps[i].x = x;
|
||||
hand.kps[i].y = y;
|
||||
hand.kps[i].z = z;
|
||||
}
|
||||
|
||||
this->api->ReleaseValue(output_tensor[0]);
|
||||
this->api->ReleaseValue(output_tensor[1]);
|
||||
this->api->ReleaseValue(output_tensor[2]);
|
||||
|
||||
return hand;
|
||||
}
|
51
src/xrt/drivers/ht/ht_model.hpp
Normal file
51
src/xrt/drivers/ht/ht_model.hpp
Normal file
|
@ -0,0 +1,51 @@
|
|||
// Copyright 2021, Collabora, Ltd.
|
||||
// SPDX-License-Identifier: BSL-1.0
|
||||
/*!
|
||||
* @file
|
||||
* @brief Code to run machine learning models for camera-based hand tracker.
|
||||
* @author Moses Turner <moses@collabora.com>
|
||||
* @author Marcus Edel <marcus.edel@collabora.com>
|
||||
* @author Simon Zeni <simon@bl4ckb0ne.ca>
|
||||
* @ingroup drv_ht
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <core/session/onnxruntime_c_api.h>
|
||||
#include <opencv2/core/mat.hpp>
|
||||
|
||||
#include <filesystem>
|
||||
#include <array>
|
||||
|
||||
class ht_model
|
||||
{
|
||||
struct ht_device *device = nullptr;
|
||||
|
||||
const OrtApi *api = nullptr;
|
||||
OrtEnv *env = nullptr;
|
||||
|
||||
OrtMemoryInfo *palm_detection_meminfo = nullptr;
|
||||
OrtSession *palm_detection_session = nullptr;
|
||||
OrtValue *palm_detection_tensor = nullptr;
|
||||
std::array<float, 3 * 128 * 128> palm_detection_data;
|
||||
|
||||
std::mutex hand_landmark_lock;
|
||||
OrtMemoryInfo *hand_landmark_meminfo = nullptr;
|
||||
OrtSession *hand_landmark_session = nullptr;
|
||||
OrtValue *hand_landmark_tensor = nullptr;
|
||||
std::array<float, 3 * 224 * 224> hand_landmark_data;
|
||||
|
||||
void
|
||||
init_palm_detection(OrtSessionOptions *opts);
|
||||
void
|
||||
init_hand_landmark(OrtSessionOptions *opts);
|
||||
|
||||
public:
|
||||
ht_model(struct ht_device *htd);
|
||||
~ht_model();
|
||||
|
||||
std::vector<Palm7KP>
|
||||
palm_detection(ht_view *htv, const cv::Mat &input);
|
||||
Hand2D
|
||||
hand_landmark(const cv::Mat input);
|
||||
};
|
|
@ -92,6 +92,7 @@ lib_drv_ht = static_library(
|
|||
'ht/ht_driver.hpp',
|
||||
'ht/ht_interface.h',
|
||||
'ht/ht_models.cpp',
|
||||
'ht/ht_model.cpp',
|
||||
'ht/ht_hand_math.cpp',
|
||||
'ht/ht_image_math.cpp',
|
||||
'ht/ht_nms.cpp',
|
||||
|
|
Loading…
Reference in a new issue