Skip to content

support for marigold #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
## Changelog
### 0.4.5
* Support for [Marigold](https://marigoldmonodepth.github.io). [PR #385](https://github.com/thygate/stable-diffusion-webui-depthmap-script/pull/385).
### 0.4.4
* Compatibility with stable-diffusion-webui 1.6.0
### 0.4.3 video processing tab
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,16 @@ ZoeDepth :
copyright = {arXiv.org perpetual, non-exclusive license}
}
```

Marigold - Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation:

```
@misc{ke2023repurposing,
title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
year={2023},
eprint={2312.02145},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
7 changes: 6 additions & 1 deletion install.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def ensure(module_name, min_version=None):
launch.run_pip('install "moviepy==1.0.2"', "moviepy requirement for depthmap script")
ensure('transforms3d', '0.4.1')

ensure('transformers', '4.32.1')
ensure('xformers', '0.0.21')
ensure('accelerate', '0.22.0')
ensure('diffusers', '0.20.1')

ensure('imageio') # 2.4.1
try: # Dirty hack to not reinstall every time
importlib_metadata.version('imageio-ffmpeg')
Expand All @@ -53,4 +58,4 @@ def ensure(module_name, min_version=None):
if platform.system() == 'Darwin':
ensure('pyqt6')

launch.git_clone("https://github.com/prs-eth/Marigold", "repositories/Marigold", "Marigold", "cc78ff3")
launch.git_clone("https://github.com/prs-eth/Marigold", "Marigold", "Marigold", "cc78ff3")
4 changes: 2 additions & 2 deletions pix2pix/models/pix2pix4depth_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def set_input_train(self, input):
self.real_A = torch.cat((self.outer, self.inner), 1)

def set_input(self, outer, inner):
inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0).float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you super-duper sure this does not break anything?

Copy link
Contributor Author

@affromero affromero Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether this break something to be honest. Are there tests/unittests I can run?
I modified this part because for some reason when selecting Marigold mode was raising an error because inner and outer were double tensors and the network weights are in float. Not sure why this is a particular error of marigold mode and not for the others.

outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0).float()

inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ transforms3d>=0.4.1
imageio>=2.4.1,<3.0
imageio-ffmpeg
networkx>=2.5
transformers>=4.32.1 # For Marigold
xformers==0.0.21 # For Marigold
accelerate>=0.22.0 # For Marigold
diffusers>=0.20.1 # For Marigold
pyqt5; sys_platform == 'windows'
pyqt6; sys_platform != 'windows'
37 changes: 21 additions & 16 deletions src/depthmap_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def load_models(self, model_type, device: torch.device, boost: bool):
model_dir = "./models/midas"
if model_type == 0:
model_dir = "./models/leres"
if model_type == 10:
"./models/marigold"
# create paths to model if not present
os.makedirs(model_dir, exist_ok=True)
os.makedirs('./models/pix2pix', exist_ok=True)
Expand Down Expand Up @@ -197,9 +195,9 @@ def load_models(self, model_type, device: torch.device, boost: bool):
model = build_model(conf)

elif model_type == 10: # Marigold v1
# TODO: pass more parameters
model_path = f"{model_dir}/marigold_v1/"
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
model_path = "Bingxin/Marigold"
print(model_path)
from Marigold.src.model.marigold_pipeline import MarigoldPipeline
model = MarigoldPipeline.from_pretrained(model_path)

model.eval() # prepare for evaluation
Expand Down Expand Up @@ -301,11 +299,11 @@ def get_raw_prediction(self, input, net_width, net_height):
self.resize_mode, self.normalization, self.no_half,
self.precision == "autocast")
elif self.depth_model_type == 10:
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height, self.device)
raw_prediction = estimatemarigold(img, self.depth_model, net_width, net_height)
else:
raw_prediction = estimateboost(img, self.depth_model, self.depth_model_type, self.pix2pix_model,
self.boost_whole_size_threshold)
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9]
raw_prediction_invert = self.depth_model_type in [0, 7, 8, 9, 10]
return raw_prediction, raw_prediction_invert


Expand Down Expand Up @@ -405,11 +403,11 @@ def estimatemidas(img, model, w, h, resize_mode, normalization, no_half, precisi
return prediction


def estimatemarigold(image, model, w, h, device):
from repositories.Marigold.src.model.marigold_pipeline import MarigoldPipeline
from repositories.Marigold.src.util.ensemble import ensemble_depths
from repositories.Marigold.src.util.image_util import chw2hwc, colorize_depth_maps, resize_max_res
from repositories.Marigold.src.util.seed_all import seed_all
def estimatemarigold(image, model, w, h):
from Marigold.src.model.marigold_pipeline import MarigoldPipeline
from Marigold.src.util.ensemble import ensemble_depths
from Marigold.src.util.image_util import chw2hwc, colorize_depth_maps, resize_max_res
from Marigold.src.util.seed_all import seed_all

n_repeat = 10
denoise_steps = 10
Expand All @@ -418,13 +416,18 @@ def estimatemarigold(image, model, w, h, device):
tol = 1e-3
reduction_method = "median"
merging_max_res = None
resize_to_max_res = None

# From Marigold repository run.py
with torch.no_grad():
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
rgb_norm = rgb / 255.0
if resize_to_max_res is not None:
image = (image * 255).astype(np.uint8)
image = np.asarray(resize_max_res(
Image.fromarray(image), max_edge_resolution=resize_to_max_res
)) / 255.0
rgb_norm = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
rgb_norm = torch.from_numpy(rgb_norm).unsqueeze(0).float()
rgb_norm = rgb_norm.to(device)
rgb_norm = rgb_norm.to(depthmap_device)

model.unet.eval()
depth_pred_ls = []
Expand All @@ -445,7 +448,7 @@ def estimatemarigold(image, model, w, h, device):
tol=tol,
reduction=reduction_method,
max_res=merging_max_res,
device=device,
device=depthmap_device,
)
else:
depth_pred = depth_preds
Expand Down Expand Up @@ -942,6 +945,8 @@ def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel
def singleestimate(img, msize, model, net_type):
if net_type == 0:
return estimateleres(img, model, msize, msize)
elif net_type == 10:
return estimatemarigold(img, model, msize, msize)
elif net_type >= 7:
# np to PIL
return estimatezoedepth(Image.fromarray(np.uint8(img * 255)).convert('RGB'), model, msize, msize)
Expand Down
2 changes: 1 addition & 1 deletion src/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_commit_hash():

REPOSITORY_NAME = "stable-diffusion-webui-depthmap-script"
SCRIPT_NAME = "DepthMap"
SCRIPT_VERSION = "v0.4.4"
SCRIPT_VERSION = "v0.4.5"
SCRIPT_FULL_NAME = f"{SCRIPT_NAME} {SCRIPT_VERSION} ({get_commit_hash()})"


Expand Down