Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VC function complete #76

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 14 additions & 137 deletions app/backend/app.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,15 @@
import os
import datetime
import pkg_resources

from scipy.io.wavfile import write, read
from flask import Flask, request, send_file, jsonify
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_cors import CORS

from tts import synthesize
from vc import convert_voice
from svc import convert_singing_voice
from chat import chat_response


CACHE_DIR = pkg_resources.resource_filename(__name__, "cache")

app = Flask(__name__)
CORS(app, resources={
r"/tts*": {"origins": "https://6a4e5b50.r10.cpolar.top"},
r"/chat*": {"origins": "https://6a4e5b50.r10.cpolar.top"}
})

limiter = Limiter(
app=app,
key_func=get_remote_address # 使用请求的远程地址作为标识符
from gradio_client import Client

file = 'recordings/d0b14fcb-288a-4439-ab74-ee7a92b535b0.wav'
cache_dir = 'cache'
client = Client("zomehwh/sovits-teio", output_dir=cache_dir)
result = client.predict(
file, # str (filepath or URL to file) in 'inputs' Audio component
0, # str in 'Task' Radio component
True,
'', # str in 'Model' Radio component
'', # str in 'Model' Radio component
False,
fn_index=0,
)

# index page
@app.route("/")
def index():
return "<p>Welcome to lemon5!</p>"

# text-to-speech
@app.route("/tts", methods=["POST", "GET"])
@limiter.limit("50 per minute")
def tts():
text = request.args.get("text")
lang = request.args.get("lang", "ja")
if text:
audio = synthesize(text=text, lang=lang)
file_name = "tts_{}.wav".format(lang)
file_path = os.path.join(CACHE_DIR, file_name)
write(file_path, 22050, audio)
try:
return send_file(file_path, mimetype="audio/wav", as_attachment=True)
finally:
os.remove(file_path)
else:
return "Input text is invalid."

# voice conversion
@app.route("/vc", methods=["POST", "GET"])
@limiter.limit("50 per minute")
def vc():
if "src_audio" not in request.files:
return "No src_audio file uploaded."

if "tgt_audio" not in request.files:
return "No tgt_audio file uploaded."

src_audio = request.files["src_audio"]
src_audio_path = os.path.join(CACHE_DIR, src_audio.filename)
src_audio.save(src_audio_path)

tgt_audio = request.files["tgt_audio"]
tgt_audio_path = os.path.join(CACHE_DIR, tgt_audio.filename)
tgt_audio.save(tgt_audio_path)

out_audio = convert_voice(src_audio_path, tgt_audio_path)
out_filename = "{}-to-{}.wav".format(src_audio.filename.split('.')[0], tgt_audio.filename.split('.')[0])
out_path = os.path.join(CACHE_DIR, out_filename)
write(out_path, 16000, out_audio)

try:
return send_file(out_path, mimetype="audio/wav", as_attachment=True)
finally:
os.remove(out_path)

# singing voice conversion
@app.route("/vc", methods=["POST", "GET"])
@limiter.limit("50 per minute")
def svc():
if "src_audio" not in request.files:
return "No src_audio file uploaded."

if "tgt_audio" not in request.files:
return "No tgt_audio file uploaded."

src_audio = request.files["src_audio"]
src_audio_path = os.path.join(CACHE_DIR, src_audio.filename)
src_audio.save(src_audio_path)

tgt_audio = request.files["tgt_audio"]
tgt_audio_path = os.path.join(CACHE_DIR, tgt_audio.filename)
tgt_audio.save(tgt_audio_path)

out_audio = convert_voice(src_audio_path)
out_filename = "{}-to-{}.wav".format(src_audio.filename.split('.')[0], tgt_audio.filename.split('.')[0])
out_path = os.path.join(CACHE_DIR, out_filename)
write(out_path, 44100, out_audio)
try:
return send_file(out_path, mimetype="audio/wav", as_attachment=True)
finally:
os.remove(out_path)


# chatting
@app.route("/chat", methods=['POST'])
def chat():
json_post_list = request.json
character = json_post_list.get('character', '派蒙') # 从请求中获取角色名,默认为派蒙
response, history = chat_response(json_post_list)

now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time,
"character": character # 在响应中返回角色名
}
log = "[" + time + "] " + '", prompt:"' + json_post_list.get('prompt') + '", response:"' + repr(response) + '", character:"' + character + '"'
print(log)

return jsonify(answer)

# error handler
@app.errorhandler(429)
def ratelimit_error(e):
print("Ratelimit exceeded: ", str(e.description))
return jsonify(error="ratelimit exceeded", message=str(e.description)), 429


if __name__ == "__main__":
app.run(debug=True)
print(result)
Binary file modified app/backend/instance/backend.db
Binary file not shown.
12 changes: 10 additions & 2 deletions app/backend/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ def populate_database():
# Optional: Update existing record with new data
existing_voice.avatar = voice_data['avatar']
existing_voice.audio = voice_data['audio']

voice = existing_voice
else:
# Create a new Voice instance if it doesn't exist
voice = Voice(name=voice_data['name'], avatar=voice_data['avatar'], audio=voice_data['audio'])
voice = Voice(name=voice_data['name'], avatar=voice_data['avatar'], audio=voice_data['audio'], page=voice_data['page'])
db.session.add(voice)

