Skip to content

Commit

Permalink
Merge pull request braindotai#22 from edwin-yan/master
Browse files Browse the repository at this point in the history
Added MPS for Mac Users
  • Loading branch information
braindotai authored Oct 25, 2023
2 parents c646ff7 + 4a90101 commit b3051db
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,25 @@
from helper import *
from model.generator import SkipEncoderDecoder, input_noise

def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm_length = 100):
DTYPE = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
if not torch.cuda.is_available():
print('\nSetting device to "cpu", since torch is not built with "cuda" support...')

def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr, show_step, training_steps, tqdm_length=100):
DTYPE = torch.FloatTensor
has_set_device = False
if torch.cuda.is_available():
device = 'cuda'
has_set_device = True
print("Setting Device to CUDA...")
try:
if torch.backends.mps.is_available():
device = 'mps'
has_set_device = True
print("Setting Device to MPS...")
except Exception as e:
print(f"Your version of pytorch might be too old, which does not support MPS. Error: \n{e}")
pass
if not has_set_device:
device = 'cpu'
print('\nSetting device to "cpu", since torch is not built with "cuda" or "mps" support...')
print('It is recommended to use GPU if possible...')

image_np, mask_np = preprocess_images(image_path, mask_path, max_dim)
Expand All @@ -17,43 +32,43 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr,
num_channels_down = [128] * 5,
num_channels_up = [128] * 5,
num_channels_skip = [128] * 5
).type(DTYPE)
).type(DTYPE).to(device)

objective = torch.nn.MSELoss().type(DTYPE)
objective = torch.nn.MSELoss().type(DTYPE).to(device)
optimizer = optim.Adam(generator.parameters(), lr)

image_var = np_to_torch_array(image_np).type(DTYPE)
mask_var = np_to_torch_array(mask_np).type(DTYPE)
image_var = np_to_torch_array(image_np).type(DTYPE).to(device)
mask_var = np_to_torch_array(mask_np).type(DTYPE).to(device)

generator_input = input_noise(input_depth, image_np.shape[1:]).type(DTYPE)
generator_input = input_noise(input_depth, image_np.shape[1:]).type(DTYPE).to(device)

generator_input_saved = generator_input.detach().clone()
noise = generator_input.detach().clone()

print('\nStarting training...\n')

progress_bar = tqdm(range(training_steps), desc = 'Completed', ncols = tqdm_length)
progress_bar = tqdm(range(training_steps), desc='Completed', ncols=tqdm_length)

for step in progress_bar:
optimizer.zero_grad()
generator_input = generator_input_saved

if reg_noise > 0:
generator_input = generator_input_saved + (noise.normal_() * reg_noise)

output = generator(generator_input)

loss = objective(output * mask_var, image_var * mask_var)
loss.backward()

if step % show_step == 0:
output_image = torch_to_np_array(output)
visualize_sample(image_np, output_image, nrow = 2, size_factor = 10)

progress_bar.set_postfix(Loss = loss.item())

optimizer.step()

output_image = torch_to_np_array(output)
visualize_sample(output_image, nrow = 1, size_factor = 10)

Expand All @@ -62,4 +77,4 @@ def remove_watermark(image_path, mask_path, max_dim, reg_noise, input_depth, lr,
output_path = image_path.split('/')[-1].split('.')[-2] + '-output.jpg'
print(f'\nSaving final output image to: "{output_path}"\n')

pil_image.save(output_path)
pil_image.save(output_path)

0 comments on commit b3051db

Please sign in to comment.