@@ -29,6 +29,14 @@ const (
2929 driverSCSIType = "scsi"
3030)
3131
32+ const rootBusPath = "/devices/pci0000:00"
33+
34+ var (
35+ sysBusPrefix = "/sys/bus/pci/devices"
36+ pciBusPathFormat = "%s/%s/pci_bus/"
37+ systemDevPath = "/dev"
38+ )
39+
3240// SCSI variables
3341var (
3442 // Here in "0:0", the first number is the SCSI host number because
@@ -42,31 +50,123 @@ var (
4250 scsiHostPath = filepath .Join (sysClassPrefix , "scsi_host" )
4351)
4452
45- type deviceHandler func (device pb.Device , spec * pb.Spec ) error
53+ type deviceHandler func (device pb.Device , spec * pb.Spec , s * sandbox ) error
4654
4755var deviceHandlerList = map [string ]deviceHandler {
4856 driverBlkType : virtioBlkDeviceHandler ,
4957 driverSCSIType : virtioSCSIDeviceHandler ,
5058}
5159
52- func virtioBlkDeviceHandler (device pb.Device , spec * pb.Spec ) error {
53- // First need to make sure the expected device shows up properly,
54- // and then we need to retrieve its device info (such as major and
55- // minor numbers), useful to update the device provided
56- // through the OCI specification.
57- devName := strings .TrimPrefix (device .VmPath , devPrefix )
58- checkUevent := func (uEv * uevent.Uevent ) bool {
59- return (uEv .Action == "add" &&
60- filepath .Base (uEv .DevPath ) == devName )
60+ // getDevicePCIAddress fetches the complete PCI address in sysfs, based on the PCI
61+ // identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
62+ // Here, bridgeAddr is the address at which the brige is attached on the root bus,
63+ // while deviceAddr is the address at which the device is attached on the bridge.
64+ func getDevicePCIAddress (pciID string ) (string , error ) {
65+ tokens := strings .Split (pciID , "/" )
66+
67+ if len (tokens ) != 2 {
68+ return "" , fmt .Errorf ("PCI Identifier for device should be of format [bridgeAddr/deviceAddr], got %s" , pciID )
69+ }
70+
71+ bridgeID := tokens [0 ]
72+ deviceID := tokens [1 ]
73+
74+ // Deduce the complete bridge address based on the bridge address identifier passed
75+ // and the fact that bridges are attached on the main bus with function 0.
76+ pciBridgeAddr := fmt .Sprintf ("0000:00:%s.0" , bridgeID )
77+
78+ // Find out the bus exposed by bridge
79+ bridgeBusPath := fmt .Sprintf (pciBusPathFormat , sysBusPrefix , pciBridgeAddr )
80+
81+ files , err := ioutil .ReadDir (bridgeBusPath )
82+ if err != nil {
83+ return "" , fmt .Errorf ("Error with getting bridge pci bus : %s" , err )
84+ }
85+
86+ busNum := len (files )
87+ if busNum != 1 {
88+ return "" , fmt .Errorf ("Expected an entry for bus in %s, got %d entries instead" , bridgeBusPath , busNum )
6189 }
62- if err := waitForDevice (device .VmPath , devName , checkUevent ); err != nil {
90+
91+ bus := files [0 ].Name ()
92+
93+ // Device address is based on the bus of the bridge to which it is attached.
94+ // We do not pass devices as multifunction, hence the trailing 0 in the address.
95+ pciDeviceAddr := fmt .Sprintf ("%s:%s.0" , bus , deviceID )
96+
97+ bridgeDevicePCIAddr := fmt .Sprintf ("%s/%s" , pciBridgeAddr , pciDeviceAddr )
98+ agentLog .WithField ("completePCIAddr" , bridgeDevicePCIAddr ).Info ("Fetched PCI address for device" )
99+
100+ return bridgeDevicePCIAddr , nil
101+ }
102+
103+ func getBlockDeviceNodeName (s * sandbox , pciID string ) (string , error ) {
104+ pciAddr , err := getDevicePCIAddress (pciID )
105+ if err != nil {
106+ return "" , err
107+ }
108+
109+ var devName string
110+ var notifyChan chan string
111+
112+ fieldLogger := agentLog .WithField ("pciID" , pciID )
113+
114+ // Check if the PCI identifier is in PCI device map.
115+ s .Lock ()
116+ for key , value := range s .pciDeviceMap {
117+ if strings .Contains (key , pciAddr ) {
118+ devName = value
119+ fieldLogger .Info ("Device found in pci device map" )
120+ break
121+ }
122+ }
123+
124+ // If device is not found in the device map, hotplug event has not
125+ // been received yet, create and add channel to the watchers map.
126+ // The key of the watchers map is the device we are interested in.
127+ // Note this is done inside the lock, not to miss any events from the
128+ // global udev listener.
129+ if devName == "" {
130+ notifyChan := make (chan string , 1 )
131+ s .deviceWatchers [pciAddr ] = notifyChan
132+ }
133+ s .Unlock ()
134+
135+ if devName == "" {
136+ fieldLogger .Info ("Waiting on channel for device notification" )
137+ select {
138+ case devName = <- notifyChan :
139+ case <- time .After (time .Duration (timeoutHotplug ) * time .Second ):
140+ s .Lock ()
141+ delete (s .deviceWatchers , pciAddr )
142+ close (notifyChan )
143+ s .Unlock ()
144+
145+ return "" , grpcStatus .Errorf (codes .DeadlineExceeded ,
146+ "Timeout reached after %ds waiting for device %s" ,
147+ timeoutHotplug , pciAddr )
148+ }
149+ }
150+
151+ return filepath .Join (systemDevPath , devName ), nil
152+ }
153+
154+ // device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
155+ // Here, bridgeAddr is the address at which the brige is attached on the root bus,
156+ // while deviceAddr is the address at which the device is attached on the bridge.
157+ func virtioBlkDeviceHandler (device pb.Device , spec * pb.Spec , s * sandbox ) error {
158+ // Get the device node path based on the PCI device address
159+ devPath , err := getBlockDeviceNodeName (s , device .Id )
160+ if err != nil {
63161 return err
64162 }
163+ device .VmPath = devPath
65164
66165 return updateSpecDeviceList (device , spec )
67166}
68167
69- func virtioSCSIDeviceHandler (device pb.Device , spec * pb.Spec ) error {
168+ // device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
169+ func virtioSCSIDeviceHandler (device pb.Device , spec * pb.Spec , s * sandbox ) error {
70170 // Retrieve the device path from SCSI address.
71171 devPath , err := getSCSIDevPath (device .Id )
72172 if err != nil {
@@ -270,13 +370,13 @@ func getSCSIDevPath(scsiAddr string) (string, error) {
270370 return filepath .Join (devPrefix , scsiDiskName ), nil
271371}
272372
273- func addDevices (devices []* pb.Device , spec * pb.Spec ) error {
373+ func addDevices (devices []* pb.Device , spec * pb.Spec , s * sandbox ) error {
274374 for _ , device := range devices {
275375 if device == nil {
276376 continue
277377 }
278378
279- err := addDevice (device , spec )
379+ err := addDevice (device , spec , s )
280380 if err != nil {
281381 return err
282382 }
@@ -286,7 +386,7 @@ func addDevices(devices []*pb.Device, spec *pb.Spec) error {
286386 return nil
287387}
288388
289- func addDevice (device * pb.Device , spec * pb.Spec ) error {
389+ func addDevice (device * pb.Device , spec * pb.Spec , s * sandbox ) error {
290390 if device == nil {
291391 return grpcStatus .Error (codes .InvalidArgument , "invalid device" )
292392 }
@@ -326,5 +426,5 @@ func addDevice(device *pb.Device, spec *pb.Spec) error {
326426 "Unknown device type %q" , device .Type )
327427 }
328428
329- return devHandler (* device , spec )
429+ return devHandler (* device , spec , s )
330430}
0 commit comments