db.session.flush() # This will assign an ID to voice without committing the transaction
Expand Down Expand Up @@ -46,7 +47,14 @@ def remove_duplicate_attributes():

db.session.commit()

def show_all():
voices = Voice.query.all()
for voice in voices:
print(voice.name)
for attribute in voice.attributes:
print(attribute.element, attribute.style)


if __name__ == '__main__':
with app.app_context():
remove_duplicate_attributes()
populate_database()
93 changes: 54 additions & 39 deletions app/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import json
from flask import Flask, request, send_file, jsonify
from flask_socketio import SocketIO
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_cors import CORS, cross_origin
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import and_, or_
Expand Down Expand Up @@ -38,11 +36,6 @@
API_KEY = "tvGX7YwZGsaT3Vez8IMDSl8i",
SECRET_KEY = "HO6dIOw4duyPQQULQ71ug3y6xnPF4OVM"

limiter = Limiter(
app=app,
key_func=get_remote_address # 使用请求的远程地址作为标识符
)


def get_access_token():
"""
Expand Down Expand Up @@ -158,41 +151,55 @@ def handle_audio():
os.makedirs(cache_dir)
file_path = os.path.join(cache_dir, filename)
file.save(file_path)
client = Client("https://hf-audio-whisper-large-v3.hf.space/")
client = Client("https://sanchit-gandhi-whisper-jax.hf.space/")
result = client.predict(
file_path, # str (filepath or URL to file) in 'inputs' Audio component
"transcribe", # str in 'Task' Radio component
False,
api_name="/predict_1"
)
print(result)
result = result[0]
return jsonify({'text': result}), 200

# # voice conversion
# @app.route("/api/vc", methods=["POST", "GET"])
# @limiter.limit("50 per minute")
# def vc():
# if "src_audio" not in request.files:
# return "No src_audio file uploaded."
#
# if "tgt_audio" not in request.files:
# return "No tgt_audio file uploaded."
#
# src_audio = request.files["src_audio"]
# src_audio_path = os.path.join(CACHE_DIR, src_audio.filename)
# src_audio.save(src_audio_path)
#
# tgt_audio = request.files["tgt_audio"]
# tgt_audio_path = os.path.join(CACHE_DIR, tgt_audio.filename)
# tgt_audio.save(tgt_audio_path)
#
# out_audio = convert_voice(src_audio_path, tgt_audio_path)
# out_filename = "{}-to-{}.wav".format(src_audio.filename.split('.')[0], tgt_audio.filename.split('.')[0])
# out_path = os.path.join(CACHE_DIR, out_filename)
# write(out_path, 16000, out_audio)
#
# try:
# return send_file(out_path, mimetype="audio/wav", as_attachment=True)
# finally:
# os.remove(out_path)
@app.route("/api/vc", methods=["POST", "GET"])
def vc():
if 'file' not in request.files:
return 'No file part', 400
file = request.files['file']
# save file
filename = f"{uuid.uuid4()}.wav"
cache_dir = 'recordings'
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

file_path = os.path.join(cache_dir, filename)
file.save(file_path)

cache_dir = 'cache'
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
client = Client("zomehwh/sovits-teio", output_dir=cache_dir)
result = client.predict(
file_path, # str (filepath or URL to file) in 'inputs' Audio component
0, # str in 'Task' Radio component
True,
'', # str in 'Model' Radio component
'', # str in 'Model' Radio component
False,
fn_index=0,
)
print(result)
dirname, filename = result[1].split(path_delimiter)[-2], result[1].split(path_delimiter)[-1]
newfilename = dirname + '.wav'
# move file one level up
os.rename(os.path.join('cache', dirname, filename), os.path.join('cache', newfilename))
# remove folder
os.rmdir(os.path.join('cache', dirname))
socketio.emit('notification', {'message': f'语音转换完成',
'time': str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))})
return jsonify({'filename': newfilename})

@app.route("/api/chat", methods=["POST"])
@cross_origin()
Expand Down Expand Up @@ -315,6 +322,7 @@ class Voice(db.Model):
name = db.Column(db.String(50), unique=True, nullable=False)
avatar = db.Column(db.String(200), nullable=False)
audio = db.Column(db.String(200), nullable=False)
page = db.Column(db.String(200), nullable=False)

# Relationship with attributes
attributes = db.relationship('Attribute', backref='voice', lazy=True)
Expand All @@ -335,14 +343,18 @@ def get_filtered_voices():
name_query = request.args.get('name')
element_query = request.args.get('element')
style_query = request.args.get('style')

page_query = request.args.get('page')
# Base query
query = Voice.query

# Apply filters based on query parameters
if name_query:
query = query.filter(Voice.name.contains(name_query))

if page_query:
# find exact match
query = query.filter(Voice.page == page_query)

if any([element_query, style_query]):
query = query.join(Attribute).filter(
and_(
Expand Down Expand Up @@ -402,10 +414,13 @@ def login():
@app.route('/api/signout', methods=['POST'])
def logout():
# clear cache
os.remove('cache')
os.remove('recordings')
# clear chat history
chat_history.clear()
try:
os.remove('cache')
os.remove('recordings')
# clear chat history
chat_history.clear()
except:
print('cache not exist')
return jsonify({'message': 'User logout successfully'}), 200

# Database model for Notification
Expand Down
Loading
Loading