Legacy NET tftp invokes a store_block() function which performs buffer
validation (LMB, address wrapping). Do the same with NET_LWIP.

Signed-off-by: Jerome Forissier <jerome.foriss...@linaro.org>
Suggested-by: Sughosh Ganu <sughosh.g...@linaro.org>
---
 net/lwip/tftp.c | 45 ++++++++++++++++++++++++++++++++++-----------
 1 file changed, 34 insertions(+), 11 deletions(-)

diff --git a/net/lwip/tftp.c b/net/lwip/tftp.c
index 123d66b5dba..38aa647df5c 100644
--- a/net/lwip/tftp.c
+++ b/net/lwip/tftp.c
@@ -8,6 +8,7 @@
 #include <efi_loader.h>
 #include <image.h>
 #include <linux/delay.h>
+#include <linux/kconfig.h>
 #include <lwip/apps/tftp_client.h>
 #include <lwip/timeouts.h>
 #include <mapmem.h>
@@ -31,6 +32,36 @@ struct tftp_ctx {
        enum done_state done;
 };
 
+static int store_block(struct tftp_ctx *ctx, void *src, u16_t len)
+{
+       ulong store_addr = ctx->daddr;
+       void *ptr;
+
+       if (CONFIG_IS_ENABLED(LMB)) {
+               if (store_addr + len < store_addr ||
+                   lmb_read_check(store_addr, len)) {
+                       puts("\nTFTP error: ");
+                       puts("trying to overwrite reserved memory...\n");
+                       return -1;
+               }
+       }
+
+       ptr = map_sysmem(store_addr, len);
+       memcpy(ptr, src, len);
+       unmap_sysmem(ptr);
+
+       ctx->daddr += len;
+       ctx->size += len;
+       ctx->block_count++;
+       if (ctx->block_count % 10 == 0) {
+               putc('#');
+               if (ctx->block_count % (65 * 10) == 0)
+                       puts("\n\t ");
+       }
+
+       return 0;
+}
+
 static void *tftp_open(const char *fname, const char *mode, u8_t is_write)
 {
        return NULL;
@@ -71,17 +102,9 @@ static int tftp_write(void *handle, struct pbuf *p)
        struct tftp_ctx *ctx = handle;
        struct pbuf *q;
 
-       for (q = p; q; q = q->next) {
-               memcpy((void *)ctx->daddr, q->payload, q->len);
-               ctx->daddr += q->len;
-               ctx->size += q->len;
-               ctx->block_count++;
-               if (ctx->block_count % 10 == 0) {
-                       putc('#');
-                       if (ctx->block_count % (65 * 10) == 0)
-                               puts("\n\t ");
-               }
-       }
+       for (q = p; q; q = q->next)
+               if (store_block(ctx, q->payload, q->len) < 0)
+                       return -1;
 
        return 0;
 }
-- 
2.43.0

Reply via email to