Segmented Prefill test¶
Source https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/kv_segmented_prefill.
This example builds upon the disaggregated-prefill-v1 example in examples/offline_inference.
It demonstrates vLLM's ability to perform segmented prefill, in case the KV Connector reports "gaps" in the external cache. The goal is to verify that vLLM correctly recalculates the tokens missing in cache (the gaps). Correctness is tested by comparing the generation output using the full cache and the output using a cache with gaps.
Files¶
segmented_prefill_example_connector.py– definesSegmentedPrefillExampleConnector, a subclass ofExampleConnector, that simulates missing external KV blocks by creating gaps in the cache - intentionally failing to load blocks of tokens in the middle of each prompt.-
run.sh– orchestrates the test: runs a prefill stage which generates the external KV-Cache, then two decode stages:- Normal decode (baseline).
- Decode with simulated gaps in the KV cache.
It then compares the two outputs to verify correctness.
How It Works¶
- The test dynamically loads
SegmentedPrefillExampleConnectorviaKVTransferConfig.kv_connector_module_path, enabling controlled simulation of cache gaps without modifying the original connector. - The decode stage that simulates gaps is expected to trigger Segmented Prefill in vLLM, resulting in the same output as the baseline decode.
- In case the outputs differ, the script prints a unified diff of mismatch and exits with error.
Usage¶
Example materials¶
decode_example.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
"""Read prompts from prefill_output.txt"""
prompts = []
try:
with open("prefill_output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
return prompts
except FileNotFoundError:
print("Error: prefill_output.txt file not found")
exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
parser = argparse.ArgumentParser()
parser.add_argument(
"--segmented-prefill", action="store_true", help="Simulate gaps in KV cache"
)
args = parser.parse_args()
if args.segmented_prefill:
ktc = KVTransferConfig(
kv_connector="SegmentedPrefillExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
kv_connector_module_path="segmented_prefill_example_connector",
)
out_file = "segmented_prefill_decode_output.txt"
else:
ktc = KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, # no CUDA graphs, so layer-wise API will be called
gpu_memory_utilization=0.8,
kv_transfer_config=ktc,
)
outputs = llm.generate(prompts, sampling_params)
sep_str = "-" * 30 + "\n"
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
if __name__ == "__main__":
main()
prefill_example.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + context + "The capital of France is",
context2 + "Your name is",
context2 + context2 + "The capital of China is",
]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
if __name__ == "__main__":
main()
run.sh
#!/bin/bash
# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SEGMENTED_PREFILL_OUTPUT="segmented_prefill_decode_output.txt"
# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"
# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --segmented-prefill
# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"; then
echo "❌ Outputs differ: segmented prefill output differs from regular prefill."
diff -u "$DECODE_OUTPUT" "$SEGMENTED_PREFILL_OUTPUT"
exit 1
fi
echo "✅ Outputs match: segmented prefill test successful."
segmented_prefill_example_connector.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
ExampleConnector,
ExampleConnectorMetadata,
)
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@dataclass
class SegmentedPrefillExampleConnectorMetadata(ExampleConnectorMetadata):
_req_to_gaps: dict[str, list[tuple[int, int]]] = field(default_factory=dict)
@classmethod
def from_base(cls, base: ExampleConnectorMetadata):
return cls(requests=base.requests)
class SegmentedPrefillExampleConnector(ExampleConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._req_to_gaps: dict[str, list[tuple[int, int]]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, SegmentedPrefillExampleConnectorMetadata)
for req in connector_metadata.requests:
if not req.is_store and req.req_id in connector_metadata._req_to_gaps:
gaps = connector_metadata._req_to_gaps[req.req_id]
self.override_slot_mapping_gaps(req.slot_mapping, gaps)
total_gap_tokens = sum(end - start for start, end in gaps)
logger.info(
"Simulating gaps in KV token blocks for the "
"first load request. Total tokens: %d",
total_gap_tokens,
)
super().bind_connector_metadata(connector_metadata)
def _choose_gaps(
self, num_computed_tokens: int, num_external_tokens: int
) -> list[tuple[int, int]]:
# Simulate gaps in the external tokens, at block_size granularity.
# Create gaps of growing size (1, 2, 3 blocks etc.) in the last num_external tokens,
# with non-gap sections of the same growing size between them,
# ensuring the last block is not a gap, and all aligned to block_size.
block_size = self._block_size
external_start = num_computed_tokens
external_end = num_computed_tokens + num_external_tokens
if external_end - external_start < block_size:
return []
gaps = []
current_pos = external_start
size = 1
is_gap = True
while current_pos < external_end:
segment_size_tokens = size * block_size
segment_end = min(current_pos + segment_size_tokens, external_end)
if is_gap:
if segment_end == external_end:
# If this gap would be the last segment, skip it to ensure last is non-gap
break
gaps.append((current_pos, segment_end))
current_pos = segment_end
is_gap = not is_gap
if is_gap:
size += 1
return gaps
def _print_gaps_representation(
self,
gaps: list[tuple[int, int]],
num_external_tokens: int,
num_computed_tokens: int,
) -> None:
"""Print a human-readable representation of the tokens and gaps for debugging."""
total_tokens = num_computed_tokens + num_external_tokens
block_size = self._block_size
representation = []
for block_start in range(0, total_tokens, block_size):
block_end = min(block_start + block_size, total_tokens)
block_chars = []
for i in range(block_start, block_end):
if i < num_computed_tokens:
block_chars.append("C") # Computed token
else:
# Check if in gap
in_gap = any(start <= i < end for start, end in gaps)
block_chars.append("-" if in_gap else "E") # Gap or External token
# Determine the character for this block
unique_chars = set(block_chars)
# print 'X' if mixed token types in block
char = unique_chars.pop() if len(unique_chars) == 1 else "X"
representation.append(char)
print("Cache status per block (C=computed, E=external, -=gap, X=mixed):")
print("".join(representation))
print("Gaps: ", gaps)
print(
"Total tokens: ",
total_tokens,
", computed tokens: ",
num_computed_tokens,
", external tokens: ",
num_external_tokens,
)
@staticmethod
def override_slot_mapping_gaps(
slot_mapping: torch.Tensor, gaps: list[tuple[int, int]]
) -> None:
"""create gaps in slot_mapping by mapping them to an incorrect value"""
if not gaps:
return
gap_value = slot_mapping[-1].item() # use last value
for start, end in gaps:
slot_mapping[start:end] = gap_value
def get_computed_token_gaps(
self,
request: "Request",
) -> list[tuple[int, int]] | None:
return self._req_to_gaps.get(request.request_id)
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int | None, bool]:
num_external_tokens, _ = super().get_num_new_matched_tokens(
request, num_computed_tokens
)
num_external_tokens = (
0 if num_external_tokens is None else num_external_tokens
) # don't simulated async lookup for now
# pick requests with at least 2*block_size external tokens to simulate gaps
if num_external_tokens >= self._block_size * 2:
gaps = self._choose_gaps(num_computed_tokens, num_external_tokens)
self._req_to_gaps[request.request_id] = gaps
self._print_gaps_representation(
gaps, num_external_tokens, num_computed_tokens
)
return num_external_tokens, False
def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
base = super().build_connector_meta(scheduler_output)
assert isinstance(base, ExampleConnectorMetadata)
meta = SegmentedPrefillExampleConnectorMetadata.from_base(base)
meta._req_to_gaps = self._req_to_gaps.copy()
self._req_to_gaps.clear()
return meta