Skip to content

Segmented Prefill test

Source https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/kv_segmented_prefill.

This example builds upon the disaggregated-prefill-v1 example in examples/offline_inference.

It demonstrates vLLM's ability to perform segmented prefill, in case the KV Connector reports "gaps" in the external cache. The goal is to verify that vLLM correctly recalculates the tokens missing in cache (the gaps). Correctness is tested by comparing the generation output using the full cache and the output using a cache with gaps.

Files

  • segmented_prefill_example_connector.py – defines SegmentedPrefillExampleConnector, a subclass of ExampleConnector, that simulates missing external KV blocks by creating gaps in the cache - intentionally failing to load blocks of tokens in the middle of each prompt.
  • run.sh – orchestrates the test: runs a prefill stage which generates the external KV-Cache, then two decode stages:

    1. Normal decode (baseline).
    2. Decode with simulated gaps in the KV cache.

    It then compares the two outputs to verify correctness.

How It Works

  • The test dynamically loads SegmentedPrefillExampleConnector via KVTransferConfig.kv_connector_module_path, enabling controlled simulation of cache gaps without modifying the original connector.
  • The decode stage that simulates gaps is expected to trigger Segmented Prefill in vLLM, resulting in the same output as the baseline decode.
  • In case the outputs differ, the script prints a unified diff of mismatch and exits with error.

Usage

./run.sh

Example materials

.gitignore
*.txt
local_storage/
decode_example.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig


def read_prompts():
    """Read prompts from prefill_output.txt"""
    prompts = []
    try:
        with open("prefill_output.txt") as f:
            for line in f:
                prompts.append(line.strip())
        print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
        return prompts
    except FileNotFoundError:
        print("Error: prefill_output.txt file not found")
        exit(-1)


def main():
    prompts = read_prompts()
    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--segmented-prefill", action="store_true", help="Simulate gaps in KV cache"
    )

    args = parser.parse_args()

    if args.segmented_prefill:
        ktc = KVTransferConfig(
            kv_connector="SegmentedPrefillExampleConnector",
            kv_role="kv_both",
            kv_connector_extra_config={
                "shared_storage_path": "local_storage",
            },
            kv_connector_module_path="segmented_prefill_example_connector",
        )
        out_file = "segmented_prefill_decode_output.txt"
    else:
        ktc = KVTransferConfig(
            kv_connector="ExampleConnector",
            kv_role="kv_both",
            kv_connector_extra_config={
                "shared_storage_path": "local_storage",
            },
        )
        out_file = "decode_output.txt"

    llm = LLM(
        model="meta-llama/Llama-3.2-1B-Instruct",
        enforce_eager=True,  # no CUDA graphs, so layer-wise API will be called
        gpu_memory_utilization=0.8,
        kv_transfer_config=ktc,
    )

    outputs = llm.generate(prompts, sampling_params)

    sep_str = "-" * 30 + "\n"
    with open(out_file, "w", encoding="utf-8") as f:
        for output in outputs:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
            print(out_str)
            print(sep_str)
            f.write(out_str)
            f.write(sep_str)


if __name__ == "__main__":
    main()
prefill_example.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig


def read_prompts():
    context = "Hi " * 1000
    context2 = "Hey " * 500
    return [
        context + "Hello, my name is",
        context + context + "The capital of France is",
        context2 + "Your name is",
        context2 + context2 + "The capital of China is",
    ]


def main():
    prompts = read_prompts()

    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

    llm = LLM(
        model="meta-llama/Llama-3.2-1B-Instruct",
        enforce_eager=True,
        gpu_memory_utilization=0.8,
        kv_transfer_config=KVTransferConfig(
            kv_connector="ExampleConnector",
            kv_role="kv_both",
            kv_connector_extra_config={"shared_storage_path": "local_storage"},
        ),
    )  # , max_model_len=2048, max_num_batched_tokens=2048)

    # 1ST generation (prefill instance)
    outputs = llm.generate(
        prompts,
        sampling_params,
    )

    new_prompts = []
    print("-" * 30)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        new_prompts.append(prompt + generated_text)
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 30)

    # Write new_prompts to prefill_output.txt
    with open("prefill_output.txt", "w") as f:
        for prompt in new_prompts:
            f.write(prompt + "\n")
    print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")


if __name__ == "__main__":
    main()
run.sh
#!/bin/bash

# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SEGMENTED_PREFILL_OUTPUT="segmented_prefill_decode_output.txt"

# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"

# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --segmented-prefill

# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"; then
    echo "❌ Outputs differ: segmented prefill output differs from regular prefill."
    diff -u "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"
    exit 1
fi


echo "✅ Outputs match: segmented prefill test successful."
segmented_prefill_example_connector.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import torch

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
    KVConnectorMetadata,
    KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
    ExampleConnector,
    ExampleConnectorMetadata,
)
from vllm.v1.request import Request

if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)


@dataclass
class SegmentedPrefillExampleConnectorMetadata(ExampleConnectorMetadata):
    _req_to_gaps: dict[str, list[tuple[int, int]]] = field(default_factory=dict)

    @classmethod
    def from_base(cls, base: ExampleConnectorMetadata):
        return cls(requests=base.requests)


