tcp_comm.c 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. /**
  2. * Copyright (c) 2022 Brian Starkey <stark3y@gmail.com>
  3. *
  4. * Parts based on the Pico W tcp_server example:
  5. * Copyright (c) 2022 Raspberry Pi (Trading) Ltd.
  6. *
  7. * SPDX-License-Identifier: BSD-3-Clause
  8. */
  9. #include <stdlib.h>
  10. #include "pico/cyw43_arch.h"
  11. #include "lwip/pbuf.h"
  12. #include "lwip/tcp.h"
  13. #include "tcp_comm.h"
  14. #define DEBUG_printf printf
  15. #define POLL_TIME_S 5
  16. #define COMM_MAX_NARG 5
  17. #define COMM_MAX_DATA_LEN 1024
  18. #define COMM_RSP_OK (('O' << 0) | ('K' << 8) | ('O' << 16) | ('K' << 24))
  19. #define COMM_RSP_ERR (('E' << 0) | ('R' << 8) | ('R' << 16) | ('!' << 24))
  20. enum conn_state {
  21. CONN_STATE_WAIT_FOR_SYNC,
  22. CONN_STATE_READ_OPCODE,
  23. CONN_STATE_READ_ARGS,
  24. CONN_STATE_READ_DATA,
  25. CONN_STATE_HANDLE,
  26. CONN_STATE_WRITE_RESP,
  27. CONN_STATE_WRITE_ERROR,
  28. CONN_STATE_CLOSED,
  29. };
  30. struct tcp_comm_ctx {
  31. struct tcp_pcb *serv_pcb;
  32. volatile bool serv_done;
  33. enum conn_state conn_state;
  34. struct tcp_pcb *client_pcb;
  35. uint8_t buf[(sizeof(uint32_t) * (1 + COMM_MAX_NARG)) + COMM_MAX_DATA_LEN];
  36. uint16_t rx_bytes_received;
  37. uint16_t rx_bytes_remaining;
  38. uint16_t tx_bytes_sent;
  39. uint16_t tx_bytes_remaining;
  40. uint32_t resp_data_len;
  41. const struct comm_command *cmd;
  42. const struct comm_command *const *cmds;
  43. unsigned int n_cmds;
  44. uint32_t sync_opcode;
  45. };
  46. #define COMM_BUF_OPCODE(_buf) ((uint32_t *)((uint8_t *)(_buf)))
  47. #define COMM_BUF_ARGS(_buf) ((uint32_t *)((uint8_t *)(_buf) + sizeof(uint32_t)))
  48. #define COMM_BUF_BODY(_buf, _nargs) ((uint8_t *)(_buf) + (sizeof(uint32_t) * ((_nargs) + 1)))
  49. static const struct comm_command *find_command_desc(struct tcp_comm_ctx *ctx, uint32_t opcode)
  50. {
  51. unsigned int i;
  52. for (i = 0; i < ctx->n_cmds; i++) {
  53. if (ctx->cmds[i]->opcode == opcode) {
  54. return ctx->cmds[i];
  55. }
  56. }
  57. return NULL;
  58. }
  59. static bool is_error(uint32_t status)
  60. {
  61. return status == COMM_RSP_ERR;
  62. }
  63. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx);
  64. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx);
  65. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx);
  66. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx);
  67. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx);
  68. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx);
  69. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len);
  70. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx);
  71. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx);
  72. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx);
  73. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx);
  74. static int tcp_comm_sync_begin(struct tcp_comm_ctx *ctx)
  75. {
  76. ctx->conn_state = CONN_STATE_WAIT_FOR_SYNC;
  77. ctx->rx_bytes_received = 0;
  78. ctx->rx_bytes_remaining = sizeof(uint32_t);
  79. DEBUG_printf("sync_begin %d\n", ctx->rx_bytes_remaining);
  80. }
  81. static int tcp_comm_sync_complete(struct tcp_comm_ctx *ctx)
  82. {
  83. if (ctx->sync_opcode != *COMM_BUF_OPCODE(ctx->buf)) {
  84. DEBUG_printf("sync not correct: %c%c%c%c\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  85. return tcp_comm_error_begin(ctx);
  86. }
  87. return tcp_comm_opcode_complete(ctx);
  88. }
  89. static int tcp_comm_opcode_begin(struct tcp_comm_ctx *ctx)
  90. {
  91. ctx->conn_state = CONN_STATE_READ_OPCODE;
  92. ctx->rx_bytes_received = 0;
  93. ctx->rx_bytes_remaining = sizeof(uint32_t);
  94. return 0;
  95. }
  96. static int tcp_comm_opcode_complete(struct tcp_comm_ctx *ctx)
  97. {
  98. ctx->cmd = find_command_desc(ctx, *COMM_BUF_OPCODE(ctx->buf));
  99. if (!ctx->cmd) {
  100. DEBUG_printf("no command for '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  101. return tcp_comm_error_begin(ctx);
  102. } else {
  103. DEBUG_printf("got command '%c%c%c%c'\n", ctx->buf[0], ctx->buf[1], ctx->buf[2], ctx->buf[3]);
  104. }
  105. return tcp_comm_args_begin(ctx);
  106. }
  107. static int tcp_comm_args_begin(struct tcp_comm_ctx *ctx)
  108. {
  109. ctx->conn_state = CONN_STATE_READ_ARGS;
  110. ctx->rx_bytes_received = 0;
  111. ctx->rx_bytes_remaining = ctx->cmd->nargs * sizeof(uint32_t);
  112. if (ctx->cmd->nargs == 0) {
  113. return tcp_comm_args_complete(ctx);
  114. }
  115. return 0;
  116. }
  117. static int tcp_comm_args_complete(struct tcp_comm_ctx *ctx)
  118. {
  119. const struct comm_command *cmd = ctx->cmd;
  120. uint32_t data_len = 0;
  121. if (cmd->size) {
  122. uint32_t status = cmd->size(COMM_BUF_ARGS(ctx->buf),
  123. &data_len,
  124. &ctx->resp_data_len);
  125. if (is_error(status)) {
  126. return tcp_comm_error_begin(ctx);
  127. }
  128. }
  129. return tcp_comm_data_begin(ctx, data_len);
  130. }
  131. static int tcp_comm_data_begin(struct tcp_comm_ctx *ctx, uint32_t data_len)
  132. {
  133. const struct comm_command *cmd = ctx->cmd;
  134. ctx->conn_state = CONN_STATE_READ_DATA;
  135. ctx->rx_bytes_received = 0;
  136. ctx->rx_bytes_remaining = data_len;
  137. if (data_len == 0) {
  138. return tcp_comm_data_complete(ctx);
  139. }
  140. return 0;
  141. }
  142. static int tcp_comm_data_complete(struct tcp_comm_ctx *ctx)
  143. {
  144. const struct comm_command *cmd = ctx->cmd;
  145. if (cmd->handle) {
  146. uint32_t status = cmd->handle(COMM_BUF_ARGS(ctx->buf),
  147. COMM_BUF_BODY(ctx->buf, cmd->nargs),
  148. COMM_BUF_ARGS(ctx->buf),
  149. COMM_BUF_BODY(ctx->buf, cmd->resp_nargs));
  150. if (is_error(status)) {
  151. return tcp_comm_error_begin(ctx);
  152. }
  153. *COMM_BUF_OPCODE(ctx->buf) = status;
  154. } else {
  155. // TODO: Should we just assert(desc->handle)?
  156. *COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_OK;
  157. }
  158. return tcp_comm_response_begin(ctx);
  159. }
  160. static int tcp_comm_response_begin(struct tcp_comm_ctx *ctx)
  161. {
  162. ctx->conn_state = CONN_STATE_WRITE_RESP;
  163. ctx->tx_bytes_sent = 0;
  164. ctx->tx_bytes_remaining = ctx->resp_data_len + ((ctx->cmd->resp_nargs + 1) * sizeof(uint32_t));
  165. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  166. if (err != ERR_OK) {
  167. return -1;
  168. }
  169. return 0;
  170. }
  171. static int tcp_comm_error_begin(struct tcp_comm_ctx *ctx)
  172. {
  173. ctx->conn_state = CONN_STATE_WRITE_ERROR;
  174. ctx->tx_bytes_sent = 0;
  175. ctx->tx_bytes_remaining = sizeof(uint32_t);
  176. *COMM_BUF_OPCODE(ctx->buf) = COMM_RSP_ERR;
  177. err_t err = tcp_write(ctx->client_pcb, ctx->buf, ctx->tx_bytes_remaining, 0);
  178. if (err != ERR_OK) {
  179. return -1;
  180. }
  181. return 0;
  182. }
  183. static int tcp_comm_response_complete(struct tcp_comm_ctx *ctx)
  184. {
  185. return tcp_comm_opcode_begin(ctx);
  186. }
  187. static int tcp_comm_rx_complete(struct tcp_comm_ctx *ctx)
  188. {
  189. switch (ctx->conn_state) {
  190. case CONN_STATE_WAIT_FOR_SYNC:
  191. return tcp_comm_sync_complete(ctx);
  192. case CONN_STATE_READ_OPCODE:
  193. return tcp_comm_opcode_complete(ctx);
  194. case CONN_STATE_READ_ARGS:
  195. return tcp_comm_args_complete(ctx);
  196. case CONN_STATE_READ_DATA:
  197. return tcp_comm_data_complete(ctx);
  198. default:
  199. return -1;
  200. }
  201. }
  202. static int tcp_comm_tx_complete(struct tcp_comm_ctx *ctx)
  203. {
  204. switch (ctx->conn_state) {
  205. case CONN_STATE_WRITE_RESP:
  206. return tcp_comm_response_complete(ctx);
  207. case CONN_STATE_WRITE_ERROR:
  208. return -1;
  209. default:
  210. return -1;
  211. }
  212. }
  213. static err_t tcp_comm_client_close(struct tcp_comm_ctx *ctx)
  214. {
  215. err_t err = ERR_OK;
  216. cyw43_arch_gpio_put (0, false);
  217. ctx->conn_state = CONN_STATE_CLOSED;
  218. if (!ctx->client_pcb) {
  219. return err;
  220. }
  221. tcp_arg(ctx->client_pcb, NULL);
  222. tcp_poll(ctx->client_pcb, NULL, 0);
  223. tcp_sent(ctx->client_pcb, NULL);
  224. tcp_recv(ctx->client_pcb, NULL);
  225. tcp_err(ctx->client_pcb, NULL);
  226. err = tcp_close(ctx->client_pcb);
  227. if (err != ERR_OK) {
  228. DEBUG_printf("close failed %d, calling abort\n", err);
  229. tcp_abort(ctx->client_pcb);
  230. err = ERR_ABRT;
  231. }
  232. ctx->client_pcb = NULL;
  233. return err;
  234. }
  235. err_t tcp_comm_server_close(struct tcp_comm_ctx *ctx)
  236. {
  237. err_t err = ERR_OK;
  238. err = tcp_comm_client_close(ctx);
  239. if ((err != ERR_OK) && ctx->serv_pcb) {
  240. tcp_arg(ctx->serv_pcb, NULL);
  241. tcp_abort(ctx->serv_pcb);
  242. ctx->serv_pcb = NULL;
  243. return ERR_ABRT;
  244. }
  245. if (!ctx->serv_pcb) {
  246. return err;
  247. }
  248. tcp_arg(ctx->serv_pcb, NULL);
  249. err = tcp_close(ctx->serv_pcb);
  250. if (err != ERR_OK) {
  251. tcp_abort(ctx->serv_pcb);
  252. err = ERR_ABRT;
  253. }
  254. ctx->serv_pcb = NULL;
  255. return err;
  256. }
  257. static void tcp_comm_server_complete(void *arg, int status)
  258. {
  259. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  260. if (status == 0) {
  261. DEBUG_printf("server completed normally\n");
  262. } else {
  263. DEBUG_printf("server error %d\n", status);
  264. }
  265. tcp_comm_server_close(ctx);
  266. ctx->serv_done = true;
  267. }
  268. static err_t tcp_comm_client_complete(void *arg, int status)
  269. {
  270. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  271. if (status == 0) {
  272. DEBUG_printf("conn completed normally\n");
  273. } else {
  274. DEBUG_printf("conn error %d\n", status);
  275. }
  276. return tcp_comm_client_close(ctx);
  277. }
  278. static err_t tcp_comm_client_sent(void *arg, struct tcp_pcb *tpcb, u16_t len)
  279. {
  280. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  281. DEBUG_printf("tcp_comm_server_sent %u\n", len);
  282. cyw43_arch_lwip_check();
  283. if (len > ctx->tx_bytes_remaining) {
  284. DEBUG_printf("tx len %d > remaining %d\n", len, ctx->tx_bytes_remaining);
  285. return tcp_comm_client_complete(ctx, ERR_ARG);
  286. }
  287. ctx->tx_bytes_remaining -= len;
  288. ctx->tx_bytes_sent += len;
  289. if (ctx->tx_bytes_remaining == 0) {
  290. int res = tcp_comm_tx_complete(ctx);
  291. if (res) {
  292. return tcp_comm_client_complete(ctx, ERR_ARG);
  293. }
  294. }
  295. return ERR_OK;
  296. }
  297. static err_t tcp_comm_client_recv(void *arg, struct tcp_pcb *tpcb, struct pbuf *p, err_t err)
  298. {
  299. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  300. if (!p) {
  301. DEBUG_printf("no pbuf\n");
  302. return tcp_comm_client_complete(ctx, 0);
  303. }
  304. // this method is callback from lwIP, so cyw43_arch_lwip_begin is not required, however you
  305. // can use this method to cause an assertion in debug mode, if this method is called when
  306. // cyw43_arch_lwip_begin IS needed
  307. cyw43_arch_lwip_check();
  308. if (p->tot_len > 0) {
  309. DEBUG_printf("tcp_comm_server_recv %d err %d\n", p->tot_len, err);
  310. size_t to_copy = p->tot_len > ctx->rx_bytes_remaining ? ctx->rx_bytes_remaining : p->tot_len;
  311. // Receive the buffer
  312. if (pbuf_copy_partial(p, ctx->buf + ctx->rx_bytes_received, to_copy, 0) != to_copy) {
  313. DEBUG_printf("wrong copy len\n");
  314. return tcp_comm_client_complete(ctx, ERR_ARG);
  315. }
  316. ctx->rx_bytes_received += to_copy;
  317. ctx->rx_bytes_remaining -= to_copy;
  318. tcp_recved(tpcb, p->tot_len);
  319. if (ctx->rx_bytes_remaining == 0) {
  320. int res = tcp_comm_rx_complete(ctx);
  321. if (res) {
  322. return tcp_comm_client_complete(ctx, ERR_ARG);
  323. }
  324. }
  325. }
  326. pbuf_free(p);
  327. return ERR_OK;
  328. }
  329. static err_t tcp_comm_client_poll(void *arg, struct tcp_pcb *tpcb)
  330. {
  331. DEBUG_printf("tcp_comm_server_poll_fn\n");
  332. return ERR_OK;
  333. }
  334. static void tcp_comm_client_err(void *arg, err_t err)
  335. {
  336. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  337. DEBUG_printf("tcp_comm_err %d\n", err);
  338. ctx->client_pcb = NULL;
  339. ctx->conn_state = CONN_STATE_CLOSED;
  340. ctx->rx_bytes_remaining = 0;
  341. cyw43_arch_gpio_put (0, false);
  342. }
  343. static void tcp_comm_client_init(struct tcp_comm_ctx *ctx, struct tcp_pcb *pcb)
  344. {
  345. ctx->client_pcb = pcb;
  346. tcp_arg(pcb, ctx);
  347. cyw43_arch_gpio_put (0, true);
  348. tcp_comm_sync_begin(ctx);
  349. tcp_sent(pcb, tcp_comm_client_sent);
  350. tcp_recv(pcb, tcp_comm_client_recv);
  351. tcp_poll(pcb, tcp_comm_client_poll, POLL_TIME_S * 2);
  352. tcp_err(pcb, tcp_comm_client_err);
  353. }
  354. static err_t tcp_comm_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err)
  355. {
  356. struct tcp_comm_ctx *ctx = (struct tcp_comm_ctx *)arg;
  357. if (err != ERR_OK || client_pcb == NULL) {
  358. DEBUG_printf("Failure in accept\n");
  359. tcp_comm_server_complete(ctx, err);
  360. return ERR_VAL;
  361. }
  362. DEBUG_printf("Connection opened\n");
  363. if (ctx->client_pcb) {
  364. DEBUG_printf("Already have a connection\n");
  365. tcp_abort(client_pcb);
  366. return ERR_ABRT;
  367. }
  368. tcp_comm_client_init(ctx, client_pcb);
  369. return ERR_OK;
  370. }
  371. err_t tcp_comm_listen(struct tcp_comm_ctx *ctx, uint16_t port)
  372. {
  373. DEBUG_printf("Starting server at %s on port %u\n", ip4addr_ntoa(netif_ip4_addr(netif_list)), port);
  374. ctx->serv_done = false;
  375. struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY);
  376. if (!pcb) {
  377. DEBUG_printf("failed to create pcb\n");
  378. return ERR_MEM;
  379. }
  380. err_t err = tcp_bind(pcb, NULL, port);
  381. if (err) {
  382. DEBUG_printf("failed to bind to port %d\n", port);
  383. tcp_abort(pcb);
  384. return err;
  385. }
  386. ctx->serv_pcb = tcp_listen_with_backlog_and_err(pcb, 1, &err);
  387. if (!ctx->serv_pcb) {
  388. DEBUG_printf("failed to listen: %d\n", err);
  389. return err;
  390. }
  391. tcp_arg(ctx->serv_pcb, ctx);
  392. tcp_accept(ctx->serv_pcb, tcp_comm_server_accept);
  393. return ERR_OK;
  394. }
  395. struct tcp_comm_ctx *tcp_comm_new(const struct comm_command *const *cmds,
  396. unsigned int n_cmds, uint32_t sync_opcode)
  397. {
  398. struct tcp_comm_ctx *ctx = calloc(1, sizeof(struct tcp_comm_ctx));
  399. if (!ctx) {
  400. return NULL;
  401. }
  402. unsigned int i;
  403. for (i = 0; i < n_cmds; i++) {
  404. assert(cmds[i]->nargs <= MAX_NARG);
  405. assert(cmds[i]->resp_nargs <= MAX_NARG);
  406. }
  407. ctx->cmds = cmds;
  408. ctx->n_cmds = n_cmds;
  409. ctx->sync_opcode = sync_opcode;
  410. return ctx;
  411. }
  412. void tcp_comm_delete(struct tcp_comm_ctx *ctx)
  413. {
  414. tcp_comm_server_close(ctx);
  415. free(ctx);
  416. }
  417. bool tcp_comm_server_done(struct tcp_comm_ctx *ctx)
  418. {
  419. return ctx->serv_done;
  420. }