Skip to content

Commit fe09d5a

Browse files
committed
Add tolerations for taints on NVIDIA specific node groups
1 parent 059c8d9 commit fe09d5a

File tree

2 files changed

+89
-24
lines changed

2 files changed

+89
-24
lines changed

pkg/addons/device_plugin.go

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@ package addons
22

33
import (
44
"context"
5+
// For go:embed
6+
_ "embed"
57
"fmt"
68
"time"
79

810
"github.com/kris-nova/logger"
911
"github.com/pkg/errors"
12+
1013
api "github.com/weaveworks/eksctl/pkg/apis/eksctl.io/v1alpha5"
1114
"github.com/weaveworks/eksctl/pkg/kubernetes"
15+
"github.com/weaveworks/eksctl/pkg/utils/instance"
1216

1317
appsv1 "k8s.io/api/apps/v1"
14-
v1 "k8s.io/api/core/v1"
18+
corev1 "k8s.io/api/core/v1"
1519
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1620
"k8s.io/apimachinery/pkg/watch"
1721
clientappsv1 "k8s.io/client-go/kubernetes/typed/apps/v1"
18-
19-
// For go:embed
20-
_ "embed"
2122
)
2223

2324
//go:embed assets/efa-device-plugin.yaml
@@ -29,7 +30,7 @@ var neuronDevicePluginYaml []byte
2930
//go:embed assets/nvidia-device-plugin.yaml
3031
var nvidiaDevicePluginYaml []byte
3132

