Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions basicsr/models/flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,22 @@ def test(self):
# Check if model has fixed input size (e.g., MeanFlow with DiT)
model_input_size = self.opt['network_g'].get('input_size', None)

# For DiT-like architectures (has x_embedder), avoid outer tiling/cropping.
# Delegate to sample_image, which handles DiT fixed input size and padding.
if hasattr(self.net_g, 'x_embedder'):
if hasattr(self, 'net_g_ema'):
model_to_use = self.net_g_ema
model_to_use.eval()
with torch.no_grad():
self.output = self.sample_image(self.lq, model=model_to_use)
model_to_use.train()
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.sample_image(self.lq, model=self.net_g)
self.net_g.train()
return

if model_input_size is not None:
# For models with fixed input size, adjust chunk size to match
# For scale=4 and input_size=128, each LR chunk should be 32x32
Expand Down
69 changes: 55 additions & 14 deletions basicsr/models/rectified_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,27 +448,53 @@ def sample_image(self, lq=None, model=None, ema=False):
model_input_size = self.opt['network_g'].get('input_size', None)

if model_input_size is not None:
# DiT requires fixed input size: crop or pad to match
# For DiT with fixed input_size, tile the input if larger; pad if smaller
_, _, h, w = lq_norm.shape
if h != model_input_size or w != model_input_size:
# Center crop if larger, or pad if smaller
if h >= model_input_size and w >= model_input_size:
start_h = (h - model_input_size) // 2
start_w = (w - model_input_size) // 2
lq_padded = lq_norm[:, :, start_h:start_h+model_input_size, start_w:start_w+model_input_size]
# Keep original size for later padding back
original_h, original_w = h, w
else:
# Pad to model_input_size
if h > model_input_size or w > model_input_size:
# Pad to divisible by input_size
pad_h = (model_input_size - h % model_input_size) % model_input_size
pad_w = (model_input_size - w % model_input_size) % model_input_size
lq_padded = F.pad(lq_norm, (0, pad_w, 0, pad_h), mode='reflect')
_, _, H, W = lq_padded.shape

# Prepare output tensor
output_padded = torch.zeros_like(lq_padded)
times = torch.ones(lq_padded.shape[0], device=self.device)

for i in range(0, H, model_input_size):
for j in range(0, W, model_input_size):
tile = lq_padded[:, :, i:i+model_input_size, j:j+model_input_size]
_, flow = self.predict_flow(model_to_use, tile, times=times)
# Clip flow
if hasattr(self, 'clip_flow_values'):
flow = torch.clamp(flow, self.clip_flow_values[0], self.clip_flow_values[1])
else:
flow = torch.clamp(flow, -2.0, 2.0)
# Apply flow and clip output
tile_out = tile + flow
if hasattr(self, 'clip_values'):
tile_out = torch.clamp(tile_out, self.clip_values[0], self.clip_values[1])
else:
tile_out = torch.clamp(tile_out, -1.0, 1.0)
output_padded[:, :, i:i+model_input_size, j:j+model_input_size] = tile_out

# Crop back to original size
output = output_padded[:, :, :h, :w]
output = self.data_unnormalize_fn(output)
output = torch.clamp(output, 0.0, 1.0)
return output
else:
# Pad or keep to input_size
if h != model_input_size or w != model_input_size:
pad_h = (model_input_size - h) // 2
pad_w = (model_input_size - w) // 2
pad_h2 = model_input_size - h - pad_h
pad_w2 = model_input_size - w - pad_w
lq_padded = F.pad(lq_norm, (pad_w, pad_w2, pad_h, pad_h2), mode='reflect')
original_h, original_w = h, w
else:
lq_padded = lq_norm
original_h, original_w = h, w
else:
lq_padded = lq_norm
original_h, original_w = h, w
else:
# For other architectures (e.g., FlowUNet), pad to be divisible by 8
lq_padded, (original_h, original_w) = self._pad_to_divisible(lq_norm, divisor=8)
Expand Down Expand Up @@ -528,6 +554,21 @@ def sample_image(self, lq=None, model=None, ema=False):

return output

def test(self):
"""Override test to avoid额外分块,直接使用本类的 sample_image(内部已处理 DiT 分块/填充)"""
# 选择模型
if hasattr(self, 'net_g_ema') and not self.use_consistency:
model_to_use = self.net_g_ema
elif self.use_consistency and hasattr(self, 'ema_model'):
model_to_use = self.ema_model.ema_model
else:
model_to_use = self.net_g

model_to_use.eval()
with torch.no_grad():
self.output = self.sample_image(self.lq, model=model_to_use)
model_to_use.train()

'''
Add flow-based loss function.
'''
Expand Down
14 changes: 14 additions & 0 deletions options/train/RectifiedFlow/train_RectifiedFlow_x4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ network_g:
num_register_tokens: 4
num_classes: ~

#network_g:
# type: FlowUNet
# dim: 64
# init_dim: 64
# channels: 3
# dim_mults: [1, 2, 4, 8]
# mean_variance_net: false
# learned_sinusoidal_cond: false
# random_fourier_features: false
# dropout: 0.0
# attn_dim_head: 32
# attn_heads: 4
# num_residual_streams: 2

# Rectified Flow specific settings
rectified_flow:
time_cond_kwarg: 'times'
Expand Down