6
6
from fastapi import FastAPI , Body
7
7
from fastapi .exceptions import HTTPException
8
8
from PIL import Image
9
+ from itertools import tee
9
10
10
11
import gradio as gr
11
12
12
13
from modules .api .models import List , Dict
13
14
from modules .api import api
14
15
15
- from src .core import core_generation_funnel
16
+ from src .core import core_generation_funnel , CoreGenerationFunnelInp
16
17
from src .misc import SCRIPT_VERSION
17
18
from src import backbone
18
19
from src .common_constants import GenerationOptions as go
19
-
20
+ from src . api_constants import Api_Defaults , Api_Forced , Api_options
20
21
21
22
def encode_to_base64 (image ):
22
23
if type (image ) is str :
@@ -28,48 +29,77 @@ def encode_to_base64(image):
28
29
else :
29
30
return ""
30
31
31
-
32
32
def encode_np_to_base64 (image ):
33
33
pil = Image .fromarray (image )
34
34
return api .encode_pil_to_base64 (pil )
35
35
36
-
37
36
def to_base64_PIL (encoding : str ):
38
37
return Image .fromarray (np .array (api .decode_base64_to_image (encoding )).astype ('uint8' ))
39
38
40
39
40
+ def api_gen (depth_input_images , options ):
41
+
42
+ default_options = CoreGenerationFunnelInp ({Api_Defaults }).values
43
+
44
+ #TODO try-catch type errors here
45
+ for key , value in options .items ():
46
+ default_options [key ] = value
47
+
48
+ for key , value in Api_Forced .items ():
49
+ default_options [key .lower ()] = value
50
+
51
+ if len (depth_input_images ) == 0 :
52
+ raise HTTPException (status_code = 422 , detail = "No images supplied" )
53
+
54
+ print (f"Processing { str (len (depth_input_images ))} images through the API" )
55
+
56
+ pil_images = []
57
+ for input_image in depth_input_images :
58
+ pil_images .append (to_base64_PIL (input_image ))
59
+ outpath = backbone .get_outpath ()
60
+ gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
61
+ return gen_obj
62
+
41
63
def depth_api (_ : gr .Blocks , app : FastAPI ):
42
64
@app .get ("/depth/version" )
43
65
async def version ():
44
66
return {"version" : SCRIPT_VERSION }
45
67
46
68
@app .get ("/depth/get_options" )
47
69
async def get_options ():
48
- return {"options" : sorted ([x .name .lower () for x in go ])}
70
+ return {
71
+ "api_options" : Api_options ,
72
+ "gen_options" : [x .name .lower () for x in go ]
73
+ }
49
74
50
- # TODO: some potential inputs not supported (like custom depthmaps)
51
75
@app .post ("/depth/generate" )
52
76
async def process (
53
77
depth_input_images : List [str ] = Body ([], title = 'Input Images' ),
54
- options : Dict [str , object ] = Body ("options" , title = 'Generation options' ),
78
+ api_options : Dict [str , object ] = Body ({'outputs' : ["depth" ]}, title = 'Api options' , options = Api_options ),
79
+ gen_options : Dict [str , object ] = Body ({}, title = 'Generation options' , options = [x .name .lower () for x in go ])
55
80
):
56
- # TODO: restrict mesh options
57
-
58
- if len (depth_input_images ) == 0 :
59
- raise HTTPException (status_code = 422 , detail = "No images supplied" )
60
- print (f"Processing { str (len (depth_input_images ))} images trough the API" )
61
-
62
- pil_images = []
63
- for input_image in depth_input_images :
64
- pil_images .append (to_base64_PIL (input_image ))
65
- outpath = backbone .get_outpath ()
66
- gen_obj = core_generation_funnel (outpath , pil_images , None , None , options )
67
-
68
- results_based = []
69
- for count , type , result in gen_obj :
70
- if type not in ['simple_mesh' , 'inpainted_mesh' ]:
71
- results_based += [encode_to_base64 (result )]
72
- return {"images" : results_based , "info" : "Success" }
81
+ gen_obj = api_gen (depth_input_images , gen_options )
82
+ #NOTE Work around yield. (Might not be necessary, not sure if yield caches)
83
+ _ , gen_obj = tee (gen_obj )
84
+
85
+ if len (api_options ["outputs" ])> 1 :
86
+ results_based = {}
87
+
88
+ for type in api_options ["outputs" ]:
89
+ result_per_type = []
90
+
91
+ for count , img_type , result in gen_obj :
92
+ if img_type == type :
93
+ result_per_type += result
94
+
95
+ if len (result_per_type )== 0 :
96
+ results_based [type ] = "Check options. no img-type of " + str (type ) + " where generated"
97
+ else :
98
+ results_based [type ] = map (encode_to_base64 , result_per_type )
99
+
100
+ return {"images" : results_based , "info" : "Success" }
101
+ else :
102
+ return {"images" : {}, "info" : "api_options.output is empty" }
73
103
74
104
75
105
try :
0 commit comments