Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion schemainspect/pg/obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
COLLATIONS_QUERY = resource_text("sql/collations.sql")
COLLATIONS_QUERY_9 = resource_text("sql/collations9.sql")
RLSPOLICIES_QUERY = resource_text("sql/rlspolicies.sql")
ROLES_QUERY = resource_text("sql/roles.sql")
MEMBERSHIPS_QUERY = resource_text("sql/memberships.sql")


class InspectedSelectable(BaseInspectedSelectable):
Expand Down Expand Up @@ -862,6 +864,104 @@ def __eq__(self, other):
return all(equalities)


class InspectedRole(Inspected):
def __init__(self, name, superuser, createdb, inherit, login, replication, bypassrls,
connection_limit, password, valid_until):
self.name = name
self.superuser = superuser
self.createdb = createdb
self.inherit = inherit
self.login = login
self.replication = replication
self.bypassrls = bypassrls
self.connection_limit = connection_limit
self.password = password
self.valid_until = valid_until

@property
def drop_statement(self):
return "drop role {};".format(self.name)

@property
def create_statement(self):
return "create role {} with {} {} {} {} {} {} connection limit {} password {} {};".format(
self.name,
self.superuser,
self.createdb,
self.inherit,
self.login,
self.replication,
self.bypassrls,
self.connection_limit,
("'" + self.password + "'") if self.password else 'NULL',
" valid until {}".format(self.valid_until) if self.valid_until else "",
)

@property
def update_statement(self):
return "alter role {} with {} {} {} {} {} {} connection limit {} password {} {};".format(
self.name,
self.superuser,
self.createdb,
self.inherit,
self.login,
self.replication,
self.bypassrls,
self.connection_limit,
("'" + self.password + "'") if self.password else 'NULL',
" valid until {}".format(self.valid_until) if self.valid_until else "",
)

def __eq__(self, other):
equalities = (
self.name == other.name,
self.superuser == other.superuser,
self.createdb == other.createdb,
self.inherit == other.inherit,
self.login == other.login,
self.replication == other.replication,
self.bypassrls == other.bypassrls,
self.connection_limit == other.connection_limit,
self.password == other.password,
self.valid_until == other.valid_until,
)
return all(equalities)


class InspectedMembership(Inspected):
def __init__(self, roleid, member, admin_option, grantor):
self.roleid = roleid
self.member = member
self.admin_option = admin_option
self.grantor = grantor

@property
def create_statement(self):
return "grant {} to {} {} granted by {};".format(
self.roleid, self.member,
" with admin option " if self.admin_option else "", self.grantor
)

@property
def drop_statement(self):
return "revoke {} from {};".format(
self.roleid, self.member
)

@property
def key(self):
return self.roleid, self.member, self.admin_option

def __eq__(self, other):
equalities = (
self.roleid == other.roleid,
self.member == other.member,
self.admin_option == other.admin_option,
self.grantor == other.grantor,
)
return all(equalities)


class PostgreSQL(DBInspector):
def __init__(self, c, include_internal=False):
pg_version = c.dialect.server_version_info[0]
Expand Down Expand Up @@ -893,7 +993,8 @@ def processed(q):
self.SCHEMAS_QUERY = processed(SCHEMAS_QUERY)
self.PRIVILEGES_QUERY = processed(PRIVILEGES_QUERY)
self.TRIGGERS_QUERY = processed(TRIGGERS_QUERY)

self.ROLES_QUERY = processed(ROLES_QUERY)
self.MEMBERSHIPS_QUERY = processed(MEMBERSHIPS_QUERY)
super(PostgreSQL, self).__init__(c, include_internal)

def load_all(self):
Expand All @@ -911,6 +1012,8 @@ def load_all(self):
self.load_rlspolicies()
self.load_types()
self.load_domains()
self.load_roles()
self.load_memberships()

def load_schemas(self):
q = self.c.execute(self.SCHEMAS_QUERY)
Expand Down Expand Up @@ -940,6 +1043,42 @@ def load_rlspolicies(self):

self.rlspolicies = od((p.key, p) for p in rlspolicies)

def load_roles(self):
q = self.c.execute(self.ROLES_QUERY)

roles = [
InspectedRole(
name=r.name,
superuser=r.superuser,
createdb=r.createdb,
inherit=r.inherit,
login=r.login,
replication=r.replication,
bypassrls=r.bypassrls,
connection_limit=r.connection_limit,
password=r.password,
valid_until=r.valid_until,
)
for r in q
]

self.roles = od((r.name, r) for r in roles)

def load_memberships(self):
q = self.c.execute(self.MEMBERSHIPS_QUERY)

memberships = [
InspectedMembership(
roleid=m.roleid,
member=m.member,
admin_option=m.admin_option,
grantor=m.grantor,
)
for m in q
]

self.memberships = od((m.key, m) for m in memberships)

def load_collations(self):
q = self.c.execute(self.COLLATIONS_QUERY)
collations = [
Expand Down
12 changes: 12 additions & 0 deletions schemainspect/pg/sql/memberships.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
select
ur.rolname as roleid,
um.rolname as member,
a.admin_option,
ug.rolname as grantor
from pg_auth_members a
left join pg_authid ur on ur.oid = a.roleid
left join pg_authid ug on ug.oid = a.grantor
left join pg_authid um on um.oid = a.member
where
not (ur.rolname ~ '^pg_' and um.rolname ~ '^pg_')
order by 1, 2, 3;
46 changes: 46 additions & 0 deletions schemainspect/pg/sql/roles.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
select
oid,
rolname as name,
case rolsuper when true then
'SUPERUSER'
else
'NOSUPERUSER'
end as superuser,
case rolinherit when true then
'INHERIT'
else
'NOINHERIT'
end as inherit,
case rolcreaterole when true then
'CREATEROLE'
else
'NOCREATEROLE'
end as createrole,
case rolcreatedb when true then
'CREATEDB'
else
'NOCREATEDB'
end as createdb,
case rolcanlogin when true then
'LOGIN'
else
'NOLOGIN'
end as login,
case rolreplication when true then
'REPLICATION'
else
'NOREPLICATION'
end as replication,
case rolbypassrls when true then
'BYPASSRLS'
else
'NOBYPASSRLS'
end as bypassrls,
rolconnlimit as connection_limit,
rolpassword as password,
rolvaliduntil as valid_until
from pg_authid
where
rolsuper = false
and rolname not like 'pg_%'
order by rolname;