Skip to content

vllm.compilation.fusion

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

FUSED_OPS module-attribute

QUANT_OPS module-attribute

QUANT_OPS: dict[QuantKey, OpOverload] = {
    kFp8StaticTensorSym: default,
    kFp8DynamicTensorSym: default,
    kFp8DynamicTokenSym: default,
}

RMS_ADD_OP module-attribute

RMS_ADD_OP = default

RMS_OP module-attribute

RMS_OP = default

logger module-attribute

logger = init_logger(__name__)

FusedAddRMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)

            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
        result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
        result, scale = self.quant_matcher(result_rms)

        return result, residual, scale

    def replacement(
        input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
    ):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(input)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=residual,
        )

        # result, residual, scale
        return at[1], at[3], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

FusedAddRMSNormGroupQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, scale = self.quant_matcher(result_rms)
            return result, residual, scale

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
                input, transposed=self.quant_matcher.use_col_major_scales
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=residual,
                group_size=self.group_shape[1],
                is_scale_transposed=self.quant_matcher.use_col_major_scales,
            )

            # result, residual, scale
            return at[1], at[3], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

group_shape instance-attribute

group_shape = group_shape

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    self.group_shape = group_shape
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
        result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
        result, scale = self.quant_matcher(result_rms)
        return result, residual, scale

    def replacement(
        input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
    ):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(
            input, transposed=self.quant_matcher.use_col_major_scales
        )
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=residual,
            group_size=self.group_shape[1],
            is_scale_transposed=self.quant_matcher.use_col_major_scales,
        )

        # result, residual, scale
        return at[1], at[3], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

FusedAddRMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        key = FusedRMSQuantKey(
            fused_add=True,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
        ):
            result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
            result, _ = self.quant_matcher(result_rms, scale)

            return result, residual

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            residual: torch.Tensor,
            scale: torch.Tensor,
        ):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                residual=residual,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result, residual
            return at[1], at[2]

        inputs = [
            # input, weight, residual
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
        ]

        pm.register_replacement(
            pattern,
            replacement,
            inputs,
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    key = FusedRMSQuantKey(
        fused_add=True,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
        scale: torch.Tensor,
    ):
        result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
        result, _ = self.quant_matcher(result_rms, scale)

        return result, residual

    def replacement(
        input: torch.Tensor,
        weight: torch.Tensor,
        residual: torch.Tensor,
        scale: torch.Tensor,
    ):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            residual=residual,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result, residual
        return at[1], at[2]

    inputs = [
        # input, weight, residual
        *self.rmsnorm_matcher.inputs(),
        self.quant_matcher.inputs()[1],  # scale
    ]

    pm.register_replacement(
        pattern,
        replacement,
        inputs,
        pm.fwd_only,
        pm_pass,
    )

FusedRMSQuantKey

Bases: NamedTuple

Named tuple for identifying the type of RMSNorm + quant fusion. quant: type of quantization fused_add: does the op also perform the residual add

Source code in vllm/compilation/fusion.py
class FusedRMSQuantKey(NamedTuple):
    """
    Named tuple for identifying the type of RMSNorm + quant fusion.
    quant: type of quantization
    fused_add: does the op also perform the residual add
    """

    quant: QuantKey
    fused_add: bool

    def __str__(self):
        return (
            f"FusedQuantKey({self.quant}, with"
            f"{'' if self.fused_add else 'out'} residual)"
        )

fused_add instance-attribute

fused_add: bool

quant instance-attribute

quant: QuantKey

__str__

__str__()
Source code in vllm/compilation/fusion.py
def __str__(self):
    return (
        f"FusedQuantKey({self.quant}, with"
        f"{'' if self.fused_add else 'out'} residual)"
    )

RMSNormDynamicQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape = GroupShape.PER_TOKEN,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            # result, scale
            return self.quant_matcher(result_rms)

        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(input)
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape = PER_TOKEN,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape = GroupShape.PER_TOKEN,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor):
        result_rms = self.rmsnorm_matcher(input, weight)
        # result, scale
        return self.quant_matcher(result_rms)

    def replacement(input: torch.Tensor, weight: torch.Tensor):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(input)
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=None,
        )

        # result, scale
        return at[1], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

RMSNormGroupQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
    def __init__(
        self,
        epsilon: float,
        quant_dtype: torch.dtype,
        group_shape: GroupShape,
        symmetric=True,
    ):
        scale = ScaleDesc(torch.float32, False, group_shape)
        key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
        )
        self.group_shape = group_shape
        super().__init__(epsilon, key)

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(input: torch.Tensor, weight: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            result, scale = self.quant_matcher(result_rms)
            return result, scale

        def replacement(input: torch.Tensor, weight: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty_like(input, dtype=self.quant_dtype)
            scale = self.quant_matcher.make_scale(
                input, transposed=self.quant_matcher.use_col_major_scales
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
                scale_ub=None,
                residual=None,
                group_size=self.group_shape[1],
                is_scale_transposed=self.quant_matcher.use_col_major_scales,
            )

            # result, scale
            return at[1], at[2]

        pm.register_replacement(
            pattern,
            replacement,
            self.rmsnorm_matcher.inputs(),
            pm.fwd_only,
            pm_pass,
        )

group_shape instance-attribute

group_shape = group_shape

__init__

__init__(
    epsilon: float,
    quant_dtype: dtype,
    group_shape: GroupShape,
    symmetric=True,
)
Source code in vllm/compilation/fusion.py
def __init__(
    self,
    epsilon: float,
    quant_dtype: torch.dtype,
    group_shape: GroupShape,
    symmetric=True,
):
    scale = ScaleDesc(torch.float32, False, group_shape)
    key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
    )
    self.group_shape = group_shape
    super().__init__(epsilon, key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    def pattern(input: torch.Tensor, weight: torch.Tensor):
        result_rms = self.rmsnorm_matcher(input, weight)
        result, scale = self.quant_matcher(result_rms)
        return result, scale

    def replacement(input: torch.Tensor, weight: torch.Tensor):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty_like(input, dtype=self.quant_dtype)
        scale = self.quant_matcher.make_scale(
            input, transposed=self.quant_matcher.use_col_major_scales
        )
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
            scale_ub=None,
            residual=None,
            group_size=self.group_shape[1],
            is_scale_transposed=self.quant_matcher.use_col_major_scales,
        )

        # result, scale
        return at[1], at[2]

    pm.register_replacement(
        pattern,
        replacement,
        self.rmsnorm_matcher.inputs(),
        pm.fwd_only,
        pm_pass,
    )

RMSNormQuantFusionPass

Bases: VllmPatternMatcherPass

This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.

Source code in vllm/compilation/fusion.py
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
    """
    This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
    It also supports fused_add_rms_norm.
    """

    @enable_fake_mode
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="rmsnorm_quant_fusion_pass"
        )

        # Make sure fused add patterns are before simple rms norm,
        # as the latter is a subset of the former in torch ops
        for epsilon in [1e-5, 1e-6]:
            # Fuse fused_add_rms_norm + fp8 group quant
            # Only register group quant patterns on CUDA where the C++ op exists
            if current_platform.is_cuda():
                FusedAddRMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
                ).register(self.patterns)

                # Fuse rms_norm + fp8 group quant
                RMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
                ).register(self.patterns)

                FusedAddRMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
                ).register(self.patterns)

                # Fuse rms_norm + fp8 group quant
                RMSNormGroupQuantPattern(
                    epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
                ).register(self.patterns)

            # Fuse fused_add_rms_norm + static fp8 quant
            FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

            # Fuse rms_norm + static fp8 quant
            RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

            # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
            FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
                self.patterns
            )

            # Fuse rms_norm + dynamic per-token fp8 quant
            RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        self.dump_patterns(config, self.patterns)

    @VllmInductorPass.time_and_log
    def __call__(self, graph: fx.Graph):
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)

    def uuid(self) -> Any:
        return self.hash_source(
            self,
            RMSNormGroupQuantPattern,
            RMSNormQuantPattern,
            RMSNormStaticQuantPattern,
            RMSNormDynamicQuantPattern,
            FusedAddRMSNormStaticQuantPattern,
            FusedAddRMSNormDynamicQuantPattern,
            FusedAddRMSNormGroupQuantPattern,
        )

