Skip to content

Commit e12909f

Browse files
api changes + notes
typo updating depth api modified: scripts/depthmap_api.py
1 parent a232eb9 commit e12909f

File tree

3 files changed

+74
-25
lines changed

3 files changed

+74
-25
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
__pycache__/
1+
__pycache__/
2+
models/
3+
ouputs/

scripts/depthmap_api.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
from fastapi import FastAPI, Body
77
from fastapi.exceptions import HTTPException
88
from PIL import Image
9+
from itertools import tee
910

1011
import gradio as gr
1112

1213
from modules.api.models import List, Dict
1314
from modules.api import api
1415

15-
from src.core import core_generation_funnel
16+
from src.core import core_generation_funnel, CoreGenerationFunnelInp
1617
from src.misc import SCRIPT_VERSION
1718
from src import backbone
1819
from src.common_constants import GenerationOptions as go
19-
20+
from src.api_constants import Api_Defaults, Api_Forced, Api_options
2021

2122
def encode_to_base64(image):
2223
if type(image) is str:
@@ -28,48 +29,77 @@ def encode_to_base64(image):
2829
else:
2930
return ""
3031

31-
3232
def encode_np_to_base64(image):
3333
pil = Image.fromarray(image)
3434
return api.encode_pil_to_base64(pil)
3535

36-
3736
def to_base64_PIL(encoding: str):
3837
return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))
3938

4039

40+
def api_gen(depth_input_images, options):
41+
42+
default_options = CoreGenerationFunnelInp({Api_Defaults}).values
43+
44+
#TODO try-catch type errors here
45+
for key, value in options.items():
46+
default_options[key] = value
47+
48+
for key, value in Api_Forced.items():
49+
default_options[key.lower()] = value
50+
51+
if len(depth_input_images) == 0:
52+
raise HTTPException(status_code=422, detail="No images supplied")
53+
54+
print(f"Processing {str(len(depth_input_images))} images through the API")
55+
56+
pil_images = []
57+
for input_image in depth_input_images:
58+
pil_images.append(to_base64_PIL(input_image))
59+
outpath = backbone.get_outpath()
60+
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
61+
return gen_obj
62+
4163
def depth_api(_: gr.Blocks, app: FastAPI):
4264
@app.get("/depth/version")
4365
async def version():
4466
return {"version": SCRIPT_VERSION}
4567

4668
@app.get("/depth/get_options")
4769
async def get_options():
48-
return {"options": sorted([x.name.lower() for x in go])}
70+
return {
71+
"api_options": Api_options,
72+
"gen_options": [x.name.lower() for x in go]
73+
}
4974

50-
# TODO: some potential inputs not supported (like custom depthmaps)
5175
@app.post("/depth/generate")
5276
async def process(
5377
depth_input_images: List[str] = Body([], title='Input Images'),
54-
options: Dict[str, object] = Body("options", title='Generation options'),
78+
api_options: Dict[str, object] = Body({'outputs': ["depth"]}, title='Api options', options= Api_options),
79+
gen_options: Dict[str, object] = Body({}, title='Generation options', options= [x.name.lower() for x in go])
5580
):
56-
# TODO: restrict mesh options
57-
58-
if len(depth_input_images) == 0:
59-
raise HTTPException(status_code=422, detail="No images supplied")
60-
print(f"Processing {str(len(depth_input_images))} images trough the API")
61-
62-
pil_images = []
63-
for input_image in depth_input_images:
64-
pil_images.append(to_base64_PIL(input_image))
65-
outpath = backbone.get_outpath()
66-
gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)
67-
68-
results_based = []
69-
for count, type, result in gen_obj:
70-
if type not in ['simple_mesh', 'inpainted_mesh']:
71-
results_based += [encode_to_base64(result)]
72-
return {"images": results_based, "info": "Success"}
81+
gen_obj = api_gen(depth_input_images, gen_options)
82+
#NOTE Work around yield. (Might not be necessary, not sure if yield caches)
83+
_, gen_obj = tee (gen_obj)
84+
85+
if len(api_options["outputs"])>1:
86+
results_based = {}
87+
88+
for type in api_options["outputs"]:
89+
result_per_type = []
90+
91+
for count, img_type, result in gen_obj:
92+
if img_type == type:
93+
result_per_type += result
94+
95+
if len(result_per_type)==0:
96+
results_based[type] = "Check options. no img-type of " + str(type) + " where generated"
97+
else:
98+
results_based[type] = map(encode_to_base64, result_per_type)
99+
100+
return {"images": results_based, "info": "Success"}
101+
else:
102+
return {"images": {}, "info": "api_options.output is empty"}
73103

74104

75105
try:

src/api_constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# TODO Maybe these have a better home
2+
Api_options = {
3+
'outputs': ["depth"], # list of outputs to send in response. examples ["depth", "normalmap", 'heatmap', "normal", 'background_removed'] etc
4+
#'conversions': "", #TODO implement. it's a good idea to give some options serverside for because often that's challenging in js/clientside
5+
'save':"" #TODO implement. To save on local machine. Can be very helpful for debugging.
6+
}
7+
8+
# TODO: These two are intended to be temporary
9+
Api_Defaults={
10+
"BOOST": False,
11+
"NET_SIZE_MATCH": True
12+
}
13+
#These are enforced after user inputs
14+
Api_Forced={
15+
"GEN_SIMPLE_MESH": False,
16+
"GEN_INPAINTED_MESH": False
17+
}

0 commit comments

Comments
 (0)