Skip to content

Commit

Permalink
Red Knot - Add symbol flags (#11134)
Browse files Browse the repository at this point in the history
* Adds `Symbol.flag` bitfield. Populates it from (the three renamed)
`add_or_update_symbol*` methods.
* Currently there are these flags supported:
  * `IS_DEFINED` is set in a scope where a variable is defined.
* `IS_USED` is set in a scope where a variable is referenced. (To have
both this and `IS_DEFINED` would require two separate appearances of a
variable in the same scope-- one def and one use.)
* `MARKED_GLOBAL` and `MARKED_NONLOCAL` are **not yet implemented**.
(*TODO: While traversing, if you find these declarations, add these
flags to the variable.*)
* Adds `Symbol.kind` field (commented) and the data structure which will
populate it: `Kind` which is an enum of freevar, cellvar,
implicit_global, and implicit_local. **Not yet populated**. (*TODO: a
second pass over the scope (or the ast?) will observe the
`MARKED_GLOBAL` and `MARKED_NONLOCAL` flags to populate this field. When
that's added, we'll uncomment the field.*)
* Adds a few tests that the `IS_DEFINED` and `IS_USED` fields are
correctly set and/or merged:
* Unit test that subsequent calls to `add_or_update_symbol` will merge
the flag arguments.
* Unit test that in the statement `x = foo`, the variable `foo` is
considered used but not defined.
* Unit test that in the statement `from bar import foo`, the variable
`foo` is considered defined but not used.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
  • Loading branch information
plredmond and carljm authored Apr 30, 2024
1 parent ce030a4 commit c391c8b
Showing 1 changed file with 104 additions and 23 deletions.
127 changes: 104 additions & 23 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::NonZeroU32;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use bitflags::bitflags;
use hashbrown::hash_map::{Keys, RawEntryMut};
use rustc_hash::{FxHashMap, FxHasher};

Expand Down Expand Up @@ -81,15 +82,52 @@ impl Scope {
}
}

#[derive(Debug)]
pub(crate) enum Kind {
FreeVar,
CellVar,
CellVarAssigned,
ExplicitGlobal,
ImplicitGlobal,
}

bitflags! {
#[derive(Copy,Clone,Debug)]
pub(crate) struct SymbolFlags: u8 {
const IS_USED = 1 << 0;
const IS_DEFINED = 1 << 1;
/// TODO: This flag is not yet set by anything
const MARKED_GLOBAL = 1 << 2;
/// TODO: This flag is not yet set by anything
const MARKED_NONLOCAL = 1 << 3;
}
}

#[derive(Debug)]
pub(crate) struct Symbol {
name: Name,
flags: SymbolFlags,
// kind: Kind,
}

impl Symbol {
pub(crate) fn name(&self) -> &str {
self.name.as_str()
}

/// Is the symbol used in its containing scope?
pub(crate) fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
}

/// Is the symbol defined in its containing scope?
pub(crate) fn is_defined(&self) -> bool {
self.flags.contains(SymbolFlags::IS_DEFINED)
}

// TODO: implement Symbol.kind 2-pass analysis to categorize as: free-var, cell-var,
// explicit-global, implicit-global and implement Symbol.kind by modifying the preorder
// traversal code
}

// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
Expand Down Expand Up @@ -271,7 +309,12 @@ impl SymbolTable {
.flat_map(|(sym_id, defs)| defs.iter().map(move |def| (*sym_id, def)))
}

fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId {
fn add_or_update_symbol(
&mut self,
scope_id: ScopeId,
name: &str,
flags: SymbolFlags,
) -> SymbolId {
let hash = SymbolTable::hash_name(name);
let scope = &mut self.scopes_by_id[scope_id];
let name = Name::new(name);
Expand All @@ -282,9 +325,14 @@ impl SymbolTable {
.from_hash(hash, |existing| self.symbols_by_id[*existing].name == name);

match entry {
RawEntryMut::Occupied(entry) => *entry.key(),
RawEntryMut::Occupied(entry) => {
if let Some(symbol) = self.symbols_by_id.get_mut(*entry.key()) {
symbol.flags.insert(flags);
};
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let id = self.symbols_by_id.push(Symbol { name });
let id = self.symbols_by_id.push(Symbol { name, flags });
entry.insert_with_hasher(hash, id, (), |_| hash);
id
}
Expand Down Expand Up @@ -392,12 +440,17 @@ struct SymbolTableBuilder {
}

impl SymbolTableBuilder {
fn add_symbol(&mut self, identifier: &str) -> SymbolId {
self.table.add_symbol_to_scope(self.cur_scope(), identifier)
fn add_or_update_symbol(&mut self, identifier: &str, flags: SymbolFlags) -> SymbolId {
self.table
.add_or_update_symbol(self.cur_scope(), identifier, flags)
}

fn add_symbol_with_def(&mut self, identifier: &str, definition: Definition) -> SymbolId {
let symbol_id = self.add_symbol(identifier);
fn add_or_update_symbol_with_def(
&mut self,
identifier: &str,
definition: Definition,
) -> SymbolId {
let symbol_id = self.add_or_update_symbol(identifier, SymbolFlags::IS_DEFINED);
self.table
.defs
.entry(symbol_id)
Expand Down Expand Up @@ -439,7 +492,7 @@ impl SymbolTableBuilder {
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
};
self.add_symbol(name);
self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED);
}
}
nested(self);
Expand All @@ -452,10 +505,16 @@ impl SymbolTableBuilder {
impl PreorderVisitor<'_> for SymbolTableBuilder {
fn visit_expr(&mut self, expr: &ast::Expr) {
if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr {
self.add_symbol(id);
if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) {
let flags = match ctx {
ast::ExprContext::Load => SymbolFlags::IS_USED,
ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
ast::ExprContext::Invalid => SymbolFlags::empty(),
};
self.add_or_update_symbol(id, flags);
if flags.contains(SymbolFlags::IS_DEFINED) {
if let Some(curdef) = self.current_definition.clone() {
self.add_symbol_with_def(id, curdef);
self.add_or_update_symbol_with_def(id, curdef);
}
}
}
Expand All @@ -467,7 +526,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
match stmt {
ast::Stmt::ClassDef(node) => {
let def = Definition::ClassDef(TypedNodeKey::from_node(node));
self.add_symbol_with_def(&node.name, def);
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class);
ast::visitor::preorder::walk_stmt(builder, stmt);
Expand All @@ -476,7 +535,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
}
ast::Stmt::FunctionDef(node) => {
let def = Definition::FunctionDef(TypedNodeKey::from_node(node));
self.add_symbol_with_def(&node.name, def);
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function);
ast::visitor::preorder::walk_stmt(builder, stmt);
Expand All @@ -496,7 +555,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
let def = Definition::Import(ImportDefinition {
module: module.clone(),
});
self.add_symbol_with_def(symbol_name, def);
self.add_or_update_symbol_with_def(symbol_name, def);
self.table.dependencies.push(Dependency::Module(module));
}
}
Expand All @@ -519,7 +578,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
name: Name::new(&alias.name.id),
level: *level,
});
self.add_symbol_with_def(symbol_name, def);
self.add_or_update_symbol_with_def(symbol_name, def);
}

let dependency = if let Some(module) = module {
Expand Down Expand Up @@ -578,7 +637,7 @@ mod tests {
use crate::parse::Parsed;
use crate::symbols::ScopeKind;

use super::{SymbolId, SymbolIterator, SymbolTable};
use super::{SymbolFlags, SymbolId, SymbolIterator, SymbolTable};

mod from_ast {
use super::*;
Expand Down Expand Up @@ -662,6 +721,13 @@ mod tests {
.len(),
1
);
assert!(
table.root_symbol_id_by_name("foo").is_some_and(|sid| {
let s = sid.symbol(&table);
s.is_defined() || !s.is_used()
}),
"symbols that are defined get the defined flag"
);
}

#[test]
Expand All @@ -675,6 +741,13 @@ mod tests {
.len(),
1
);
assert!(
table.root_symbol_id_by_name("foo").is_some_and(|sid| {
let s = sid.symbol(&table);
!s.is_defined() && s.is_used()
}),
"a symbol used but not defined in a scope should have only the used flag"
);
}

#[test]
Expand Down Expand Up @@ -800,6 +873,12 @@ mod tests {
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope.name(), "C");
assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]);
assert!(
table
.symbol_by_name(ann_scope_id, "T")
.is_some_and(|s| s.is_defined() && !s.is_used()),
"type parameters are defined by the scope that introduces them"
);
let scopes = table.child_scope_ids_of(ann_scope_id);
assert_eq!(scopes.len(), 1);
let func_scope_id = scopes[0];
Expand All @@ -814,27 +893,29 @@ mod tests {
fn insert_same_name_symbol_twice() {
let mut table = SymbolTable::new();
let root_scope_id = SymbolTable::root_scope_id();
let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo");
let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "foo");
let symbol_id_1 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_DEFINED);
let symbol_id_2 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_USED);
assert_eq!(symbol_id_1, symbol_id_2);
assert!(symbol_id_1.symbol(&table).is_used(), "flags must merge");
assert!(symbol_id_1.symbol(&table).is_defined(), "flags must merge");
}

#[test]
fn insert_different_named_symbols() {
let mut table = SymbolTable::new();
let root_scope_id = SymbolTable::root_scope_id();
let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo");
let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "bar");
let symbol_id_1 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty());
let symbol_id_2 = table.add_or_update_symbol(root_scope_id, "bar", SymbolFlags::empty());
assert_ne!(symbol_id_1, symbol_id_2);
}

#[test]
fn add_child_scope_with_symbol() {
let mut table = SymbolTable::new();
let root_scope_id = SymbolTable::root_scope_id();
let foo_symbol_top = table.add_symbol_to_scope(root_scope_id, "foo");
let foo_symbol_top = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty());
let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class);
let foo_symbol_inner = table.add_symbol_to_scope(c_scope, "foo");
let foo_symbol_inner = table.add_or_update_symbol(c_scope, "foo", SymbolFlags::empty());
assert_ne!(foo_symbol_top, foo_symbol_inner);
}

Expand All @@ -851,7 +932,7 @@ mod tests {
fn symbol_from_id() {
let mut table = SymbolTable::new();
let root_scope_id = SymbolTable::root_scope_id();
let foo_symbol_id = table.add_symbol_to_scope(root_scope_id, "foo");
let foo_symbol_id = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty());
let symbol = foo_symbol_id.symbol(&table);
assert_eq!(symbol.name.as_str(), "foo");
}
Expand Down

0 comments on commit c391c8b

Please sign in to comment.