class SegmentedPrefillExampleConnector(ExampleConnector):
    def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
        super().__init__(vllm_config=vllm_config, role=role)
        self._req_to_gaps: dict[str, list[tuple[int, int]]] = dict()

    def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
        assert isinstance(connector_metadata, SegmentedPrefillExampleConnectorMetadata)
        for req in connector_metadata.requests:
            if not req.is_store and req.req_id in connector_metadata._req_to_gaps:
                gaps = connector_metadata._req_to_gaps[req.req_id]
                self.override_slot_mapping_gaps(req.slot_mapping, gaps)
                total_gap_tokens = sum(end - start for start, end in gaps)
                logger.info(
                    "Simulating gaps in KV token blocks for the "
                    "first load request. Total tokens: %d",
                    total_gap_tokens,
                )
        super().bind_connector_metadata(connector_metadata)

    def _choose_gaps(
        self, num_computed_tokens: int, num_external_tokens: int
    ) -> list[tuple[int, int]]:
        # Simulate gaps in the external tokens, at block_size granularity.
        # Create gaps of growing size (1, 2, 3 blocks etc.) in the last num_external tokens,
        # with non-gap sections of the same growing size between them,
        # ensuring the last block is not a gap, and all aligned to block_size.
        block_size = self._block_size
        external_start = num_computed_tokens
        external_end = num_computed_tokens + num_external_tokens
        if external_end - external_start < block_size:
            return []
        gaps = []
        current_pos = external_start
        size = 1
        is_gap = True
        while current_pos < external_end:
            segment_size_tokens = size * block_size
            segment_end = min(current_pos + segment_size_tokens, external_end)
            if is_gap:
                if segment_end == external_end:
                    # If this gap would be the last segment, skip it to ensure last is non-gap
                    break
                gaps.append((current_pos, segment_end))
            current_pos = segment_end
            is_gap = not is_gap
            if is_gap:
                size += 1
        return gaps

    def _print_gaps_representation(
        self,
        gaps: list[tuple[int, int]],
        num_external_tokens: int,
        num_computed_tokens: int,
    ) -> None:
        """Print a human-readable representation of the tokens and gaps for debugging."""
        total_tokens = num_computed_tokens + num_external_tokens
        block_size = self._block_size
        representation = []
        for block_start in range(0, total_tokens, block_size):
            block_end = min(block_start + block_size, total_tokens)
            block_chars = []
            for i in range(block_start, block_end):
                if i < num_computed_tokens:
                    block_chars.append("C")  # Computed token
                else:
                    # Check if in gap
                    in_gap = any(start <= i < end for start, end in gaps)
                    block_chars.append("-" if in_gap else "E")  # Gap or External token
            # Determine the character for this block
            unique_chars = set(block_chars)
            # print 'X' if mixed token types in block
            char = unique_chars.pop() if len(unique_chars) == 1 else "X"
            representation.append(char)

        print("Cache status per block (C=computed, E=external, -=gap, X=mixed):")
        print("".join(representation))
        print("Gaps: ", gaps)
        print(
            "Total tokens: ",
            total_tokens,
            ", computed tokens: ",
            num_computed_tokens,
            ", external tokens: ",
            num_external_tokens,
        )

    @staticmethod
    def override_slot_mapping_gaps(
        slot_mapping: torch.Tensor, gaps: list[tuple[int, int]]
    ) -> None:
        """create gaps in slot_mapping by mapping them to an incorrect value"""
        if not gaps:
            return
        gap_value = slot_mapping[-1].item()  # use last value
        for start, end in gaps:
            slot_mapping[start:end] = gap_value

    def get_computed_token_gaps(
        self,
        request: "Request",
    ) -> list[tuple[int, int]] | None:
        return self._req_to_gaps.get(request.request_id)

    def get_num_new_matched_tokens(
        self,
        request: Request,
        num_computed_tokens: int,
    ) -> tuple[int | None, bool]:
        num_external_tokens, _ = super().get_num_new_matched_tokens(
            request, num_computed_tokens
        )
        num_external_tokens = (
            0 if num_external_tokens is None else num_external_tokens
        )  # don't simulated async lookup for now

        # pick requests with at least 2*block_size external tokens to simulate gaps
        if num_external_tokens >= self._block_size * 2:
            gaps = self._choose_gaps(num_computed_tokens, num_external_tokens)
            self._req_to_gaps[request.request_id] = gaps
            self._print_gaps_representation(
                gaps, num_external_tokens, num_computed_tokens
            )

        return num_external_tokens, False

    def build_connector_meta(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> KVConnectorMetadata:
        base = super().build_connector_meta(scheduler_output)
        assert isinstance(base, ExampleConnectorMetadata)
        meta = SegmentedPrefillExampleConnectorMetadata.from_base(base)
        meta._req_to_gaps = self._req_to_gaps.copy()

        self._req_to_gaps.clear()
        return meta