Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions specforge/modeling/target/eagle3_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,65 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple

import sglang.srt.managers.mm_utils as mm_utils
import torch
import torch.distributed as dist
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
init_mm_embedding_cache,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
Req,
ScheduleBatch,
)

# - prepare_mlp_sync_batch_raw is now a module-level function, not a Scheduler method
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM

from specforge.distributed import get_tp_device_mesh, get_tp_group
from specforge.utils import padding

from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module
from .sglang_backend.utils import LogitsProcessorForEAGLE3
# SGLang internals back the *sglang* target backend only. Keep these imports
# optional so `import specforge` (and the HF / offline / draft paths) still works
# when the installed sglang version does not expose the exact symbols this file
# pins. The SGLang backend then surfaces a clear error at construction time
# (see SGLangEagle3TargetModel.from_pretrained). This keeps the engine behind a
# replaceable boundary rather than a hard, version-locked import dependency.
try:
import sglang.srt.managers.mm_utils as mm_utils
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
init_mm_embedding_cache,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
Req,
ScheduleBatch,
)

# prepare_mlp_sync_batch_raw is a module-level function, not a Scheduler method
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
)
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather

from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module
from .sglang_backend.utils import LogitsProcessorForEAGLE3

_SGLANG_IMPORT_ERROR = None
except Exception as _exc: # pragma: no cover - depends on installed sglang version
_SGLANG_IMPORT_ERROR = _exc
mm_utils = ModelConfig = MRotaryEmbedding = None
MultiModalityDataPaddingPatternMultimodalTokens = init_mm_embedding_cache = None
Modality = MultimodalDataItem = MultimodalInputs = Req = ScheduleBatch = None
prepare_mlp_sync_batch_raw = CacheInitParams = RadixCache = None
CaptureHiddenMode = ForwardBatch = BaseMultimodalProcessor = None
SamplingParams = ServerArgs = SpeculativeAlgorithm = None
require_mlp_sync = require_mlp_tp_gather = None
SGLangRunner = wrap_eagle3_logits_processors_in_module = None
LogitsProcessorForEAGLE3 = None

logger = logging.getLogger(__name__)

Expand Down
Loading