Skip to content

merged #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# db config
DB_HOST=dev-circleo-pg.celp3nik7oaq.ap-northeast-1.rds.amazonaws.com
DB_PORT=5432
DB_USER=postgres
DB_PASSWORD=AeEGDB0b7Z5GK0E2tblt
DB_NAME=team-gpt-dev
DB_CHARSET=utf8
# auth0
AUTH0_DOMAIN=dev-1x5li4ewlxn3t8ed.jp.auth0.com
AUTH0_CLIENT_ID=Y5B6TmFCUR4Q43i9UNnCCbbYVqtaXf5X
AUTH0_SECRET=ZeAmoDfQecWkWC1LZH5GWC-cpN-JFXh8SmEUmZtbrGymmLaiRkh_XD13qPQNuOGi
AUTH0_REDIRECT_URI=http://localhost:3000/auth
AUTH0_LOGOUT_REDIRECT_URI=https://teamgpt.me
AUTH0_API_AUDIENCE=https://teamgpt-dev.felo.me
AUTH0_ADMIN_CLIENT_ID=Y5B6TmFCUR4Q43i9UNnCCbbYVqtaXf5X
AUTH0_ADMIN_CLIENT_SRCRET=ZeAmoDfQecWkWC1LZH5GWC-cpN-JFXh8SmEUmZtbrGymmLaiRkh_XD13qPQNuOGi
AUTH0_ADMIN_API_AUDIENCE=https://dev-1x5li4ewlxn3t8ed.jp.auth0.com/api/v2/
# gpt-key
GPT_KEY=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
GPT_PROXY_URL=http://127.0.0.1:7890
# stripe
STRIPE_API_KEY=sk_test_51MvbmlHOzaxB9ER0GQWFIuKOqVa0dy7AbdFivMkMp7Y5EwLPnCv4wYeTUUcRLM5bUWNPJcXMOaerysTbJ2cfQgPD00j2H31oFR
# domain
DOMAIN=https://teamgpt.me
# midjourney-proxy
MIDJOURNEY_PROXY_URL=http://202.5.26.232:8080
MIDJOURNEY_HOOK=https://dev.teamgpt.me/api/v1/midjourney_proxy/hook
# 过滤模型提供的Chat接口,在调用OpenAI接口前会过滤敏感信息。
FILTER_MODEL_CHAT_URL=http://202.5.26.162:8000/privacyChat/
21 changes: 21 additions & 0 deletions migrations/models/12_20230908115345_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tortoise import BaseDBAsyncClient


async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "maskcontent" (
"id" UUID NOT NULL PRIMARY KEY,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"deleted_at" TIMESTAMPTZ,
"content" VARCHAR(255),
"entities" JSONB,
"mask_content" VARCHAR(255),
"mask_result" VARCHAR(255),
"result" VARCHAR(255)
);;"""


async def downgrade(db: BaseDBAsyncClient) -> str:
return """
DROP TABLE IF EXISTS "maskcontent";"""
11 changes: 11 additions & 0 deletions migrations/models/13_20230909174356_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tortoise import BaseDBAsyncClient


async def upgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "conversationsmessage" ADD "privacy_chat_sta" BOOL DEFAULT False;"""


async def downgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "conversationsmessage" DROP COLUMN "privacy_chat_sta";"""
15 changes: 15 additions & 0 deletions migrations/models/14_20230909175824_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from tortoise import BaseDBAsyncClient


async def upgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "maskcontent" ADD "masked_content" VARCHAR(255);
ALTER TABLE "maskcontent" ADD "masked_result" VARCHAR(255);
ALTER TABLE "maskcontent" ADD "privacy_detected" BOOL DEFAULT False;"""


async def downgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "maskcontent" DROP COLUMN "masked_content";
ALTER TABLE "maskcontent" DROP COLUMN "masked_result";
ALTER TABLE "maskcontent" DROP COLUMN "privacy_detected";"""
13 changes: 13 additions & 0 deletions migrations/models/15_20230909175835_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from tortoise import BaseDBAsyncClient


async def upgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "maskcontent" DROP COLUMN "mask_content";
ALTER TABLE "maskcontent" DROP COLUMN "mask_result";"""


