Skip to content

Commit

Permalink
configurable filter chains
Browse files Browse the repository at this point in the history
Signed-off-by: Kuromesi <[email protected]>
  • Loading branch information
Kuromesi committed Jan 9, 2025
1 parent 38cddf0 commit 87aeaac
Show file tree
Hide file tree
Showing 9 changed files with 668 additions and 41 deletions.
14 changes: 14 additions & 0 deletions pkg/ext-proc/backend/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,24 @@ type K8sDatastore struct {
inferencePool *v1alpha1.InferencePool
InferenceModels *sync.Map
pods *sync.Map

filterConfigMap *corev1.ConfigMap
}

type K8sDatastoreOption func(*K8sDatastore)

func (ds *K8sDatastore) GetFilterConfigMap() *corev1.ConfigMap {
ds.poolMu.RLock()
defer ds.poolMu.RUnlock()
return ds.filterConfigMap
}

func WithFilterConfigMap(filterConfigMap *corev1.ConfigMap) K8sDatastoreOption {
return func(store *K8sDatastore) {
store.filterConfigMap = filterConfigMap
}
}

// WithPods can be used in tests to override the pods.
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
return func(store *K8sDatastore) {
Expand Down
53 changes: 53 additions & 0 deletions pkg/ext-proc/backend/filterconfig_reconciler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package backend

import (
"context"

corev1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/predicate"
)

type FilterConfigReconciler struct {
client.Client
Datastore *K8sDatastore
}

func (c *FilterConfigReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
cm := &corev1.ConfigMap{}
if err := c.Get(ctx, req.NamespacedName, cm); err != nil {
if client.IgnoreNotFound(err) != nil {
klog.Errorf("unable to get ConfigMap, err: %v", err)
return ctrl.Result{}, err
}
c.Datastore.poolMu.Lock()
defer c.Datastore.poolMu.Unlock()
klog.V(1).Info("filter config deleted, reset filter config")
c.Datastore.filterConfigMap = nil
return ctrl.Result{}, nil
}

c.Datastore.poolMu.Lock()
defer c.Datastore.poolMu.Unlock()

if cm.DeletionTimestamp != nil {
klog.V(1).Info("filter config deleting, reset filter config")
c.Datastore.filterConfigMap = nil
return ctrl.Result{}, nil
}

klog.V(1).Infof("update filter config to: %++v", cm.Data)
c.Datastore.filterConfigMap = cm.DeepCopy()
return ctrl.Result{}, nil
}

func (c *FilterConfigReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&corev1.ConfigMap{}).
WithEventFilter(predicate.NewPredicateFuncs(func(object client.Object) bool {
return object.GetName() == "filter-config" && object.GetNamespace() == "default"
})).
Complete(c)
}
19 changes: 18 additions & 1 deletion pkg/ext-proc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ var (
"refreshMetricsInterval",
50*time.Millisecond,
"interval to refresh metrics")
enableFilterConfiguration = flag.Bool(
"enableFilterConfiguration",
false,
"Whether to enable configuring filters in `default/filter-config` configmap, ONLY FOR DEV NOW.",
)

scheme = runtime.NewScheme()
)
Expand Down Expand Up @@ -146,6 +151,17 @@ func main() {
klog.Error(err, "Error setting up EndpointSliceReconciler")
}

var orchestrator *scheduling.FilterOrchestratorImpl
if *enableFilterConfiguration {
if err := (&backend.FilterConfigReconciler{
Datastore: datastore,
Client: mgr.GetClient(),
}).SetupWithManager(mgr); err != nil {
klog.Error(err, "Error setting up FilterConfigReconciler")
}
orchestrator = scheduling.NewFilterOrchestrator(datastore)
}

