Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Zod schemas for improved API input validation #538

Merged
merged 10 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"format": "biome format src",
"format:fix": "biome format src --write",
"start": "node .next/standalone/server.js",
"test:dev": "jest --config jest.pages.config.ts && jest --config jest.api.config.ts",
"test": "jest --config jest.pages.config.ts && jest --config jest.api.config.ts --ci --coverage",
"test:dev": "jest --detectOpenHandles --verbose --config jest.pages.config.ts && jest --config jest.api.config.ts",
"test": "jest --config jest.pages.config.ts && jest --verbose --config jest.api.config.ts --ci --coverage",
"studio": "prisma studio"
},
"dependencies": {
Expand Down Expand Up @@ -109,4 +109,4 @@
"prisma": {
"seed": "ts-node --compiler-options {\"module\":\"CommonJS\"} prisma/seed.ts"
}
}
}
15 changes: 13 additions & 2 deletions src/pages/api/__tests__/v1/application/statistic.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@ import { NextApiRequest, NextApiResponse } from "next";
describe("/api/stats", () => {
it("should allow only GET method", async () => {
const methods = ["DELETE", "POST", "PUT", "PATCH", "OPTIONS", "HEAD"];
const req = {} as NextApiRequest;
const req = {
method: "GET",
headers: {
"x-ztnet-auth": "validApiKey",
},
query: {},
body: {},
} as unknown as NextApiRequest;

const res = {
status: jest.fn().mockReturnThis(),
end: jest.fn(),
json: jest.fn().mockReturnThis(),
setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it
setHeader: jest.fn(),
} as unknown as NextApiResponse;

for (const method of methods) {
Expand All @@ -29,7 +37,10 @@ describe("/api/stats", () => {
const req = {
method: "GET",
headers: { "x-ztnet-auth": "invalidApiKey" },
query: {},
body: {},
} as unknown as NextApiRequest;

const res = {
status: jest.fn().mockReturnThis(),
end: jest.fn(),
Expand Down
8 changes: 5 additions & 3 deletions src/pages/api/__tests__/v1/network/network.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { NextApiRequest, NextApiResponse } from "next";

describe("/api/createNetwork", () => {
it("should respond 405 to unsupported methods", async () => {
const req = { method: "PUT" } as NextApiRequest;
const req = { method: "PUT", query: {} } as NextApiRequest;
const res = {
status: jest.fn().mockReturnThis(),
end: jest.fn(),
Expand All @@ -20,12 +20,13 @@ describe("/api/createNetwork", () => {
const req = {
method: "POST",
headers: { "x-ztnet-auth": "invalidApiKey" },
query: {},
} as unknown as NextApiRequest;
const res = {
status: jest.fn().mockReturnThis(),
end: jest.fn(),
json: jest.fn().mockReturnThis(),
setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it
setHeader: jest.fn(),
} as unknown as NextApiResponse;

await apiNetworkHandler(req, res);
Expand All @@ -37,12 +38,13 @@ describe("/api/createNetwork", () => {
const req = {
method: "GET",
headers: { "x-ztnet-auth": "invalidApiKey" },
query: {},
} as unknown as NextApiRequest;
const res = {
status: jest.fn().mockReturnThis(),
end: jest.fn(),
json: jest.fn().mockReturnThis(),
setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it
setHeader: jest.fn(),
} as unknown as NextApiResponse;

await apiNetworkHandler(req, res);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ describe("Update Network Members", () => {
method: "POST",
headers: { "x-ztnet-auth": "validApiKey" },
query: { id: "networkId", memberId: "memberId" },
body: { name: "New Name", authorized: "true" },
body: { name: "New Name", authorized: true },
} as unknown as NextApiRequest;

// Mock the database to return a network
Expand Down Expand Up @@ -114,7 +114,7 @@ describe("Update Network Members", () => {
method: "POST",
headers: { "x-ztnet-auth": "validApiKey" },
query: { id: "networkId", memberId: "memberId" },
body: { name: "New Name", authorized: "true" },
body: { name: "New Name", authorized: true },
} as unknown as NextApiRequest;

const res = createMockRes();
Expand Down Expand Up @@ -160,7 +160,7 @@ describe("Update Network Members", () => {
method: "POST",
headers: { "x-ztnet-auth": "invalidApiKey" },
query: { id: "networkId", memberId: "memberId" },
body: { name: "New Name", authorized: "true" },
body: { name: "New Name", authorized: true },
} as unknown as NextApiRequest;

const res = createMockRes();
Expand Down
1 change: 1 addition & 0 deletions src/pages/api/__tests__/v1/org/org.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ describe("organization api validation", () => {
.mockResolvedValue({ id: "newUserId", name: "Ztnet", email: "post@ztnet.network" });

mockRequest.headers["x-ztnet-auth"] = "not valid token";
mockRequest.query = {};

await GET_userOrganization(
mockRequest as NextApiRequest,
Expand Down
4 changes: 2 additions & 2 deletions src/pages/api/__tests__/v1/org/orgid.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ describe("organization api validation", () => {
const validToken = encrypt(validTokenData, generateInstanceSecret(API_TOKEN_SECRET));
mockRequest.headers["x-ztnet-auth"] = validToken;

// add organizationId to the request
mockRequest.query = undefined;
// add empty query
mockRequest.query = {};
await apiNetworkHandler(
mockRequest as NextApiRequest,
mockResponse as NextApiResponse,
Expand Down
26 changes: 22 additions & 4 deletions src/pages/api/__tests__/v1/user/user.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { NextApiRequest, NextApiResponse } from "next";
import createUserHandler, { POST_createUser } from "~/pages/api/v1/user";
import createUserHandler from "~/pages/api/v1/user";
import { prisma } from "~/server/db";
import { appRouter } from "~/server/api/root";
import { API_TOKEN_SECRET, encrypt, generateInstanceSecret } from "~/utils/encryption";
Expand All @@ -18,7 +18,12 @@ jest.mock("~/server/api/root", () => ({
})),
},
}));

jest.mock("~/utils/rateLimit", () => ({
__esModule: true,
default: () => ({
check: jest.fn().mockResolvedValue(true),
}),
}));
jest.mock("~/server/api/trpc");

jest.mock("~/server/db", () => ({
Expand Down Expand Up @@ -126,9 +131,19 @@ describe("createUserHandler", () => {
}),
}));

mockRequest.method = "POST";
mockRequest.headers["x-ztnet-auth"] = "not defined";
mockRequest.body = {
email: "ztnet@example.com",
password: "password123",
name: "Ztnet",
};

await createUserHandler(
mockRequest as NextApiRequest,
mockResponse as NextApiResponse,
);

await POST_createUser(mockRequest as NextApiRequest, mockResponse as NextApiResponse);
expect(mockResponse.status).toHaveBeenCalledWith(200);

// Check if the response is as expected
Expand Down Expand Up @@ -166,6 +181,7 @@ describe("createUserHandler", () => {
method: "POST",
headers: { "x-ztnet-auth": tokenWithIdHash },
body: { email: "test@example.com", password: "password123", name: "Test User" },
query: {},
} as unknown as NextApiRequest;

const res = {
Expand Down Expand Up @@ -208,7 +224,9 @@ describe("createUserHandler", () => {

it("should allow only POST method", async () => {
const methods = ["GET", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"];
const req = {} as NextApiRequest;
const req = {
query: {},
} as NextApiRequest;
const res = createMockRes();

for (const method of methods) {
Expand Down
40 changes: 40 additions & 0 deletions src/pages/api/v1/network/[id]/member/[memberId]/_schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { z } from "zod";

// Schema for updateable fields metadata
export const updateableFieldsMetaSchema = z
.object({
name: z.string().optional(),
authorized: z.boolean().optional(),
})
.strict();

// Schema for the context passed to the handler
export const handlerContextSchema = z.object({
body: z.record(z.unknown()),
userId: z.string(),
networkId: z.string(),
memberId: z.string(),
ctx: z.object({
prisma: z.any(),
session: z.object({
user: z.object({
id: z.string(),
}),
}),
}),
});

// Schema for the context passed to the DELETE handler
export const deleteHandlerContextSchema = z.object({
userId: z.string(),
networkId: z.string(),
memberId: z.string(),
ctx: z.object({
prisma: z.any(),
session: z.object({
user: z.object({
id: z.string(),
}),
}),
}),
});
54 changes: 24 additions & 30 deletions src/pages/api/v1/network/[id]/member/[memberId]/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import { SecuredPrivateApiRoute } from "~/utils/apiRouteAuth";
import { handleApiErrors } from "~/utils/errors";
import rateLimit from "~/utils/rateLimit";
import * as ztController from "~/utils/ztApi";
import {
deleteHandlerContextSchema,
handlerContextSchema,
updateableFieldsMetaSchema,
} from "./_schema";

// Number of allowed requests per minute
const limiter = rateLimit({
Expand All @@ -15,20 +20,6 @@ const limiter = rateLimit({

const REQUEST_PR_MINUTE = 50;

// Function to parse and validate fields based on the expected type
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const parseField = (key: string, value: any, expectedType: string) => {
if (expectedType === "string") {
return value; // Assume all strings are valid
}
if (expectedType === "boolean") {
if (value === "true" || value === "false") {
return value === "true";
}
throw new Error(`Field '${key}' expected to be boolean, got: ${value}`);
}
};

export default async function apiNetworkUpdateMembersHandler(
req: NextApiRequest,
res: NextApiResponse,
Expand Down Expand Up @@ -66,39 +57,39 @@ const POST_updateNetworkMember = SecuredPrivateApiRoute(
requireNetworkId: true,
requireMemberId: true,
},
async (_req, res, { body, userId, networkId, memberId, ctx }) => {
if (Object.keys(body).length === 0) {
return res.status(400).json({ error: "No data provided for update" });
}
async (_req, res, context) => {
const validatedContext = handlerContextSchema.parse(context);
const { body, userId, networkId, memberId, ctx } = validatedContext;

// Validate the input data
const validatedInput = updateableFieldsMetaSchema.parse(body);

// structure of the updateableFields object:
const updateableFields = {
name: { type: "string", destinations: ["database"] },
name: { type: "string", destinations: ["controller", "database"] },
authorized: { type: "boolean", destinations: ["controller"] },
};

if (Object.keys(body).length === 0) {
return res.status(400).json({ error: "No data provided for update" });
}

const databasePayload: Partial<network_members> = {};
const controllerPayload: Partial<network_members> = {};

// Iterate over keys in the request body
for (const key in body) {
// Check if the key is not in updateableFields
if (!(key in updateableFields)) {
return res.status(400).json({ error: `Invalid field: ${key}` });
}

for (const [key, value] of Object.entries(validatedInput)) {
try {
const parsedValue = parseField(key, body[key], updateableFields[key].type);
if (updateableFields[key].destinations.includes("database")) {
databasePayload[key] = parsedValue;
databasePayload[key] = value;
}
if (updateableFields[key].destinations.includes("controller")) {
controllerPayload[key] = parsedValue;
controllerPayload[key] = value;
}
} catch (error) {
return res.status(400).json({ error: error.message });
}
}

try {
// make sure the member is valid
const network = await prisma.network.findUnique({
Expand Down Expand Up @@ -184,7 +175,10 @@ const DELETE_deleteNetworkMember = SecuredPrivateApiRoute(
requireNetworkId: true,
requireMemberId: true,
},
async (_req, res, { userId, networkId, memberId, ctx }) => {
async (_req, res, context) => {
const validatedContext = deleteHandlerContextSchema.parse(context);
const { userId, networkId, memberId, ctx } = validatedContext;

try {
// make sure the member is valid
const network = await prisma.network.findUnique({
Expand Down
21 changes: 21 additions & 0 deletions src/pages/api/v1/network/_schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { z } from "zod";

// Schema for the request body when creating a new network
export const createNetworkBodySchema = z
.object({
name: z.string().optional(),
})
.strict();

// Schema for the context passed to the handler
export const createNetworkContextSchema = z.object({
body: createNetworkBodySchema,
ctx: z.object({
prisma: z.any(),
session: z.object({
user: z.object({
id: z.string(),
}),
}),
}),
});
9 changes: 7 additions & 2 deletions src/pages/api/v1/network/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { SecuredPrivateApiRoute } from "~/utils/apiRouteAuth";
import { handleApiErrors } from "~/utils/errors";
import rateLimit from "~/utils/rateLimit";
import * as ztController from "~/utils/ztApi";
import { createNetworkContextSchema } from "./_schema";

// Number of allowed requests per minute
const limiter = rateLimit({
Expand Down Expand Up @@ -42,9 +43,13 @@ const POST_createNewNetwork = SecuredPrivateApiRoute(
{
requireNetworkId: false,
},
async (_req, res, { body, ctx }) => {
// If there are users, verify the API key
async (_req, res, context) => {
try {
// Validate the context (which includes the body)
const validatedContext = createNetworkContextSchema.parse(context);
const { body, ctx } = validatedContext;

// If there are users, verify the API key
const { name } = body;

const newNetworkId = await networkProvisioningFactory({
Expand Down
Loading
Loading