Skip to content

enforce tenancy on search and repo listing endpoints #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"build": "yarn workspaces run build",
"test": "yarn workspaces run test",
"dev": "cross-env SOURCEBOT_TENANT_MODE=single npm-run-all --print-label dev:start",
"dev:mt": "cross-env SOURCEBOT_TENANT_MODE=multi npm-run-all --print-label dev:start",
"dev:mt": "cross-env SOURCEBOT_TENANT_MODE=multi npm-run-all --print-label dev:start:mt",
"dev:start": "yarn workspace @sourcebot/db prisma:migrate:dev && cross-env npm-run-all --print-label --parallel dev:zoekt dev:backend dev:web",
"dev:start:mt": "yarn workspace @sourcebot/db prisma:migrate:dev && cross-env npm-run-all --print-label --parallel dev:zoekt:mt dev:backend dev:web",
"dev:zoekt": "export PATH=\"$PWD/bin:$PATH\" && export SRC_TENANT_ENFORCEMENT_MODE=none && zoekt-webserver -index .sourcebot/index -rpc",
"dev:zoekt:mt": "export PATH=\"$PWD/bin:$PATH\" && export SRC_TENANT_ENFORCEMENT_MODE=strict && zoekt-webserver -index .sourcebot/index -rpc",
"dev:backend": "yarn workspace @sourcebot/backend dev:watch",
"dev:web": "yarn workspace @sourcebot/web dev"
},
Expand Down
118 changes: 16 additions & 102 deletions packages/web/src/actions.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
'use server';

import Ajv from "ajv";
import { getUser } from "./data/user";
import { auth } from "./auth";
import { notAuthenticated, notFound, ServiceError, unexpectedError } from "./lib/serviceError";
import { auth, getCurrentUserOrg } from "./auth";
import { notAuthenticated, notFound, ServiceError, unexpectedError } from "@/lib/serviceError";
import { prisma } from "@/prisma";
import { StatusCodes } from "http-status-codes";
import { ErrorCode } from "./lib/errorCodes";
import { ErrorCode } from "@/lib/errorCodes";
import { isServiceError } from "@/lib/utils";
import { githubSchema } from "@sourcebot/schemas/v3/github.schema";
import { encrypt } from "@sourcebot/crypto"

Expand All @@ -15,31 +15,9 @@ const ajv = new Ajv({
});

export const createSecret = async (key: string, value: string): Promise<{ success: boolean } | ServiceError> => {
const session = await auth();
if (!session) {
return notAuthenticated();
}

const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}

// @todo: refactor this into a shared function
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