patterns instance-attribute

patterns: PatternMatcherPass = PatternMatcherPass(
    pass_name="rmsnorm_quant_fusion_pass"
)

__call__

__call__(graph: Graph)
Source code in vllm/compilation/fusion.py
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
    self.matched_count = self.patterns.apply(graph)
    logger.debug("Replaced %s patterns", self.matched_count)

__init__

__init__(config: VllmConfig)
Source code in vllm/compilation/fusion.py
@enable_fake_mode
def __init__(self, config: VllmConfig):
    super().__init__(config)

    self.patterns: PatternMatcherPass = PatternMatcherPass(
        pass_name="rmsnorm_quant_fusion_pass"
    )

    # Make sure fused add patterns are before simple rms norm,
    # as the latter is a subset of the former in torch ops
    for epsilon in [1e-5, 1e-6]:
        # Fuse fused_add_rms_norm + fp8 group quant
        # Only register group quant patterns on CUDA where the C++ op exists
        if current_platform.is_cuda():
            FusedAddRMSNormGroupQuantPattern(
                epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
            ).register(self.patterns)

            # Fuse rms_norm + fp8 group quant
            RMSNormGroupQuantPattern(
                epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
            ).register(self.patterns)

            FusedAddRMSNormGroupQuantPattern(
                epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
            ).register(self.patterns)

            # Fuse rms_norm + fp8 group quant
            RMSNormGroupQuantPattern(
                epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
            ).register(self.patterns)

        # Fuse fused_add_rms_norm + static fp8 quant
        FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

        # Fuse rms_norm + static fp8 quant
        RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

        # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
        FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
            self.patterns
        )

        # Fuse rms_norm + dynamic per-token fp8 quant
        RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)

    self.dump_patterns(config, self.patterns)

uuid

uuid() -> Any
Source code in vllm/compilation/fusion.py
def uuid(self) -> Any:
    return self.hash_source(
        self,
        RMSNormGroupQuantPattern,
        RMSNormQuantPattern,
        RMSNormStaticQuantPattern,
        RMSNormDynamicQuantPattern,
        FusedAddRMSNormStaticQuantPattern,
        FusedAddRMSNormDynamicQuantPattern,
        FusedAddRMSNormGroupQuantPattern,
    )

RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormQuantPattern:
    def __init__(self, epsilon: float, key: FusedRMSQuantKey):
        self.epsilon = epsilon
        self.quant_dtype = key.quant.dtype
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None

        # groupwise FP8 linear uses col major scales if deepgemm and cutlass
        using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
            self.model_dtype,
            config.model_config.hf_config.intermediate_size,
            config.model_config.hf_config.hidden_size,
        )
        use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
        use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False

        assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
        self.FUSED_OP = FUSED_OPS[key]

        self.rmsnorm_matcher = (
            MatcherRMSNorm(epsilon)
            if not key.fused_add
            else MatcherFusedAddRMSNorm(epsilon)
        )
        self.quant_matcher = MatcherQuantFP8(
            key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
        )

FUSED_OP instance-attribute

FUSED_OP = FUSED_OPS[key]

epsilon instance-attribute

epsilon = epsilon

model_dtype instance-attribute

model_dtype = dtype if model_config else None

quant_dtype instance-attribute

quant_dtype = dtype

quant_matcher instance-attribute

