From 0f70df5de3da99293741a668786b076bb68ff548 Mon Sep 17 00:00:00 2001 From: wh201906 Date: Mon, 6 Mar 2023 19:55:41 +0800 Subject: [PATCH] Add TCP connection support on Windows The Windows Sockets 2 API is similar to the BSD Sockets API, so I can reuse a lot of code in uart_posix.c --- client/src/uart/uart_win32.c | 258 ++++++++++++++++++++++++++++++++--- 1 file changed, 241 insertions(+), 17 deletions(-) diff --git a/client/src/uart/uart_win32.c b/client/src/uart/uart_win32.c index 1e57daeb2..cd610c8fa 100644 --- a/client/src/uart/uart_win32.c +++ b/client/src/uart/uart_win32.c @@ -27,14 +27,24 @@ // The windows serial port implementation #ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN #include +#include +#include typedef struct { HANDLE hPort; // Serial port handle DCB dcb; // Device control settings COMMTIMEOUTS ct; // Serial port time-out configuration + SOCKET hSocket; // Socket handle } serial_port_windows_t; +// this is for TCP connection +struct timeval timeout = { + .tv_sec = 0, // 0 second + .tv_usec = UART_TCP_CLIENT_RX_TIMEOUT_MS * 1000 +}; + uint32_t newtimeout_value = 0; bool newtimeout_pending = false; @@ -69,11 +79,111 @@ static int uart_reconfigure_timeouts_polling(serial_port sp) { serial_port uart_open(const char *pcPortName, uint32_t speed) { char acPortName[255] = {0}; serial_port_windows_t *sp = calloc(sizeof(serial_port_windows_t), sizeof(uint8_t)); + sp->hSocket = INVALID_SOCKET; // default: serial port if (sp == 0) { PrintAndLogEx(WARNING, "UART failed to allocate memory\n"); return INVALID_SERIAL_PORT; } + + char *prefix = strdup(pcPortName); + if (prefix == NULL) { + PrintAndLogEx(ERR, "error: string duplication"); + free(sp); + return INVALID_SERIAL_PORT; + } + str_lower(prefix); + + if (memcmp(prefix, "tcp:", 4) == 0) { + free(prefix); + + if (strlen(pcPortName) <= 4) { + free(sp); + return INVALID_SERIAL_PORT; + } + + struct addrinfo *addr = NULL, *rp; + + char *addrstr = strdup(pcPortName + 4); + if (addrstr == NULL) { + PrintAndLogEx(ERR, "error: string duplication"); + free(sp); + return INVALID_SERIAL_PORT; + } + + timeout.tv_usec = UART_TCP_CLIENT_RX_TIMEOUT_MS * 1000; + + char *colon = strrchr(addrstr, ':'); + const char *portstr; + if (colon) { + portstr = colon + 1; + *colon = '\0'; + } else { + portstr = "18888"; + } + + WSADATA wsaData; + struct addrinfo info; + int iResult; + + iResult = WSAStartup(MAKEWORD(2,2), &wsaData); + if (iResult != 0) { + PrintAndLogEx(ERR, "error: WSAStartup failed with error: %d", iResult); + free(sp); + return INVALID_SERIAL_PORT; + } + + memset(&info, 0, sizeof(info)); + info.ai_socktype = SOCK_STREAM; + info.ai_protocol = IPPROTO_TCP; + + int s = getaddrinfo(addrstr, portstr, &info, &addr); + if (s != 0) { + PrintAndLogEx(ERR, "error: getaddrinfo: %s", gai_strerror(s)); + freeaddrinfo(addr); + free(addrstr); + free(sp); + WSACleanup(); + return INVALID_SERIAL_PORT; + } + + SOCKET hSocket = INVALID_SOCKET; + for (rp = addr; rp != NULL; rp = rp->ai_next) { + hSocket = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + + if (hSocket == INVALID_SOCKET) + continue; + + if (connect(hSocket, rp->ai_addr, (int)rp->ai_addrlen) != INVALID_SOCKET) + break; + + closesocket(hSocket); + hSocket = INVALID_SOCKET; + } + + freeaddrinfo(addr); + free(addrstr); + + if (rp == NULL) { /* No address succeeded */ + PrintAndLogEx(ERR, "error: Could not connect"); + WSACleanup(); + free(sp); + return INVALID_SERIAL_PORT; + } + + sp->hSocket = hSocket; + + int one = 1; + int res = setsockopt(sp->hSocket, IPPROTO_TCP, TCP_NODELAY, (char *)&one, sizeof(one)); + if (res != 0) { + closesocket(hSocket); + WSACleanup(); + free(sp); + return INVALID_SERIAL_PORT; + } + return sp; + } + // Copy the input "com?" to "\\.\COM?" format snprintf(acPortName, sizeof(acPortName), "\\\\.\\%s", pcPortName); _strupr(acPortName); @@ -120,8 +230,14 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) { } void uart_close(const serial_port sp) { - if (((serial_port_windows_t *)sp)->hPort != INVALID_HANDLE_VALUE) - CloseHandle(((serial_port_windows_t *)sp)->hPort); + serial_port_windows_t *spw = (serial_port_windows_t *)sp; + if (spw->hSocket != INVALID_SOCKET){ + shutdown(spw->hSocket, SD_BOTH); + closesocket(spw->hSocket); + WSACleanup(); + } + if (spw->hPort != INVALID_HANDLE_VALUE) + CloseHandle(spw->hPort); free(sp); } @@ -163,31 +279,139 @@ uint32_t uart_get_speed(const serial_port sp) { } int uart_receive(const serial_port sp, uint8_t *pbtRx, uint32_t pszMaxRxLen, uint32_t *pszRxLen) { - uart_reconfigure_timeouts_polling(sp); - int res = ReadFile(((serial_port_windows_t *)sp)->hPort, pbtRx, pszMaxRxLen, (LPDWORD)pszRxLen, NULL); - if (res) - return PM3_SUCCESS; + serial_port_windows_t *spw = (serial_port_windows_t *)sp; + if (spw->hSocket == INVALID_SOCKET) { // serial port + uart_reconfigure_timeouts_polling(sp); - int errorcode = GetLastError(); + int res = ReadFile(((serial_port_windows_t *)sp)->hPort, pbtRx, pszMaxRxLen, (LPDWORD)pszRxLen, NULL); + if (res) + return PM3_SUCCESS; - if (res == 0 && errorcode == 2) { - return PM3_EIO; + int errorcode = GetLastError(); + + if (res == 0 && errorcode == 2) { + return PM3_EIO; + } + + return PM3_ENOTTY; } + else { // TCP + uint32_t byteCount; // FIONREAD returns size on 32b + fd_set rfds; + struct timeval tv; - return PM3_ENOTTY; + if (newtimeout_pending) { + timeout.tv_usec = newtimeout_value * 1000; + newtimeout_pending = false; + } + // Reset the output count + *pszRxLen = 0; + do { + // Reset file descriptor + FD_ZERO(&rfds); + FD_SET(spw->hSocket, &rfds); + tv = timeout; + // the first argument nfds is ignored in Windows + int res = select(0, &rfds, NULL, NULL, &tv); + + // Read error + if (res == SOCKET_ERROR) { + return PM3_EIO; + } + + // Read time-out + if (res == 0) { + if (*pszRxLen == 0) { + // We received no data + return PM3_ENODATA; + } else { + // We received some data, but nothing more is available + return PM3_SUCCESS; + } + } + + // Retrieve the count of the incoming bytes + res = ioctlsocket(spw->hSocket, FIONREAD, (u_long *)&byteCount); + // PrintAndLogEx(ERR, "UART:: RX ioctl res %d byteCount %u", res, byteCount); + if (res == SOCKET_ERROR) return PM3_ENOTTY; + + // Cap the number of bytes, so we don't overrun the buffer + if (pszMaxRxLen - (*pszRxLen) < byteCount) { + // PrintAndLogEx(ERR, "UART:: RX prevent overrun (have %u, need %u)", pszMaxRxLen - (*pszRxLen), byteCount); + byteCount = pszMaxRxLen - (*pszRxLen); + } + + // There is something available, read the data + res = recv(spw->hSocket, (char *)pbtRx + (*pszRxLen), byteCount, 0); + + // Stop if the OS has some troubles reading the data + if (res <= 0) { // includes 0(gracefully closed) and -1(SOCKET_ERROR) + return PM3_EIO; + } + + *pszRxLen += res; + + if (*pszRxLen == pszMaxRxLen) { + // We have all the data we wanted. + return PM3_SUCCESS; + } + } while (byteCount); + + return PM3_SUCCESS; + } } int uart_send(const serial_port sp, const uint8_t *p_tx, const uint32_t len) { - DWORD txlen = 0; - int res = WriteFile(((serial_port_windows_t *)sp)->hPort, p_tx, len, &txlen, NULL); - if (res) + serial_port_windows_t *spw = (serial_port_windows_t *)sp; + if (spw->hSocket == INVALID_SOCKET) { // serial port + DWORD txlen = 0; + int res = WriteFile(((serial_port_windows_t *)sp)->hPort, p_tx, len, &txlen, NULL); + if (res) + return PM3_SUCCESS; + + int errorcode = GetLastError(); + if (res == 0 && errorcode == 2) { + return PM3_EIO; + } + return PM3_ENOTTY; + } + else { // TCP + uint32_t pos = 0; + fd_set wfds; + struct timeval tv; + + while (pos < len) { + // Reset file descriptor + FD_ZERO(&wfds); + FD_SET(spw->hSocket, &wfds); + tv = timeout; + // the first argument nfds is ignored in Windows + int res = select(0, NULL, &wfds, NULL, &tv); + + // Write error + if (res == SOCKET_ERROR) { + PrintAndLogEx(ERR, "UART:: write error (%d)", res); + return PM3_ENOTTY; + } + + // Write time-out + if (res == 0) { + PrintAndLogEx(ERR, "UART:: write time-out"); + return PM3_ETIMEOUT; + } + + // Send away the bytes + res = send(spw->hSocket, (const char *)p_tx + pos, len - pos, 0); + + // Stop if the OS has some troubles sending the data + if (res <= 0) + return PM3_EIO; + + pos += res; + } return PM3_SUCCESS; - int errorcode = GetLastError(); - if (res == 0 && errorcode == 2) { - return PM3_EIO; } - return PM3_ENOTTY; } #endif