32-
func useRegionalImage(spec *v1.PodTemplateSpec, region string, account string) error {
33+
func useRegionalImage(spec *corev1.PodTemplateSpec, region string, account string) error {
3334
imageFormat := spec.Spec.Containers[0].Image
3435
dnsSuffix, err := awsDNSSuffixForRegion(region)
3536
if err != nil {
@@ -72,13 +73,14 @@ func watchDaemonSetReady(dsClientSet clientappsv1.DaemonSetInterface, dsName str
7273
}
7374
}
7475

75-
type MkDevicePlugin func(rawClient kubernetes.RawClientInterface, region string, planMode bool) DevicePlugin
76+
type MkDevicePlugin func(rawClient kubernetes.RawClientInterface, region string, planMode bool, spec *api.ClusterConfig) DevicePlugin
7677

7778
type DevicePlugin interface {
7879
RawClient() kubernetes.RawClientInterface
7980
PlanMode() bool
8081
Manifest() []byte
81-
SetImage(t *v1.PodTemplateSpec) error
82+
SetImage(t *corev1.PodTemplateSpec) error
83+
SetTolerations(t *corev1.PodTemplateSpec) error
8284
Deploy() error
8385
}
8486

@@ -103,7 +105,9 @@ func applyDevicePlugin(dp DevicePlugin) error {
103105
if err := dp.SetImage(&daemonSet.Spec.Template); err != nil {
104106
return errors.Wrap(err, "setting image of device plugin daemonset")
105107
}
106-
108+
if err := dp.SetTolerations(&daemonSet.Spec.Template); err != nil {
109+
return errors.Wrap(err, "adding tolerations to device plugin daemonset")
110+
}
107111
msg, err := rawResource.CreateOrReplace(dp.PlanMode())
108112
if err != nil {
109113
return errors.Wrap(err, "calling create or replace on raw device plugin daemonset")
@@ -124,11 +128,12 @@ func applyDevicePlugin(dp DevicePlugin) error {
124128
}
125129

126130
// NewNeuronDevicePlugin creates a new NeuronDevicePlugin
127-
func NewNeuronDevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool) DevicePlugin {
131+
func NewNeuronDevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool, spec *api.ClusterConfig) DevicePlugin {
128132
return &NeuronDevicePlugin{
129-
rawClient,
130-
region,
131-
planMode,
133+
rawClient: rawClient,
134+
region: region,
135+
planMode: planMode,
136+
spec: spec,
132137
}
133138
}
134139

@@ -137,6 +142,7 @@ type NeuronDevicePlugin struct {
137142
rawClient kubernetes.RawClientInterface
138143
region string
139144
planMode bool
145+
spec *api.ClusterConfig
140146
}
141147

142148
func (n *NeuronDevicePlugin) RawClient() kubernetes.RawClientInterface {
@@ -151,7 +157,11 @@ func (n *NeuronDevicePlugin) Manifest() []byte {
151157
return neuronDevicePluginYaml
152158
}
153159

154-
func (n *NeuronDevicePlugin) SetImage(t *v1.PodTemplateSpec) error {
160+
func (n *NeuronDevicePlugin) SetImage(t *corev1.PodTemplateSpec) error {
161+
return nil
162+
}
163+
164+
func (n *NeuronDevicePlugin) SetTolerations(t *corev1.PodTemplateSpec) error {
155165
return nil
156166
}
157167

@@ -161,11 +171,12 @@ func (n *NeuronDevicePlugin) Deploy() error {
161171
}
162172

163173
// NewNvidiaDevicePlugin creates a new NvidiaDevicePlugin
164-
func NewNvidiaDevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool) DevicePlugin {
174+
func NewNvidiaDevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool, spec *api.ClusterConfig) DevicePlugin {
165175
return &NvidiaDevicePlugin{
166-
rawClient,
167-
region,
168-
planMode,
176+
rawClient: rawClient,
177+
region: region,
178+
planMode: planMode,
179+
spec: spec,
169180
}
170181
}
171182

@@ -174,6 +185,7 @@ type NvidiaDevicePlugin struct {
174185
rawClient kubernetes.RawClientInterface
175186
region string
176187
planMode bool
188+
spec *api.ClusterConfig
177189
}
178190

179191
func (n *NvidiaDevicePlugin) RawClient() kubernetes.RawClientInterface {
@@ -184,7 +196,7 @@ func (n *NvidiaDevicePlugin) PlanMode() bool {
184196
return n.planMode
185197
}
186198

187-
func (n *NvidiaDevicePlugin) SetImage(t *v1.PodTemplateSpec) error {
199+
func (n *NvidiaDevicePlugin) SetImage(t *corev1.PodTemplateSpec) error {
188200
return nil
189201
}
190202

@@ -197,11 +209,59 @@ func (n *NvidiaDevicePlugin) Deploy() error {
197209
return applyDevicePlugin(n)
198210
}
199211

212+
// SetTolerations sets given tolerations on the DaemonSet if they don't already exist.
213+
// We check the taints on each node which is an NVIDIA instance type and apply
214+
// tolerations for all the taints defined on the node.
215+
func (n *NvidiaDevicePlugin) SetTolerations(spec *corev1.PodTemplateSpec) error {
216+
contains := func(list []corev1.Toleration, key string) bool {
217+
for _, t := range list {
218+
if t.Key == key {
219+
return true
220+
}
221+
}
222+
return false
223+
}
224+
// don't duplicate taints from other nodes or overwrite them with
225+
// different values ( shouldn't happen in general... )
226+
taints := make(map[string]api.NodeGroupTaint)
227+
for _, ng := range n.spec.NodeGroups {
228+
if api.HasInstanceType(ng, instance.IsNvidiaInstanceType) &&
229+
ng.GetAMIFamily() == api.NodeImageFamilyAmazonLinux2 {
230+
for _, taint := range ng.Taints {
231+
if _, ok := taints[taint.Key]; !ok {
232+
taints[taint.Key] = taint
233+
}
234+
}
235+
}
236+
}
237+
for _, ng := range n.spec.ManagedNodeGroups {
238+
if api.HasInstanceTypeManaged(ng, instance.IsNvidiaInstanceType) &&
239+
ng.GetAMIFamily() == api.NodeImageFamilyAmazonLinux2 {
240+
for _, taint := range ng.Taints {
241+
if _, ok := taints[taint.Key]; !ok {
242+
taints[taint.Key] = taint
243+
}
244+
}
245+
}
246+
}
247+
for _, t := range taints {
248+
// only add toleration if it doesn't already exist. In that case, we don't overwrite it.
249+
if !contains(spec.Spec.Tolerations, t.Key) {
250+
spec.Spec.Tolerations = append(spec.Spec.Tolerations, corev1.Toleration{
251+
Key: t.Key,
252+
Value: t.Value,
253+
})
254+
}
255+
}
256+
return nil
257+
}
258+
200259
// A EFADevicePlugin deploys the EFA Device Plugin to a cluster
201260
type EFADevicePlugin struct {
202261
rawClient kubernetes.RawClientInterface
203262
region string
204263
planMode bool
264+
spec *api.ClusterConfig
205265
}
206266

207267
func (n *EFADevicePlugin) RawClient() kubernetes.RawClientInterface {
@@ -216,17 +276,22 @@ func (n *EFADevicePlugin) Manifest() []byte {
216276
return efaDevicePluginYaml
217277
}
218278

219-
func (n *EFADevicePlugin) SetImage(t *v1.PodTemplateSpec) error {
279+
func (n *EFADevicePlugin) SetImage(t *corev1.PodTemplateSpec) error {
220280
account := api.EKSResourceAccountID(n.region)
221281
return useRegionalImage(t, n.region, account)
222282
}
223283

284+
func (n *EFADevicePlugin) SetTolerations(spec *corev1.PodTemplateSpec) error {
285+
return nil
286+
}
287+
224288
// NewEFADevicePlugin creates a new EFADevicePlugin
225-
func NewEFADevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool) DevicePlugin {
289+
func NewEFADevicePlugin(rawClient kubernetes.RawClientInterface, region string, planMode bool, spec *api.ClusterConfig) DevicePlugin {
226290
return &EFADevicePlugin{
227-
rawClient,
228-
region,
229-
planMode,
291+
rawClient: rawClient,
292+
region: region,
293+
planMode: planMode,
294+
spec: spec,
230295
}
231296
}
232297

pkg/eks/tasks.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func (n *devicePluginTask) Do(errCh chan error) error {
128128
if err != nil {
129129
return err
130130
}
131-
devicePlugin := n.mkPlugin(rawClient, n.clusterProvider.AWSProvider.Region(), false)
131+
devicePlugin := n.mkPlugin(rawClient, n.clusterProvider.AWSProvider.Region(), false, n.spec)
132132
if err := devicePlugin.Deploy(); err != nil {
133133
return errors.Wrap(err, "error installing device plugin")
134134
}

0 commit comments

Comments
 (0)