Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (p *Plugin) WithName(name string) *Plugin {
}

// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
Expand All @@ -183,7 +183,8 @@ func (p *Plugin) Score(ctx context.Context, _ *types.CycleState, request *types.
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
}

p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().Type), state)
cycleState.Write(plugins.StateKey(p.TypedName().String()), state)
p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state)
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
// calculate the scores of pods
scores := make(map[types.Pod]float64, len(pods))
Expand All @@ -208,7 +209,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, PrefixCachePluginType)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
Expand Down
24 changes: 12 additions & 12 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaaaa",
}
scores := plugin.Score(context.Background(), nil, req1, pods)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
Expand All @@ -79,8 +79,8 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model2",
Prompt: "bbbbbb",
}
scores = plugin.Score(context.Background(), nil, req2, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType)
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
Expand All @@ -105,8 +105,8 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaabbbb",
}
scores = plugin.Score(context.Background(), nil, req3, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType)
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
Expand All @@ -130,8 +130,8 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model-new",
Prompt: "aaaabbbb",
}
scores = plugin.Score(context.Background(), nil, req4, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, PrefixCachePluginType)
scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
Expand All @@ -155,8 +155,8 @@ func TestPrefixPlugin(t *testing.T) {
TargetModel: "test-model1",
Prompt: "aaaabbbbcccc",
}
scores = plugin.Score(context.Background(), nil, req5, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, PrefixCachePluginType)
scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(t, err)
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
Expand Down Expand Up @@ -212,7 +212,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
}

// First cycle: simulate scheduling and insert prefix info into the cache
plugin.Score(context.Background(), nil, req, pods)
plugin.Score(context.Background(), types.NewCycleState(), req, pods)
schedulingResult := &types.SchedulingResult{
PrimaryProfileName: "default",
ProfileResults: map[string]*types.ProfileRunResult{
Expand All @@ -222,7 +222,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
plugin.PreRequest(context.Background(), req, schedulingResult, 0)

// Second cycle: validate internal state
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
assert.NoError(b, err)
expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize)))
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")
Expand Down