forked from artnoage/Podcast
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fast_api_app.py
252 lines (204 loc) · 8.82 KB
/
fast_api_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import logging
import os
import random
import json
import asyncio
import base64
from datetime import datetime
from typing import Optional, List, Dict
from uuid import uuid4
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import base64
from pydantic import BaseModel
from openai import OpenAI
from fastapi import Request
from src.utils.utils import add_feedback_to_state, get_all_timestamps
from src.utils.textGDwithWeightClipping import optimize_prompt
from src.paudio import create_podcast_audio
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = FastAPI()
client = OpenAI()
# Create the 'static' directory if it doesn't exist
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
# In-memory task storage (replace with a proper database in production)
tasks: Dict[str, Dict] = {}
class ApiKeyRequest(BaseModel):
api_key: str
class FeedbackRequest(BaseModel):
feedback: str
old_timestamp: Optional[str] = None
new_timestamp: str
class VoteRequest(BaseModel):
timestamp: Optional[str] = None
class ExperimentIdeaRequest(BaseModel):
idea: str
VOTES_FILE = "votes.json"
EXPERIMENT_IDEAS_FILE = "experiment_ideas.md"
def load_votes():
if os.path.exists(VOTES_FILE):
with open(VOTES_FILE, 'r') as f:
content = f.read().strip()
if content:
return json.loads(content)
return {}
def save_votes(votes):
with open(VOTES_FILE, 'w') as f:
json.dump(votes, f)
@app.get("/health")
async def health_check():
return {"status": "OK"}
@app.post("/create_podcasts")
async def create_podcasts_endpoint(
background_tasks: BackgroundTasks,
pdf_content: UploadFile = File(...),
summarizer_model: str = Form("gpt-4o-mini"),
scriptwriter_model: str = Form("gpt-4o-mini"),
enhancer_model: str = Form("gpt-4o-mini"),
provider: str = Form("OpenAI")
):
logger.info(f"Starting podcast creation. PDF file name: {pdf_content.filename}")
if not pdf_content:
logger.error("No PDF file provided")
raise HTTPException(status_code=400, detail="No PDF file provided")
try:
pdf_bytes = await pdf_content.read()
logger.info(f"PDF content read successfully. Size: {len(pdf_bytes)} bytes")
task_id = str(uuid4())
tasks[task_id] = {"status": "processing", "result": None}
background_tasks.add_task(
process_podcast_creation,
task_id,
pdf_bytes,
summarizer_model,
scriptwriter_model,
enhancer_model,
provider
)
return {"task_id": task_id}
except Exception as e:
logger.error(f"Error in create_podcasts_endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@app.get("/podcast_status/{task_id}")
async def get_podcast_status(task_id: str):
task = tasks.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/get_podcast_audio/{task_id}/{podcast_type}")
async def get_podcast_audio(task_id: str, podcast_type: str):
task = tasks.get(task_id)
if not task or task["status"] != "completed":
raise HTTPException(status_code=404, detail="Audio not found or task not completed")
podcasts = task["result"]["podcasts"]
podcast = next((p for p in podcasts if p["type"] == podcast_type), None)
if not podcast:
raise HTTPException(status_code=404, detail=f"Podcast of type {podcast_type} not found")
audio_data = base64.b64decode(podcast["audio"])
return Response(content=audio_data, media_type="audio/mpeg")
async def process_podcast_creation(
task_id: str,
pdf_bytes: bytes,
summarizer_model: str,
scriptwriter_model: str,
enhancer_model: str,
provider: str
):
try:
logger.info(f"Processing podcast creation for task {task_id}")
logger.info(f"Using models - Summarizer: {summarizer_model}, Scriptwriter: {scriptwriter_model}, Enhancer: {enhancer_model}")
all_timestamps = get_all_timestamps()
logger.info(f"All timestamps: {all_timestamps}")
last_timestamp = max(all_timestamps) if all_timestamps else None
other_timestamps = [t for t in all_timestamps if t != last_timestamp]
random_timestamp = random.choice(other_timestamps) if other_timestamps else None
async def create_podcast_subtask(timestamp, podcast_type):
try:
logger.info(f"Creating podcast for timestamp {timestamp}")
podcast_audio, dialogue_text, new_timestamp = await create_podcast_audio(
pdf_bytes, timestamp=timestamp,
summarizer_model=summarizer_model,
scriptwriter_model=scriptwriter_model,
enhancer_model=enhancer_model,
provider=provider
)
logger.info(f"Podcast created successfully for timestamp {timestamp}")
logger.info(f"New timestamp for saved podcast state: {new_timestamp}")
return {
"timestamp": timestamp,
"new_timestamp": new_timestamp,
"type": podcast_type,
"audio": base64.b64encode(podcast_audio).decode('utf-8') if podcast_audio else None,
"dialogue": dialogue_text
}
except Exception as e:
logger.error(f"Error in create_podcast_subtask for timestamp {timestamp}: {str(e)}", exc_info=True)
return {"error": str(e), "timestamp": timestamp, "type": podcast_type}
logger.info("Creating both podcasts concurrently")
podcasts = await asyncio.gather(
create_podcast_subtask(random_timestamp, "random"),
create_podcast_subtask(last_timestamp, "last")
)
# Check for errors in podcast creation
errors = [podcast for podcast in podcasts if "error" in podcast]
if errors:
error_messages = "; ".join([f"{error['type']} podcast: {error['error']}" for error in errors])
tasks[task_id] = {"status": "failed", "error": f"Failed to create podcasts: {error_messages}"}
else:
logger.info("Podcasts created successfully")
tasks[task_id] = {"status": "completed", "result": {"podcasts": podcasts}}
except Exception as e:
logger.error(f"Error in process_podcast_creation: {str(e)}", exc_info=True)
tasks[task_id] = {"status": "failed", "error": str(e)}
@app.post("/process_feedback")
async def process_feedback(request: FeedbackRequest):
feedback = request.feedback
old_timestamp = request.old_timestamp
new_timestamp = request.new_timestamp
logger.info(f"Received feedback: {feedback}")
logger.info(f"Old timestamp: {old_timestamp}")
logger.info(f"New timestamp: {new_timestamp}")
if old_timestamp:
add_feedback_to_state(old_timestamp, feedback)
try:
optimize_prompt("summarizer", old_timestamp, new_timestamp, "gpt-4o-mini", "gpt-4o-mini")
optimize_prompt("scriptwriter", old_timestamp, new_timestamp, "gpt-4o-mini", "gpt-4o-mini")
optimize_prompt("enhancer", old_timestamp, new_timestamp, "gpt-4o-mini", "gpt-4o-mini")
except Exception as e:
logger.error(f"Error optimizing prompts: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error optimizing prompts: {str(e)}")
return {"message": "Feedback processed and prompts optimized"}
@app.post("/vote")
async def vote(request: VoteRequest):
votes = load_votes()
timestamp = request.timestamp if request.timestamp is not None else "original"
if timestamp in votes:
votes[timestamp] += 1
else:
votes[timestamp] = 1
save_votes(votes)
logger.info(f"Vote recorded for timestamp: {timestamp}")
return {"message": "Vote recorded successfully", "timestamp": timestamp}
@app.post("/submit_experiment_idea")
async def submit_experiment_idea(request: Request):
idea = await request.body()
idea_text = idea.decode('utf-8')
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(EXPERIMENT_IDEAS_FILE, "a") as f:
f.write(f"\n\n---\n\nNew Experiment Idea (submitted on {timestamp}):\n\n{idea_text}\n")
return {"message": "Experiment idea submitted successfully"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)