quant_matcher = MatcherQuantFP8(
    quant,
    use_col_major_scales=use_col_major_scales,
    use_e8m0=use_e8m0,
)

rmsnorm_matcher instance-attribute

rmsnorm_matcher = (
    MatcherRMSNorm(epsilon)
    if not fused_add
    else MatcherFusedAddRMSNorm(epsilon)
)

__init__

__init__(epsilon: float, key: FusedRMSQuantKey)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
    self.epsilon = epsilon
    self.quant_dtype = key.quant.dtype
    config = get_current_vllm_config()
    self.model_dtype = config.model_config.dtype if config.model_config else None

    # groupwise FP8 linear uses col major scales if deepgemm and cutlass
    using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
        self.model_dtype,
        config.model_config.hf_config.intermediate_size,
        config.model_config.hf_config.hidden_size,
    )
    use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
    use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False

    assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
    self.FUSED_OP = FUSED_OPS[key]

    self.rmsnorm_matcher = (
        MatcherRMSNorm(epsilon)
        if not key.fused_add
        else MatcherFusedAddRMSNorm(epsilon)
    )
    self.quant_matcher = MatcherQuantFP8(
        key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
    )

RMSNormStaticQuantPattern

Bases: RMSNormQuantPattern

Source code in vllm/compilation/fusion.py
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
    def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
        fused_key = FusedRMSQuantKey(
            fused_add=False,
            quant=QuantKey(
                dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
            ),
        )
        super().__init__(epsilon, fused_key)

    def register(self, pm_pass: PatternMatcherPass):
        # Cannot use methods, as the self argument affects tracing
        def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            result_rms = self.rmsnorm_matcher(input, weight)
            return self.quant_matcher(result_rms, scale)[0]

        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
            # In case we're matching native rms-norm, conversions might be
            # optimized out. We convert here just to be safe.
            input = input.to(dtype=self.model_dtype)

            result = torch.empty(
                input.shape, device=input.device, dtype=self.quant_dtype
            )
            at = auto_functionalized(
                self.FUSED_OP,
                result=result,
                input=input,
                weight=weight,
                scale=scale,
                epsilon=self.epsilon,
            )

            # result
            return at[1]

        inputs = [
            # input, weight
            *self.rmsnorm_matcher.inputs(),
            self.quant_matcher.inputs()[1],  # scale
        ]
        pattern(*inputs)

        pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

__init__

__init__(
    epsilon: float, quant_dtype: dtype, symmetric=True
)
Source code in vllm/compilation/fusion.py
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
    fused_key = FusedRMSQuantKey(
        fused_add=False,
        quant=QuantKey(
            dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
        ),
    )
    super().__init__(epsilon, fused_key)

register

register(pm_pass: PatternMatcherPass)
Source code in vllm/compilation/fusion.py
def register(self, pm_pass: PatternMatcherPass):
    # Cannot use methods, as the self argument affects tracing
    def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
        result_rms = self.rmsnorm_matcher(input, weight)
        return self.quant_matcher(result_rms, scale)[0]

    def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
        # In case we're matching native rms-norm, conversions might be
        # optimized out. We convert here just to be safe.
        input = input.to(dtype=self.model_dtype)

        result = torch.empty(
            input.shape, device=input.device, dtype=self.quant_dtype
        )
        at = auto_functionalized(
            self.FUSED_OP,
            result=result,
            input=input,
            weight=weight,
            scale=scale,
            epsilon=self.epsilon,
        )

        # result
        return at[1]

    inputs = [
        # input, weight
        *self.rmsnorm_matcher.inputs(),
        self.quant_matcher.inputs()[1],  # scale
    ]
    pattern(*inputs)

    pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)

empty_bf16

empty_bf16(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_bf16(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

empty_fp32

empty_fp32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_fp32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")

empty_i32

empty_i32(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_i32(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")

empty_i64

empty_i64(*args, **kwargs)
Source code in vllm/compilation/fusion.py
def empty_i64(*args, **kwargs):
    return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")