diff --git a/daemon/api_test.go b/daemon/api_test.go index f8f1e1c3a60..13426dbf449 100644 --- a/daemon/api_test.go +++ b/daemon/api_test.go @@ -5451,7 +5451,7 @@ func (s *postCreateUserSuite) SetUpTest(c *check.C) { s.apiBaseSuite.SetUpTest(c) s.daemon(c) - postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) { + postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) { return 100, 0, dirs.SnapdSocket, nil } s.mockUserHome = c.MkDir() @@ -5858,7 +5858,7 @@ func (s *postCreateUserSuite) TestPostCreateUserFromAssertionAllKnownClassicErro s.makeSystemUsers(c, []map[string]interface{}{goodUser}) - postCreateUserUcrednetGet = func(string) (uint32, uint32, string, error) { + postCreateUserUcrednetGet = func(string) (int32, uint32, string, error) { return 100, 0, dirs.SnapdSocket, nil } defer func() { @@ -6596,7 +6596,7 @@ func (s *apiSuite) TestSnapctlGetNoUID(c *check.C) { func (s *apiSuite) TestSnapctlForbiddenError(c *check.C) { _ = s.daemon(c) - runSnapctlUcrednetGet = func(string) (uint32, uint32, string, error) { + runSnapctlUcrednetGet = func(string) (int32, uint32, string, error) { return 100, 9999, dirs.SnapSocket, nil } defer func() { runSnapctlUcrednetGet = ucrednetGet }() diff --git a/daemon/daemon_test.go b/daemon/daemon_test.go index e3494d569fa..b1050d0a830 100644 --- a/daemon/daemon_test.go +++ b/daemon/daemon_test.go @@ -66,7 +66,7 @@ type daemonSuite struct { var _ = check.Suite(&daemonSuite{}) -func (s *daemonSuite) checkAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) { +func (s *daemonSuite) checkAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) { s.lastPolkitFlags = flags return s.authorized, s.err } @@ -372,7 +372,7 @@ func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) { // for UserOK commands, polkit is not consulted cmd.UserOK = true - polkitCheckAuthorization = func(pid uint32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) { + polkitCheckAuthorization = func(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) { panic("polkit.CheckAuthorization called") } c.Check(cmd.canAccess(get, nil), check.Equals, accessOK) diff --git a/daemon/ucrednet.go b/daemon/ucrednet.go index 5e981e1157b..7ef32b47cab 100644 --- a/daemon/ucrednet.go +++ b/daemon/ucrednet.go @@ -31,22 +31,23 @@ import ( var errNoID = errors.New("no pid/uid found") const ( - ucrednetNoProcess = uint32(0) + ucrednetNoProcess = int32(0) ucrednetNobody = uint32((1 << 32) - 1) ) -func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err error) { +func ucrednetGet(remoteAddr string) (pid int32, uid uint32, socket string, err error) { pid = ucrednetNoProcess uid = ucrednetNobody for _, token := range strings.Split(remoteAddr, ";") { - var v uint64 if strings.HasPrefix(token, "pid=") { - if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil { - pid = uint32(v) + var v int64 + if v, err = strconv.ParseInt(token[4:], 10, 32); err == nil { + pid = int32(v) } else { break } } else if strings.HasPrefix(token, "uid=") { + var v uint64 if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil { uid = uint32(v) } else { @@ -65,26 +66,35 @@ func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err return pid, uid, socket, err } +type ucrednet struct { + pid int32 + uid uint32 + socket string +} + +func (un *ucrednet) String() string { + if un == nil { + return "pid=;uid=;socket=;" + } + return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.pid, un.uid, un.socket) +} + type ucrednetAddr struct { net.Addr - pid string - uid string - socket string + *ucrednet } func (wa *ucrednetAddr) String() string { - return fmt.Sprintf("pid=%s;uid=%s;socket=%s;%s", wa.pid, wa.uid, wa.socket, wa.Addr) + return wa.ucrednet.String() } type ucrednetConn struct { net.Conn - pid string - uid string - socket string + *ucrednet } func (wc *ucrednetConn) RemoteAddr() net.Addr { - return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.pid, wc.uid, wc.socket} + return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.ucrednet} } type ucrednetListener struct{ net.Listener } @@ -97,7 +107,7 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) { return nil, err } - var pid, uid, socket string + var unet *ucrednet if ucon, ok := con.(*net.UnixConn); ok { f, err := ucon.File() if err != nil { @@ -111,10 +121,12 @@ func (wl *ucrednetListener) Accept() (net.Conn, error) { return nil, err } - pid = strconv.FormatUint(uint64(ucred.Pid), 10) - uid = strconv.FormatUint(uint64(ucred.Uid), 10) - socket = ucon.LocalAddr().String() + unet = &ucrednet{ + pid: ucred.Pid, + uid: ucred.Uid, + socket: ucon.LocalAddr().String(), + } } - return &ucrednetConn{con, pid, uid, socket}, err + return &ucrednetConn{con, unet}, nil } diff --git a/daemon/ucrednet_test.go b/daemon/ucrednet_test.go index 83511a84527..aa3080cd01a 100644 --- a/daemon/ucrednet_test.go +++ b/daemon/ucrednet_test.go @@ -75,7 +75,7 @@ func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) { remoteAddr := conn.RemoteAddr().String() c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*") pid, uid, _, err := ucrednetGet(remoteAddr) - c.Check(pid, check.Equals, uint32(100)) + c.Check(pid, check.Equals, int32(100)) c.Check(uid, check.Equals, uint32(42)) c.Check(err, check.IsNil) } @@ -146,14 +146,14 @@ func (s *ucrednetSuite) TestUcredErrors(c *check.C) { func (s *ucrednetSuite) TestGetNoUid(c *check.C) { pid, uid, _, err := ucrednetGet("pid=100;uid=;") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, uint32(100)) + c.Check(pid, check.Equals, int32(100)) c.Check(uid, check.Equals, ucrednetNobody) } func (s *ucrednetSuite) TestGetBadUid(c *check.C) { pid, uid, _, err := ucrednetGet("pid=100;uid=hello;") c.Check(err, check.NotNil) - c.Check(pid, check.Equals, uint32(100)) + c.Check(pid, check.Equals, int32(100)) c.Check(uid, check.Equals, ucrednetNobody) } @@ -174,7 +174,7 @@ func (s *ucrednetSuite) TestGetNothing(c *check.C) { func (s *ucrednetSuite) TestGet(c *check.C) { pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket") c.Check(err, check.IsNil) - c.Check(pid, check.Equals, uint32(100)) + c.Check(pid, check.Equals, int32(100)) c.Check(uid, check.Equals, uint32(42)) c.Check(socket, check.Equals, "/run/snap.socket") } diff --git a/polkit/authority.go b/polkit/authority.go index 8f5f797ed72..db08f6d365c 100644 --- a/polkit/authority.go +++ b/polkit/authority.go @@ -64,7 +64,7 @@ func checkAuthorization(subject authSubject, actionId string, details map[string // CheckAuthorization queries polkit to determine whether a process is // authorized to perform an action. -func CheckAuthorization(pid uint32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) { +func CheckAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags CheckFlags) (bool, error) { subject := authSubject{ Kind: "unix-process", Details: make(map[string]dbus.Variant), diff --git a/polkit/pid_start_time.go b/polkit/pid_start_time.go index c873c90f374..93aca0fffba 100644 --- a/polkit/pid_start_time.go +++ b/polkit/pid_start_time.go @@ -28,7 +28,7 @@ import ( ) // getStartTimeForPid determines the start time for a given process ID -func getStartTimeForPid(pid uint32) (uint64, error) { +func getStartTimeForPid(pid int32) (uint64, error) { filename := fmt.Sprintf("/proc/%d/stat", pid) return getStartTimeForProcStatFile(filename) } diff --git a/polkit/pid_start_time_test.go b/polkit/pid_start_time_test.go index 96faafe8aaa..7bc1107d1b7 100644 --- a/polkit/pid_start_time_test.go +++ b/polkit/pid_start_time_test.go @@ -39,7 +39,7 @@ var _ = check.Suite(&polkitSuite{}) func (s *polkitSuite) TestGetStartTime(c *check.C) { pid := os.Getpid() - startTime, err := getStartTimeForPid(uint32(pid)) + startTime, err := getStartTimeForPid(int32(pid)) c.Assert(err, check.IsNil) c.Check(startTime, check.Not(check.Equals), uint64(0)) } @@ -54,7 +54,7 @@ func (s *polkitSuite) TestGetStartTimeBadPid(c *check.C) { pid += 1 } - startTime, err := getStartTimeForPid(uint32(pid)) + startTime, err := getStartTimeForPid(int32(pid)) c.Assert(err, check.ErrorMatches, "open .*: no such file or directory") c.Check(startTime, check.Equals, uint64(0)) }