async def downgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "maskcontent" ADD "mask_content" VARCHAR(255);
ALTER TABLE "maskcontent" ADD "mask_result" VARCHAR(255);"""
13 changes: 7 additions & 6 deletions teamgpt/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import nest_asyncio
import openai
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import (get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html)
from starlette import status
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import JSONResponse
from tortoise.contrib.fastapi import register_tortoise
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.errors import ServerErrorMiddleware

from teamgpt.endpoints import router
from teamgpt.settings import TORTOISE_ORM, GPT_PROXY_URL
from fastapi.openapi.docs import (get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html)
from fastapi import FastAPI

nest_asyncio.apply()

Expand Down Expand Up @@ -65,6 +66,6 @@ async def oauth2_redirect():
return get_swagger_ui_oauth2_redirect_html()


@app.get('/')
@app.get('/api/v1/test')
def read_root():
return 'hello'
56 changes: 46 additions & 10 deletions teamgpt/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,45 @@
from typing import Union

import openai
import requests
from fastapi import APIRouter, Depends, HTTPException, Security, Query
from fastapi_auth0 import Auth0User
from fastapi_pagination import Page
from sse_starlette import EventSourceResponse
from starlette.status import HTTP_204_NO_CONTENT

from teamgpt import settings
from teamgpt.endpoints.stripe import org_payment_plan
from teamgpt.enums import GptModel, ContentType, AutherUser, GptKeySource
from teamgpt.models import User, Conversations, ConversationsMessage, GPTKey, Organization, SysGPTKey, GptChatMessage
from teamgpt.models import User, Conversations, ConversationsMessage, GPTKey, Organization, SysGPTKey, GptChatMessage, \
MaskContent
from teamgpt.parameters import ListAPIParams, tortoise_paginate
from teamgpt.schemata import ConversationsIn, ConversationsOut, ConversationsMessageIn, ConversationsMessageOut
from teamgpt.schemata import ConversationsIn, ConversationsOut, ConversationsMessageIn, MaskContentInput
from teamgpt.settings import (auth)
from teamgpt.util.entity_detector import EntityDetector
from teamgpt.util.gpt import ask, num_tokens_from_messages, msg_tiktoken_num
from teamgpt.util.gpt import ask, num_tokens_from_messages, msg_tiktoken_num, privacy_ask

router = APIRouter(prefix='/api/v1/conversations', tags=['Conversations'])


@router.post('/message/mask_content')
async def create_mask_content(mask_content_input: MaskContentInput,
user: Auth0User = Security(auth.get_user)):
url = settings.FILTER_MODEL_CHAT_URL + 'maskContent/'
headers = {"Content-Type": "application/json"}
response = requests.post(url, json={
"content": mask_content_input.content,
}, headers=headers)
req_data = response.json()
info = await MaskContent.create(id=uuid.UUID(mask_content_input.id), content=mask_content_input.content,
entities=req_data['entities'],
masked_content=req_data['masked_content'],
masked_result=req_data['masked_result'],
privacy_detected=req_data['privacy_detected'],
result=req_data['result'])
return info


@router.get("/message/test/{key}")
async def test(key: str):
async def event_generator():
Expand Down Expand Up @@ -114,7 +135,8 @@ async def create_conversations_message(
model: Union[GptModel, None] = Query(default=GptModel.GPT3TURBO),
context_number: Union[int, None] = Query(default=5),
encrypt_sensitive_data: Union[bool, None] = Query(default=False),
user: Auth0User = Security(auth.get_user)
user: Auth0User = Security(auth.get_user),
privacy_chat_sta: Union[bool, None] = Query(default=False),
):
user_info = await User.get_or_none(user_id=user.id, deleted_at__isnull=True)
# 查询gpt-key配置信息,判断是否是系统用户
Expand Down Expand Up @@ -204,7 +226,8 @@ async def create_conversations_message(
content_type=conversations_input.content_type,
key=key,
shown_message=conversations_input.shown_message,
model=model
model=model,
privacy_chat_sta=privacy_chat_sta
)
else:
await ConversationsMessage.create(id=uuid.UUID(str(conversations_input.id)), user=user_info,
Expand All @@ -214,7 +237,8 @@ async def create_conversations_message(
content_type=conversations_input.content_type,
shown_message=conversations_input.shown_message,
key=key,
model=model
model=model,
privacy_chat_sta=privacy_chat_sta
)
# 查询前5条消息
con_org = await ConversationsMessage.filter(user=user_info, conversation_id=conversation_id,
Expand Down Expand Up @@ -243,7 +267,10 @@ async def send_gpt():
while prompt_tokens > 4000:
message_log.pop(-1)
prompt_tokens = await num_tokens_from_messages(message_log[::-1], model=model)
agen = ask(key, message_log[::-1], model, conversation_id)
if privacy_chat_sta is True:
agen = privacy_ask(message_log[::-1], model, conversation_id)
else:
agen = ask(key, message_log[::-1], model, conversation_id)
async for event in agen:
event_data = json.loads(event['data'])
if event_data['sta'] == 'run':
Expand All @@ -258,7 +285,7 @@ async def send_gpt():
content_type=ContentType.TEXT,
key=key,
prompt_tokens=prompt_tokens,
model=model
model=model,
)
new_msg_obj_id = str(new_msg_obj.id)
event_data['msg_id'] = new_msg_obj_id
Expand Down Expand Up @@ -288,7 +315,7 @@ async def send_gpt():

# get conversations message
@router.get('/message/{conversations_id}',
response_model=Page[ConversationsMessageOut], dependencies=[Depends(auth.implicit_scheme)])
dependencies=[Depends(auth.implicit_scheme)])
async def get_conversations_message(
conversations_id: str,
user: Auth0User = Security(auth.get_user),
Expand All @@ -297,4 +324,13 @@ async def get_conversations_message(
user_info = await User.get_or_none(user_id=user.id, deleted_at__isnull=True)
con_org = ConversationsMessage.filter(user=user_info, conversation_id=conversations_id,
deleted_at__isnull=True)
return await tortoise_paginate(con_org, params)
req_list = await tortoise_paginate(con_org, params)
for i in range(len(req_list.items)):
obj = req_list.items[i].__dict__
del obj['_partial']
del obj['_custom_generated_pk']
if req_list.items[i].privacy_chat_sta:
obj['mask_content'] = await MaskContent.get_or_none(id=req_list.items[i].id,
deleted_at__isnull=True)
req_list.items[i] = obj
return req_list
1 change: 1 addition & 0 deletions teamgpt/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class GptModel(str, Enum):
GPT4_0613= 'gpt-4-0613'
GPT4_32K= 'gpt-4-32k'
GPT4_32K_0613= 'gpt-4-32k-0613'
GPT4_0125= 'gpt-4-0125-preview'


class GptKeySource(str, Enum):
Expand Down
16 changes: 16 additions & 0 deletions teamgpt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class ConversationsMessage(AbstractBaseModelWithDeletedAt):
completion_tokens = fields.IntField(null=True)
total_tokens = fields.IntField(null=True)
model = fields.CharEnumField(GptModel, max_length=100, null=True)
privacy_chat_sta = fields.BooleanField(null=True, default=False)

class PydanticMeta:
exclude = (
Expand Down Expand Up @@ -415,3 +416,18 @@ class PydanticMeta:
'updated_at',
'deleted_at',
)


class MaskContent(AbstractBaseModelWithDeletedAt):
content = fields.CharField(max_length=255, null=True)
entities = fields.JSONField(null=True)
masked_content = fields.CharField(max_length=255, null=True)
masked_result = fields.CharField(max_length=255, null=True)
privacy_detected = fields.BooleanField(null=True, default=False)
result = fields.CharField(max_length=255, null=True)

class PydanticMeta:
exclude = (
'updated_at',
'deleted_at',
)
21 changes: 21 additions & 0 deletions teamgpt/schemata.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,27 @@ class MidjourneyProxySubmitResponse(BaseModel):
),
)

MaskContentIn = pydantic_model_creator(
models.MaskContent,
name='MaskContentIn',
exclude=(
'id',
'created_at',
'updated_at',
),
)


class MaskContentInput(BaseModel):
content: Optional[str] = None
id: Optional[str] = None


MaskContentOut = pydantic_model_creator(
models.MaskContent,
name='MaskContentOut',
)


class MidjourneyProxyHookToIn(BaseModel):
id: Optional[str] = None
Expand Down
7 changes: 7 additions & 0 deletions teamgpt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def import_temp_env():
},
}

# 过滤模型提供的Chat接口,在调用OpenAI接口前会过滤敏感信息。
FILTER_MODEL_CHAT_URL = os.getenv(
'FILTER_MODEL_CHAT_URL', '')

OPENAI_API_BASE = os.getenv(
'OPENAI_API_BASE', 'https://api.openai.com')


def api_key() -> str:
alphabet = string.ascii_letters + string.digits
Expand Down
Loading