errChan := make(chan error)
go func() {
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
Expand All @@ -164,7 +180,8 @@ func main() {
s,
handlers.NewServer(
pp,
scheduling.NewScheduler(pp),
// when orchestrator is nil, default filter will be returned
scheduling.NewScheduler(pp, scheduling.WithOrchestrator(orchestrator)),
*targetPodHeader,
datastore))
healthPb.RegisterHealthServer(s, &healthServer{})
Expand Down
35 changes: 22 additions & 13 deletions pkg/ext-proc/scheduling/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,46 @@ import (
"errors"
"math"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend"

klog "k8s.io/klog/v2"
)

type Filter interface {
type FilterChain interface {
Name() string
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
}

// filter applies current filterFunc, and then recursively applies next filters depending success or
// filterChainImpl applies current filterFunc, and then recursively applies next filters depending success or
// failure of the current filterFunc.
// It can be used to construct a flow chart algorithm.
type filter struct {
type filterChainImpl struct {
name string
filter filterFunc
filter filter
// nextOnSuccess filter will be applied after successfully applying the current filter.
// The filtered results will be passed to the next filter.
nextOnSuccess *filter
nextOnSuccess *filterChainImpl
// nextOnFailure filter will be applied if current filter fails.
// The original input will be passed to the next filter.
nextOnFailure *filter
nextOnFailure *filterChainImpl
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
// success or failure of the current filter.
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
nextOnSuccessOrFailure *filter
nextOnSuccessOrFailure *filterChainImpl
}

func (f *filter) Name() string {
func (f *filterChainImpl) Name() string {
if f == nil {
return "nil"
}
return f.name
}

func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
func (f *filterChainImpl) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, req, len(pods))

filtered, err := f.filter(req, pods)
Expand Down Expand Up @@ -71,11 +74,11 @@ func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend
}
}

// filterFunc filters a set of input pods to a subset.
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
// filter filters a set of input pods to a subset.
type filter func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)

// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
func toFilterFunc(pp podPredicate) filterFunc {
// toFilter is a helper function to convert a per pod filter func to the FilterFunc.
func toFilter(pp podPredicate) filter {
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
Expand Down Expand Up @@ -152,6 +155,12 @@ func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*bac
return filtered, nil
}

func dropRequestFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
klog.Infof("Dropping request %v", req)
return []*backend.PodMetrics{}, status.Errorf(
codes.ResourceExhausted, "dropping request due to limited backend resources")
}

// podPredicate is a filter function to check whether a pod is desired.
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool

