Skip to content
This repository was archived by the owner on May 12, 2021. It is now read-only.

Commit 8856b3a

Browse files
author
Sebastien Boeuf
authored
Merge pull request #227 from amshinde/pci-addr
Use PCI Addresses to determine the device names for virtio-blk devices
2 parents 04d58dd + d29bf53 commit 8856b3a

File tree

7 files changed

+284
-49
lines changed

7 files changed

+284
-49
lines changed

agent.go

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"os"
1616
"os/exec"
1717
"os/signal"
18+
"path/filepath"
1819
"runtime"
1920
"runtime/debug"
2021
"strings"
@@ -23,6 +24,7 @@ import (
2324
"time"
2425

2526
"github.com/gogo/protobuf/proto"
27+
"github.com/kata-containers/agent/pkg/uevent"
2628
pb "github.com/kata-containers/agent/protocols/grpc"
2729
"github.com/opencontainers/runc/libcontainer"
2830
"github.com/opencontainers/runc/libcontainer/configs"
@@ -74,6 +76,8 @@ type sandbox struct {
7476
mounts []string
7577
subreaper reaper
7678
server *grpc.Server
79+
pciDeviceMap map[string]string
80+
deviceWatchers map[string](chan string)
7781
}
7882

7983
type namespace struct {
@@ -305,6 +309,55 @@ func (s *sandbox) teardownSharedPidNs() error {
305309
return nil
306310
}
307311

312+
func (s *sandbox) listenToUdevEvents() {
313+
fieldLogger := agentLog.WithField("subsystem", "udevlistener")
314+
315+
uEvHandler, err := uevent.NewHandler()
316+
if err != nil {
317+
fieldLogger.Warnf("Error starting uevent listening loop %s", err)
318+
return
319+
}
320+
defer uEvHandler.Close()
321+
322+
for {
323+
uEv, err := uEvHandler.Read()
324+
if err != nil {
325+
fieldLogger.Error(err)
326+
continue
327+
}
328+
329+
fieldLogger = fieldLogger.WithFields(logrus.Fields{
330+
"uevent-action": uEv.Action,
331+
"uevent-devpath": uEv.DevPath,
332+
"uevent-subsystem": uEv.SubSystem,
333+
"uevent-seqnum": uEv.SeqNum,
334+
"uevent-devname": uEv.DevName,
335+
})
336+
337+
// Check if device hotplug event results in a device node being created.
338+
if uEv.DevName != "" && uEv.Action == "add" && strings.HasPrefix(uEv.DevPath, rootBusPath) {
339+
// Lock is needed to safey read and modify the pciDeviceMap and deviceWatchers.
340+
// This makes sure that watchers do not access the map while it is being updated.
341+
s.Lock()
342+
343+
// Add the device node name to the pci device map.
344+
s.pciDeviceMap[uEv.DevPath] = uEv.DevName
345+
346+
// Notify watchers that are interested in the udev event.
347+
// Close the channel after watcher has been notified.
348+
for devPCIAddress, ch := range s.deviceWatchers {
349+
if ch != nil && strings.HasPrefix(uEv.DevPath, filepath.Join(rootBusPath, devPCIAddress)) {
350+
ch <- uEv.DevName
351+
close(ch)
352+
delete(s.deviceWatchers, uEv.DevName)
353+
}
354+
}
355+
356+
s.Unlock()
357+
}
358+
}
359+
}
360+
308361
// This loop is meant to be run inside a separate Go routine.
309362
func (s *sandbox) reaperLoop(sigCh chan os.Signal) {
310363
for sig := range sigCh {
@@ -643,8 +696,10 @@ func main() {
643696
running: false,
644697
// pivot_root won't work for init, see
645698
// Documention/filesystem/ramfs-rootfs-initramfs.txt
646-
noPivotRoot: os.Getpid() == 1,
647-
subreaper: r,
699+
noPivotRoot: os.Getpid() == 1,
700+
subreaper: r,
701+
pciDeviceMap: make(map[string]string),
702+
deviceWatchers: make(map[string](chan string)),
648703
}
649704

650705
if err = s.initLogger(); err != nil {
@@ -665,5 +720,7 @@ func main() {
665720
// Start gRPC server.
666721
s.startGRPC()
667722

723+
go s.listenToUdevEvents()
724+
668725
s.wg.Wait()
669726
}

device.go

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ const (
2929
driverSCSIType = "scsi"
3030
)
3131

32+
const rootBusPath = "/devices/pci0000:00"
33+
34+
var (
35+
sysBusPrefix = "/sys/bus/pci/devices"
36+
pciBusPathFormat = "%s/%s/pci_bus/"
37+
systemDevPath = "/dev"
38+
)
39+
3240
// SCSI variables
3341
var (
3442
// Here in "0:0", the first number is the SCSI host number because
@@ -42,31 +50,123 @@ var (
4250
scsiHostPath = filepath.Join(sysClassPrefix, "scsi_host")
4351
)
4452

45-
type deviceHandler func(device pb.Device, spec *pb.Spec) error
53+
type deviceHandler func(device pb.Device, spec *pb.Spec, s *sandbox) error
4654

4755
var deviceHandlerList = map[string]deviceHandler{
4856
driverBlkType: virtioBlkDeviceHandler,
4957
driverSCSIType: virtioSCSIDeviceHandler,
5058
}
5159

52-
func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec) error {
53-
// First need to make sure the expected device shows up properly,
54-
// and then we need to retrieve its device info (such as major and
55-
// minor numbers), useful to update the device provided
56-
// through the OCI specification.
57-
devName := strings.TrimPrefix(device.VmPath, devPrefix)
58-
checkUevent := func(uEv *uevent.Uevent) bool {
59-
return (uEv.Action == "add" &&
60-
filepath.Base(uEv.DevPath) == devName)
60+
// getDevicePCIAddress fetches the complete PCI address in sysfs, based on the PCI
61+
// identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
62+
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
63+
// while deviceAddr is the address at which the device is attached on the bridge.
64+
func getDevicePCIAddress(pciID string) (string, error) {
65+
tokens := strings.Split(pciID, "/")
66+
67+
if len(tokens) != 2 {
68+
return "", fmt.Errorf("PCI Identifier for device should be of format [bridgeAddr/deviceAddr], got %s", pciID)
69+
}
70+
71+
bridgeID := tokens[0]
72+
deviceID := tokens[1]
73+
74+
// Deduce the complete bridge address based on the bridge address identifier passed
75+
// and the fact that bridges are attached on the main bus with function 0.
76+
pciBridgeAddr := fmt.Sprintf("0000:00:%s.0", bridgeID)
77+
78+
// Find out the bus exposed by bridge
79+
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, pciBridgeAddr)
80+
81+
files, err := ioutil.ReadDir(bridgeBusPath)
82+
if err != nil {
83+
return "", fmt.Errorf("Error with getting bridge pci bus : %s", err)
84+
}
85+
86+
busNum := len(files)
87+
if busNum != 1 {
88+
return "", fmt.Errorf("Expected an entry for bus in %s, got %d entries instead", bridgeBusPath, busNum)
6189
}
62-
if err := waitForDevice(device.VmPath, devName, checkUevent); err != nil {
90+
91+
bus := files[0].Name()
92+
93+
// Device address is based on the bus of the bridge to which it is attached.
94+
// We do not pass devices as multifunction, hence the trailing 0 in the address.
95+
pciDeviceAddr := fmt.Sprintf("%s:%s.0", bus, deviceID)
96+
97+
bridgeDevicePCIAddr := fmt.Sprintf("%s/%s", pciBridgeAddr, pciDeviceAddr)
98+
agentLog.WithField("completePCIAddr", bridgeDevicePCIAddr).Info("Fetched PCI address for device")
99+
100+
return bridgeDevicePCIAddr, nil
101+
}
102+
103+
func getBlockDeviceNodeName(s *sandbox, pciID string) (string, error) {
104+
pciAddr, err := getDevicePCIAddress(pciID)
105+
if err != nil {
106+
return "", err
107+
}
108+
109+
var devName string
110+
var notifyChan chan string
111+
112+
fieldLogger := agentLog.WithField("pciID", pciID)
113+
114+
// Check if the PCI identifier is in PCI device map.
115+
s.Lock()
116+
for key, value := range s.pciDeviceMap {
117+
if strings.Contains(key, pciAddr) {
118+
devName = value
119+
fieldLogger.Info("Device found in pci device map")
120+
break
121+
}
122+
}
123+
124+
// If device is not found in the device map, hotplug event has not
125+
// been received yet, create and add channel to the watchers map.
126+
// The key of the watchers map is the device we are interested in.
127+
// Note this is done inside the lock, not to miss any events from the
128+
// global udev listener.
129+
if devName == "" {
130+
notifyChan := make(chan string, 1)
131+
s.deviceWatchers[pciAddr] = notifyChan
132+
}
133+
s.Unlock()
134+
135+
if devName == "" {
136+
fieldLogger.Info("Waiting on channel for device notification")
137+
select {
138+
case devName = <-notifyChan:
139+
case <-time.After(time.Duration(timeoutHotplug) * time.Second):
140+
s.Lock()
141+
delete(s.deviceWatchers, pciAddr)
142+
close(notifyChan)
143+
s.Unlock()
144+
145+
return "", grpcStatus.Errorf(codes.DeadlineExceeded,
146+
"Timeout reached after %ds waiting for device %s",
147+
timeoutHotplug, pciAddr)
148+
}
149+
}
150+
151+
return filepath.Join(systemDevPath, devName), nil
152+
}
153+
154+
// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
155+
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
156+
// while deviceAddr is the address at which the device is attached on the bridge.
157+
func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
158+
// Get the device node path based on the PCI device address
159+
devPath, err := getBlockDeviceNodeName(s, device.Id)
160+
if err != nil {
63161
return err
64162
}
163+
device.VmPath = devPath
65164

66165
return updateSpecDeviceList(device, spec)
67166
}
68167

69-
func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec) error {
168+
// device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
169+
func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
70170
// Retrieve the device path from SCSI address.
71171
devPath, err := getSCSIDevPath(device.Id)
72172
if err != nil {
@@ -270,13 +370,13 @@ func getSCSIDevPath(scsiAddr string) (string, error) {
270370
return filepath.Join(devPrefix, scsiDiskName), nil
271371
}
272372

273-
func addDevices(devices []*pb.Device, spec *pb.Spec) error {
373+
func addDevices(devices []*pb.Device, spec *pb.Spec, s *sandbox) error {
274374
for _, device := range devices {
275375
if device == nil {
276376
continue
277377
}
278378

279-
err := addDevice(device, spec)
379+
err := addDevice(device, spec, s)
280380
if err != nil {
281381
return err
282382
}
@@ -286,7 +386,7 @@ func addDevices(devices []*pb.Device, spec *pb.Spec) error {
286386
return nil
287387
}
288388

289-
func addDevice(device *pb.Device, spec *pb.Spec) error {
389+
func addDevice(device *pb.Device, spec *pb.Spec, s *sandbox) error {
290390
if device == nil {
291391
return grpcStatus.Error(codes.InvalidArgument, "invalid device")
292392
}
@@ -326,5 +426,5 @@ func addDevice(device *pb.Device, spec *pb.Spec) error {
326426
"Unknown device type %q", device.Type)
327427
}
328428

329-
return devHandler(*device, spec)
429+
return devHandler(*device, spec, s)
330430
}

device_test.go

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func testVirtioBlkDeviceHandlerFailure(t *testing.T, device pb.Device, spec *pb.
4040
device.VmPath = devPath
4141
device.ContainerPath = "some-not-empty-path"
4242

43-
err = virtioBlkDeviceHandler(device, spec)
43+
err = virtioBlkDeviceHandler(device, spec, &sandbox{})
4444
assert.NotNil(t, err, "blockDeviceHandler() should have failed")
4545
}
4646

@@ -73,6 +73,49 @@ func TestVirtioBlkDeviceHandlerEmptyLinuxDevicesSpecFailure(t *testing.T) {
7373
testVirtioBlkDeviceHandlerFailure(t, device, spec)
7474
}
7575

76+
func TestGetPCIAddress(t *testing.T) {
77+
testDir, err := ioutil.TempDir("", "kata-agent-tmp-")
78+
if err != nil {
79+
t.Fatal(t, err)
80+
}
81+
defer os.RemoveAll(testDir)
82+
83+
pciID := "02"
84+
_, err = getDevicePCIAddress(pciID)
85+
assert.NotNil(t, err)
86+
87+
pciID = "02/03/04"
88+
_, err = getDevicePCIAddress(pciID)
89+
assert.NotNil(t, err)
90+
91+
bridgeID := "02"
92+
deviceID := "03"
93+
pciBus := "0000:01"
94+
expectedPCIAddress := "0000:00:02.0/0000:01:03.0"
95+
pciID = fmt.Sprintf("%s/%s", bridgeID, deviceID)
96+
97+
// Set sysBusPrefix to test directory for unit tests.
98+
sysBusPrefix = testDir
99+
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0")
100+
101+
_, err = getDevicePCIAddress(pciID)
102+
assert.NotNil(t, err)
103+
104+
err = os.MkdirAll(bridgeBusPath, mountPerm)
105+
assert.Nil(t, err)
106+
107+
_, err = getDevicePCIAddress(pciID)
108+
assert.NotNil(t, err)
109+
110+
err = os.MkdirAll(filepath.Join(bridgeBusPath, pciBus), mountPerm)
111+
assert.Nil(t, err)
112+
113+
addr, err := getDevicePCIAddress(pciID)
114+
assert.Nil(t, err)
115+
116+
assert.Equal(t, addr, expectedPCIAddress)
117+
}
118+
76119
func TestScanSCSIBus(t *testing.T) {
77120
testDir, err := ioutil.TempDir("", "kata-agent-tmp-")
78121
if err != nil {
@@ -112,7 +155,7 @@ func TestScanSCSIBus(t *testing.T) {
112155
}
113156

114157
func testAddDevicesSuccessful(t *testing.T, devices []*pb.Device, spec *pb.Spec) {
115-
err := addDevices(devices, spec)
158+
err := addDevices(devices, spec, &sandbox{})
116159
assert.Nil(t, err, "addDevices() failed: %v", err)
117160
}
118161

@@ -133,11 +176,11 @@ func TestAddDevicesNilMountsSuccessful(t *testing.T) {
133176
testAddDevicesSuccessful(t, devices, spec)
134177
}
135178

136-
func noopDeviceHandlerReturnNil(device pb.Device, spec *pb.Spec) error {
179+
func noopDeviceHandlerReturnNil(device pb.Device, spec *pb.Spec, s *sandbox) error {
137180
return nil
138181
}
139182

140-
func noopDeviceHandlerReturnError(device pb.Device, spec *pb.Spec) error {
183+
func noopDeviceHandlerReturnError(device pb.Device, spec *pb.Spec, s *sandbox) error {
141184
return fmt.Errorf("Noop handler failure")
142185
}
143186

@@ -159,7 +202,7 @@ func TestAddDevicesNoopHandlerSuccessful(t *testing.T) {
159202
}
160203

161204
func testAddDevicesFailure(t *testing.T, devices []*pb.Device, spec *pb.Spec) {
162-
err := addDevices(devices, spec)
205+
err := addDevices(devices, spec, &sandbox{})
163206
assert.NotNil(t, err, "addDevices() should have failed")
164207
}
165208

@@ -319,8 +362,10 @@ func TestAddDevice(t *testing.T) {
319362
},
320363
}
321364

365+
s := &sandbox{}
366+
322367
for i, d := range data {
323-
err := addDevice(d.device, d.spec)
368+
err := addDevice(d.device, d.spec, s)
324369
if d.expectError {
325370
assert.Errorf(err, "test %d (%+v)", i, d)
326371
} else {

0 commit comments

Comments
 (0)