ipc: Verify length of received commands on Windows

Co-authored-by: Simon Zeni <simon.zeni@collabora.com>
This commit is contained in:
Jakob Bornecrantz 2023-11-26 13:40:13 +00:00
parent 81246de70a
commit 9070894455

View file

@ -206,6 +206,18 @@ client_loop(volatile struct ipc_client_state *ics)
#else // XRT_OS_WINDOWS #else // XRT_OS_WINDOWS
static void
pipe_print_get_last_error(volatile struct ipc_client_state *ics, const char *func)
{
// This is the error path.
DWORD err = GetLastError();
if (err == ERROR_BROKEN_PIPE) {
IPC_INFO(ics->server, "%s: %d %s", func, err, ipc_winerror(err));
} else {
IPC_ERROR(ics->server, "%s failed: %d %s", func, err, ipc_winerror(err));
}
}
static void static void
client_loop(volatile struct ipc_client_state *ics) client_loop(volatile struct ipc_client_state *ics)
{ {
@ -213,28 +225,53 @@ client_loop(volatile struct ipc_client_state *ics)
IPC_INFO(ics->server, "Client connected"); IPC_INFO(ics->server, "Client connected");
uint8_t buf[IPC_BUF_SIZE];
while (ics->server->running) { while (ics->server->running) {
DWORD len; uint8_t buf[IPC_BUF_SIZE] = {0};
if (!ReadFile(ics->imc.ipc_handle, buf, sizeof(buf), &len, NULL)) { DWORD len = 0;
DWORD err = GetLastError(); BOOL bret = false;
if (err == ERROR_BROKEN_PIPE) {
IPC_INFO(ics->server, "ReadFile from pipe: %d %s", err, ipc_winerror(err)); /*
} else { * The pipe is created in message mode, the client IPC code will
IPC_ERROR(ics->server, "ReadFile from pipe failed: %d %s", err, ipc_winerror(err)); * always send the *_msg structs as one message, and any extra
} * variable length data as a different message. So even if the
* command is a variable length the first message will be sized
* to the command size, this is what we get here, variable
* length data is read in the dispatch function for the command.
*/
bret = ReadFile(ics->imc.ipc_handle, buf, sizeof(buf), &len, NULL);
if (!bret) {
pipe_print_get_last_error(ics, "ReadFile");
IPC_ERROR(ics->server, "ReadFile failed, disconnecting client.");
break; break;
} }
if (len < sizeof(ipc_command_t)) {
// All commands are at least 4 bytes.
if (len < 4) {
IPC_ERROR(ics->server, "Not enough bytes received '%u', disconnecting client.", (uint32_t)len);
break;
}
// Now safe to cast into a command pointer, used for dispatch.
ipc_command_t *cmd_ptr = (ipc_command_t *)buf;
// Read the command, we know we have at least 4 bytes.
ipc_command_t cmd = *cmd_ptr;
// Get the command length.
size_t cmd_size = ipc_command_size(cmd);
if (cmd_size == 0) {
IPC_ERROR(ics->server, "Invalid command '%u', disconnecting client.", cmd);
break;
}
// Check if the read message has the expected length.
if (len != cmd_size) {
IPC_ERROR(ics->server, "Invalid packet received, disconnecting client."); IPC_ERROR(ics->server, "Invalid packet received, disconnecting client.");
break; break;
} else { }
// Check the first 4 bytes of the message and dispatch.
ipc_command_t *ipc_command = (ipc_command_t *)buf;
IPC_TRACE_BEGIN(ipc_dispatch); IPC_TRACE_BEGIN(ipc_dispatch);
xrt_result_t result = ipc_dispatch(ics, ipc_command); xrt_result_t result = ipc_dispatch(ics, cmd_ptr);
IPC_TRACE_END(ipc_dispatch); IPC_TRACE_END(ipc_dispatch);
if (result != XRT_SUCCESS) { if (result != XRT_SUCCESS) {
@ -242,7 +279,6 @@ client_loop(volatile struct ipc_client_state *ics)
break; break;
} }
} }
}
// Following code is same for all platforms. // Following code is same for all platforms.
common_shutdown(ics); common_shutdown(ics);