Expand Down
10 changes: 5 additions & 5 deletions pkg/ext-proc/scheduling/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ func TestFilter(t *testing.T) {
input []*backend.PodMetrics
output []*backend.PodMetrics
err bool
filter *filter
filter *filterChainImpl
}{
{
name: "simple filter without successor, failure",
filter: &filter{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
filter: &filterChainImpl{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
return nil, errors.New("filter error")
}},
err: true,
Expand Down Expand Up @@ -216,7 +216,7 @@ func TestFilter(t *testing.T) {
func TestFilterFunc(t *testing.T) {
tests := []struct {
name string
f filterFunc
f filter
req *LLMRequest
input []*backend.PodMetrics
output []*backend.PodMetrics
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestFilterFunc(t *testing.T) {
},
{
name: "noQueueAndLessThanKVCacheThresholdPredicate",
f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
f: toFilter(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
input: []*backend.PodMetrics{
{
// This pod should be returned.
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestFilterFunc(t *testing.T) {
},
{
name: "low LoRA cost",
f: toFilterFunc(lowLoRACostPredicate),
f: toFilter(lowLoRACostPredicate),
req: &LLMRequest{
Model: "model",
ResolvedTargetModel: "model",
Expand Down
147 changes: 147 additions & 0 deletions pkg/ext-proc/scheduling/filtergen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package scheduling

const (
FilterCriticalRequestName = "critical_request"
FilterLeastQueuingName = "least_queuing"
FilterLowCostLoraName = "low_cost_lora"
FilterLowLatencyName = "low_latency"
FilterAffinityLoraName = "affinity_lora"
FilterSheddableRequestName = "sheddable_request"
FilterLeastKvCacheName = "least_kv_cache"
FilterDropRequestName = "drop_request"
FilterCanAcceptNewLoraName = "can_accept_new_lora"
)

const (
TopKByWaitingQueueSize = "waiting_queue_size"
TopKByKVCacheUsagePercent = "kv_cache_usage_percent"
)

var filterMap = map[string]FilterGen{
FilterLowLatencyName: FilterLowLatency,
FilterCriticalRequestName: FilterCriticalRequest,
FilterLeastQueuingName: FilterLeastQueuing,
FilterCanAcceptNewLoraName: FilterCanAcceptNewLora,
FilterSheddableRequestName: FilterSheddableRequest,
FilterDropRequestName: FilterDropRequest,
FilterAffinityLoraName: FilterAffinityLora,
FilterLowCostLoraName: FilterLowCostLora,
FilterLeastKvCacheName: FilterLeastKvCache,
}

// FilterGen generate a filter from a filter option
type FilterGen interface {
Name() string
Get(*FilterOption) filter
Validate(*FilterOption) error
}

type FilterOption struct {
KvCacheThreshold *float64 `json:"kvCacheThreshold,omitempty"`

QueueThresholdCritical *int `json:"queueThresholdCritical,omitempty"`
QueueingThresholdLoRA *int `json:"queueingThresholdLoRA,omitempty"`
}

type filterGenImpl struct {
name string
getter func(*FilterOption) filter
validator func(*FilterOption) error
}

var _ FilterGen = &filterGenImpl{}

func (fg *filterGenImpl) Name() string {
return fg.name
}

func (fg *filterGenImpl) Get(fo *FilterOption) filter {
return fg.getter(fo)
}

func (fg *filterGenImpl) Validate(fo *FilterOption) error {
return fg.validator(fo)
}

var (
FilterCriticalRequest FilterGen = &filterGenImpl{
name: FilterCriticalRequestName,
getter: func(fo *FilterOption) filter {
return toFilter(criticalRequestPredicate)
},
validator: func(fo *FilterOption) error { return nil },
}

FilterLeastQueuing FilterGen = &filterGenImpl{
name: FilterLeastQueuingName,
getter: func(fo *FilterOption) filter {
return leastQueuingFilterFunc
},
validator: func(fo *FilterOption) error { return nil },
}

FilterLowCostLora FilterGen = &filterGenImpl{
name: FilterLowCostLoraName,
getter: func(fo *FilterOption) filter {
return toFilter(lowLoRACostPredicate)
},
validator: func(fo *FilterOption) error { return nil },
}

FilterLowLatency FilterGen = &filterGenImpl{
name: FilterLowLatencyName,
getter: func(fo *FilterOption) filter {
return toFilter(lowQueueingPodPredicate)
},
validator: func(fo *FilterOption) error { return nil },
}

FilterAffinityLora FilterGen = &filterGenImpl{
name: FilterAffinityLoraName,
getter: func(fo *FilterOption) filter {
return toFilter(loRAAffinityPredicate)
},
validator: func(fo *FilterOption) error { return nil },
}

FilterSheddableRequest FilterGen = &filterGenImpl{
name: FilterSheddableRequestName,
getter: func(opt *FilterOption) filter {
qtc, kct := queueThresholdCritical, kvCacheThreshold
if opt != nil {
if opt.KvCacheThreshold != nil {
kct = *opt.KvCacheThreshold
}
if opt.QueueThresholdCritical != nil {
qtc = *opt.QueueThresholdCritical
}
}
return toFilter(noQueueAndLessThanKVCacheThresholdPredicate(qtc, kct))
},
validator: func(fo *FilterOption) error { return nil },
}

FilterLeastKvCache FilterGen = &filterGenImpl{
name: FilterLeastKvCacheName,
getter: func(fo *FilterOption) filter {
return leastKVCacheFilterFunc
},
validator: func(fo *FilterOption) error { return nil },
}

FilterDropRequest FilterGen = &filterGenImpl{
name: FilterDropRequestName,
getter: func(fo *FilterOption) filter {
return dropRequestFilterFunc
},
validator: func(fo *FilterOption) error { return nil },
}

FilterCanAcceptNewLora FilterGen = &filterGenImpl{
name: FilterCanAcceptNewLoraName,
getter: func(fo *FilterOption) filter {
return toFilter(canAcceptNewLoraPredicate)
},
validator: func(fo *FilterOption) error { return nil },
}
)
Loading

0 comments on commit 87aeaac

Please sign in to comment.