try {
Expand All @@ -62,30 +40,9 @@ export const createSecret = async (key: string, value: string): Promise<{ succes
}

export const getSecrets = async (): Promise<{ createdAt: Date; key: string; }[] | ServiceError> => {
const session = await auth();
if (!session) {
return notAuthenticated();
}

const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}

const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

const secrets = await prisma.secret.findMany({
Expand All @@ -105,30 +62,9 @@ export const getSecrets = async (): Promise<{ createdAt: Date; key: string; }[]
}

export const deleteSecret = async (key: string): Promise<{ success: boolean } | ServiceError> => {
const session = await auth();
if (!session) {
return notAuthenticated();
}

const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}

const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

await prisma.secret.delete({
Expand Down Expand Up @@ -206,31 +142,9 @@ export const switchActiveOrg = async (orgId: number): Promise<{ id: number } | S
}

export const createConnection = async (config: string): Promise<{ id: number } | ServiceError> => {
const session = await auth();
if (!session) {
return notAuthenticated();
}

const user = await getUser(session.user.id);
if (!user) {
return unexpectedError("User not found");
}
const orgId = user.activeOrgId;
if (!orgId) {
return unexpectedError("User has no active org");
}

// @todo: refactor this into a shared function
const membership = await prisma.userToOrg.findUnique({
where: {
orgId_userId: {
userId: session.user.id,
orgId,
}
},
});
if (!membership) {
return notFound();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

let parsedConfig;
Expand Down
9 changes: 8 additions & 1 deletion packages/web/src/app/api/(server)/repos/route.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
'use server';

import { listRepositories } from "@/lib/server/searchService";
import { getCurrentUserOrg } from "../../../../auth";
import { isServiceError } from "@/lib/utils";

export const GET = async () => {
const response = await listRepositories();
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

const response = await listRepositories(orgId);
return Response.json(response);
}
20 changes: 9 additions & 11 deletions packages/web/src/app/api/(server)/search/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,25 @@ import { searchRequestSchema } from "@/lib/schemas";
import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { NextRequest } from "next/server";
import { getCurrentUserOrg } from "../../../../auth";

export const POST = async (request: NextRequest) => {
const body = await request.json();
const tenantId = request.headers.get("X-Tenant-ID");

console.log(`Search request received. Tenant ID: ${tenantId}`);
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

const parsed = await searchRequestSchema.safeParseAsync({
...body,
...(tenantId ? {
tenantId: parseInt(tenantId)
} : {}),
});
console.log(`Searching for org ${orgId}`);
const body = await request.json();
const parsed = await searchRequestSchema.safeParseAsync(body);
if (!parsed.success) {
return serviceErrorResponse(
schemaValidationError(parsed.error)
);
}


const response = await search(parsed.data);
const response = await search(parsed.data, orgId);
if (isServiceError(response)) {
return serviceErrorResponse(response);
}
Expand Down
8 changes: 7 additions & 1 deletion packages/web/src/app/api/(server)/source/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ import { getFileSource } from "@/lib/server/searchService";
import { schemaValidationError, serviceErrorResponse } from "@/lib/serviceError";
import { isServiceError } from "@/lib/utils";
import { NextRequest } from "next/server";
import { getCurrentUserOrg } from "@/auth";

export const POST = async (request: NextRequest) => {
const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return orgId;
}

const body = await request.json();
const parsed = await fileSourceRequestSchema.safeParseAsync(body);
if (!parsed.success) {
Expand All @@ -15,7 +21,7 @@ export const POST = async (request: NextRequest) => {
);
}

const response = await getFileSource(parsed.data);
const response = await getFileSource(parsed.data, orgId);
if (isServiceError(response)) {
return serviceErrorResponse(response);
}
Expand Down
17 changes: 15 additions & 2 deletions packages/web/src/app/browse/[...path]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { CodePreview } from "./codePreview";
import { PageNotFound } from "@/app/components/pageNotFound";
import { ErrorCode } from "@/lib/errorCodes";
import { LuFileX2, LuBookX } from "react-icons/lu";
import { getCurrentUserOrg } from "@/auth";

interface BrowsePageProps {
params: {
Expand Down Expand Up @@ -44,9 +45,18 @@ export default async function BrowsePage({
}
})();

const orgId = await getCurrentUserOrg();
if (isServiceError(orgId)) {
return (
<>
Error: {orgId.message}
</>
)
}

// @todo (bkellam) : We should probably have a endpoint to fetch repository metadata
// given it's name or id.
const reposResponse = await listRepositories();
const reposResponse = await listRepositories(orgId);
if (isServiceError(reposResponse)) {
// @todo : proper error handling
return (
Expand Down Expand Up @@ -98,6 +108,7 @@ export default async function BrowsePage({
path={path}
repoName={repoName}
revisionName={revisionName ?? 'HEAD'}
orgId={orgId}
/>
)}
</div>
Expand All @@ -108,19 +119,21 @@ interface CodePreviewWrapper {
path: string,
repoName: string,
revisionName: string,
orgId: number,
}

const CodePreviewWrapper = async ({
path,
repoName,
revisionName,
orgId,
}: CodePreviewWrapper) => {
// @todo: this will depend on `pathType`.
const fileSourceResponse = await getFileSource({
fileName: path,
repository: repoName,
branch: revisionName,
});
}, orgId);

if (isServiceError(fileSourceResponse)) {
if (fileSourceResponse.errorCode === ErrorCode.FILE_NOT_FOUND) {
Expand Down
Loading