From f32f624916d8bac8a0ca04e375c0f0fa04f95254 Mon Sep 17 00:00:00 2001
From: dumbmoron <log@riseup.net>
Date: Thu, 8 Aug 2024 16:34:54 +0000
Subject: [PATCH] api: use zod for request schema validation

---
 api/package.json              |  3 +-
 api/src/core/api.js           |  4 +--
 api/src/misc/run-test.js      |  4 +--
 api/src/processing/request.js | 65 ++---------------------------------
 api/src/processing/schema.js  | 42 ++++++++++++++++++++++
 api/src/util/test.js          |  6 ++--
 pnpm-lock.yaml                |  8 +++++
 7 files changed, 62 insertions(+), 70 deletions(-)
 create mode 100644 api/src/processing/schema.js

diff --git a/api/package.json b/api/package.json
index 7336447a..8ebe4c7e 100644
--- a/api/package.json
+++ b/api/package.json
@@ -40,7 +40,8 @@
         "set-cookie-parser": "2.6.0",
         "undici": "^5.19.1",
         "url-pattern": "1.0.3",
-        "youtubei.js": "^10.3.0"
+        "youtubei.js": "^10.3.0",
+        "zod": "^3.23.8"
     },
     "optionalDependencies": {
         "freebind": "^0.2.2"
diff --git a/api/src/core/api.js b/api/src/core/api.js
index 1de90033..0bb6347a 100644
--- a/api/src/core/api.js
+++ b/api/src/core/api.js
@@ -138,8 +138,8 @@ export function runAPI(express, app, __dirname) {
             request.youtubeDubLang = lang;
         }
 
-        const normalizedRequest = normalizeRequest(request);
-        if (!normalizedRequest) {
+        const { success, data: normalizedRequest } = await normalizeRequest(request);
+        if (!success) {
             return fail('ErrorCantProcess');
         }
 
diff --git a/api/src/misc/run-test.js b/api/src/misc/run-test.js
index ef3d5f51..2f8fb3c4 100644
--- a/api/src/misc/run-test.js
+++ b/api/src/misc/run-test.js
@@ -3,8 +3,8 @@ import match from "../processing/match.js";
 import { extract } from "../processing/url.js";
 
 export async function runTest(url, params, expect) {
-    const normalized = normalizeRequest({ url, ...params });
-    if (!normalized) {
+    const { success, data: normalized } = await normalizeRequest({ url, ...params });
+    if (!success) {
         throw "invalid request";
     }
 
diff --git a/api/src/processing/request.js b/api/src/processing/request.js
index c503788c..4cd4d478 100644
--- a/api/src/processing/request.js
+++ b/api/src/processing/request.js
@@ -1,26 +1,7 @@
 import ipaddr from "ipaddr.js";
 
-import { normalizeURL } from "./url.js";
 import { createStream } from "../stream/manage.js";
-import { verifyLanguageCode } from "../misc/utils.js";
-
-const apiRequest = {
-    option: {
-        audioFormat: ["best", "mp3", "ogg", "wav", "opus"],
-        downloadMode: ["auto", "audio", "mute"],
-        filenameStyle: ["classic", "pretty", "basic", "nerdy"],
-        videoQuality: ["max", "4320", "2160", "1440", "1080", "720", "480", "360", "240", "144"],
-        youtubeVideoCodec: ["h264", "av1", "vp9"],
-    },
-    boolean: [
-        "disableMetadata",
-        "tiktokFullAudio",
-        "tiktokH265",
-        "twitterGif",
-        "youtubeDubBrowserLang",
-        "youtubeDubLang"
-    ]
-}
+import { apiSchema } from "./schema.js";
 
 export function createResponse(responseType, responseData) {
     const internalError = (code) => {
@@ -91,49 +72,7 @@ export function createResponse(responseType, responseData) {
 }
 
 export function normalizeRequest(request) {
-    try {
-        let template = {
-            audioFormat: "mp3",
-            url: normalizeURL(decodeURIComponent(request.url)),
-            youtubeVideoCodec: "h264",
-            videoQuality: "720",
-            filenameStyle: "classic",
-            downloadMode: "auto",
-            tiktokFullAudio: false,
-            disableMetadata: false,
-            youtubeDubBrowserLang: false,
-            youtubeDubLang: false,
-            twitterGif: false,
-            tiktokH265: false
-        }
-
-        const requestKeys = Object.keys(request);
-        const templateKeys = Object.keys(template);
-
-        if (requestKeys.length > templateKeys.length + 1 || !request.url) {
-            return false;
-        }
-
-        for (const i in requestKeys) {
-            const key = requestKeys[i];
-            const item = request[key];
-
-            if (String(key) !== "url" && templateKeys.includes(key)) {
-                if (apiRequest.boolean.includes(key)) {
-                    template[key] = !!item;
-                } else if (apiRequest.option[key] && apiRequest.option[key].includes(item)) {
-                    template[key] = String(item)
-                }
-            }
-        }
-
-        if (template.youtubeDubBrowserLang)
-            template.youtubeDubLang = verifyLanguageCode(request.youtubeDubLang);
-
-        return template
-    } catch {
-        return false
-    }
+    return apiSchema.safeParseAsync(request).catch(() => ({ success: false }));
 }
 
 export function getIP(req) {
diff --git a/api/src/processing/schema.js b/api/src/processing/schema.js
new file mode 100644
index 00000000..5b8a93f0
--- /dev/null
+++ b/api/src/processing/schema.js
@@ -0,0 +1,42 @@
+import { z } from "zod";
+
+import { normalizeURL } from "./url.js";
+import { verifyLanguageCode } from "../misc/utils.js";
+
+export const apiSchema = z.object({
+    url: z.string()
+          .min(1)
+          .transform((url) => normalizeURL(decodeURIComponent(url))),
+
+    audioFormat: z.enum(
+        ["best", "mp3", "ogg", "wav", "opus"]
+    ).default("mp3"),
+
+    downloadMode: z.enum(
+        ["auto", "audio", "mute"]
+    ).default("auto"),
+
+    filenameStyle: z.enum(
+        ["classic", "pretty", "basic", "nerdy"]
+    ).default("classic"),
+
+    youtubeVideoCodec: z.enum(
+        ["h264", "av1", "vp9"]
+    ).default("h264"),
+
+    videoQuality: z.enum([
+        "max", "4320", "2160", "1440", "1080", "720", "480", "360", "240", "144"
+    ]).default("720"),
+
+    youtubeDubLang: z.string()
+                     .length(2)
+                     .transform(verifyLanguageCode)
+                     .optional(),
+
+    disableMetadata: z.boolean().default(false),
+    tiktokFullAudio: z.boolean().default(false),
+    tiktokH265: z.boolean().default(false),
+    twitterGif: z.boolean().default(false),
+    youtubeDubBrowserLang: z.boolean().default(false),
+})
+.strict();
diff --git a/api/src/util/test.js b/api/src/util/test.js
index 704473f4..c0f06498 100644
--- a/api/src/util/test.js
+++ b/api/src/util/test.js
@@ -34,8 +34,10 @@ for (let i in services) {
             let params = {...{url: test.url}, ...test.params};
             console.log(params);
 
-            let chck = normalizeRequest(params);
-            if (chck) {
+            let chck = await normalizeRequest(params);
+            if (chck.success) {
+                chck = chck.data;
+
                 const parsed = extract(chck.url);
                 if (parsed === null) {
                     throw `Invalid URL: ${chck.url}`
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 1d6d724e..04a0c6d6 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -61,6 +61,9 @@ importers:
       youtubei.js:
         specifier: ^10.3.0
         version: 10.3.0
+      zod:
+        specifier: ^3.23.8
+        version: 3.23.8
     optionalDependencies:
       freebind:
         specifier: ^0.2.2
@@ -1834,6 +1837,9 @@ packages:
   youtubei.js@10.3.0:
     resolution: {integrity: sha512-tLmeJCECK2xF2hZZtF2nEqirdKVNLFSDpa0LhTaXY3tngtL7doQXyy7M2CLueramDTlmCnFaW+rctHirTPFaRQ==}
 
+  zod@3.23.8:
+    resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
+
 snapshots:
 
   '@ampproject/remapping@2.3.0':
@@ -3405,3 +3411,5 @@ snapshots:
       jintr: 2.1.1
       tslib: 2.6.3
       undici: 5.28.4
+
+  zod@3.23.8: {}