Skip to content

Commit 895e14c

Browse files
feat(completions): complete roles (#410)
1 parent 9144ea1 commit 895e14c

File tree

16 files changed

+1458
-245
lines changed

16 files changed

+1458
-245
lines changed

crates/pgt_completions/src/complete.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use crate::{
55
context::CompletionContext,
66
item::CompletionItem,
77
providers::{
8-
complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables,
8+
complete_columns, complete_functions, complete_policies, complete_roles, complete_schemas,
9+
complete_tables,
910
},
1011
sanitization::SanitizedCompletionParams,
1112
};
@@ -36,6 +37,7 @@ pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
3637
complete_columns(&ctx, &mut builder);
3738
complete_schemas(&ctx, &mut builder);
3839
complete_policies(&ctx, &mut builder);
40+
complete_roles(&ctx, &mut builder);
3941

4042
builder.finish()
4143
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
use std::iter::Peekable;
2+
3+
use pgt_text_size::{TextRange, TextSize};
4+
5+
pub(crate) struct TokenNavigator {
6+
tokens: Peekable<std::vec::IntoIter<WordWithIndex>>,
7+
pub previous_token: Option<WordWithIndex>,
8+
pub current_token: Option<WordWithIndex>,
9+
}
10+
11+
impl TokenNavigator {
12+
pub(crate) fn next_matches(&mut self, options: &[&str]) -> bool {
13+
self.tokens
14+
.peek()
15+
.is_some_and(|c| options.contains(&c.get_word_without_quotes().as_str()))
16+
}
17+
18+
pub(crate) fn prev_matches(&self, options: &[&str]) -> bool {
19+
self.previous_token
20+
.as_ref()
21+
.is_some_and(|t| options.contains(&t.get_word_without_quotes().as_str()))
22+
}
23+
24+
pub(crate) fn advance(&mut self) -> Option<WordWithIndex> {
25+
// we can't peek back n an iterator, so we'll have to keep track manually.
26+
self.previous_token = self.current_token.take();
27+
self.current_token = self.tokens.next();
28+
self.current_token.clone()
29+
}
30+
}
31+
32+
impl From<Vec<WordWithIndex>> for TokenNavigator {
33+
fn from(tokens: Vec<WordWithIndex>) -> Self {
34+
TokenNavigator {
35+
tokens: tokens.into_iter().peekable(),
36+
previous_token: None,
37+
current_token: None,
38+
}
39+
}
40+
}
41+
42+
pub(crate) trait CompletionStatementParser: Sized {
43+
type Context: Default;
44+
const NAME: &'static str;
45+
46+
fn looks_like_matching_stmt(sql: &str) -> bool;
47+
fn parse(self) -> Self::Context;
48+
fn make_parser(tokens: Vec<WordWithIndex>, cursor_position: usize) -> Self;
49+
50+
fn get_context(sql: &str, cursor_position: usize) -> Self::Context {
51+
assert!(
52+
Self::looks_like_matching_stmt(sql),
53+
"Using {} for a wrong statement! Developer Error!",
54+
Self::NAME
55+
);
56+
57+
match sql_to_words(sql) {
58+
Ok(tokens) => {
59+
let parser = Self::make_parser(tokens, cursor_position);
60+
parser.parse()
61+
}
62+
Err(_) => Self::Context::default(),
63+
}
64+
}
65+
}
66+
67+
pub(crate) fn schema_and_table_name(token: &WordWithIndex) -> (String, Option<String>) {
68+
let word = token.get_word_without_quotes();
69+
let mut parts = word.split('.');
70+
71+
(
72+
parts.next().unwrap().into(),
73+
parts.next().map(|tb| tb.into()),
74+
)
75+
}
76+
77+
#[derive(Clone, Debug, PartialEq, Eq)]
78+
pub(crate) struct WordWithIndex {
79+
word: String,
80+
start: usize,
81+
end: usize,
82+
}
83+
84+
impl WordWithIndex {
85+
pub(crate) fn is_under_cursor(&self, cursor_pos: usize) -> bool {
86+
self.start <= cursor_pos && self.end > cursor_pos
87+
}
88+
89+
pub(crate) fn get_range(&self) -> TextRange {
90+
let start: u32 = self.start.try_into().expect("Text too long");
91+
let end: u32 = self.end.try_into().expect("Text too long");
92+
TextRange::new(TextSize::from(start), TextSize::from(end))
93+
}
94+
95+
pub(crate) fn get_word_without_quotes(&self) -> String {
96+
self.word.replace('"', "")
97+
}
98+
99+
pub(crate) fn get_word(&self) -> String {
100+
self.word.clone()
101+
}
102+
}
103+
104+
/// Note: A policy name within quotation marks will be considered a single word.
105+
pub(crate) fn sql_to_words(sql: &str) -> Result<Vec<WordWithIndex>, String> {
106+
let mut words = vec![];
107+
108+
let mut start_of_word: Option<usize> = None;
109+
let mut current_word = String::new();
110+
let mut in_quotation_marks = false;
111+
112+
for (current_position, current_char) in sql.char_indices() {
113+
if (current_char.is_ascii_whitespace() || current_char == ';')
114+
&& !current_word.is_empty()
115+
&& start_of_word.is_some()
116+
&& !in_quotation_marks
117+
{
118+
words.push(WordWithIndex {
119+
word: current_word,
120+
start: start_of_word.unwrap(),
121+
end: current_position,
122+
});
123+
124+
current_word = String::new();
125+
start_of_word = None;
126+
} else if (current_char.is_ascii_whitespace() || current_char == ';')
127+
&& current_word.is_empty()
128+
{
129+
// do nothing
130+
} else if current_char == '"' && start_of_word.is_none() {
131+
in_quotation_marks = true;
132+
current_word.push(current_char);
133+
start_of_word = Some(current_position);
134+
} else if current_char == '"' && start_of_word.is_some() {
135+
current_word.push(current_char);
136+
in_quotation_marks = false;
137+
} else if start_of_word.is_some() {
138+
current_word.push(current_char)
139+
} else {
140+
start_of_word = Some(current_position);
141+
current_word.push(current_char);
142+
}
143+
}
144+
145+
if let Some(start_of_word) = start_of_word {
146+
if !current_word.is_empty() {
147+
words.push(WordWithIndex {
148+
word: current_word,
149+
start: start_of_word,
150+
end: sql.len(),
151+
});
152+
}
153+
}
154+
155+
if in_quotation_marks {
156+
Err("String was not closed properly.".into())
157+
} else {
158+
Ok(words)
159+
}
160+
}
161+
162+
#[cfg(test)]
163+
mod tests {
164+
use crate::context::base_parser::{WordWithIndex, sql_to_words};
165+
166+
#[test]
167+
fn determines_positions_correctly() {
168+
let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string();
169+
170+
let words = sql_to_words(query.as_str()).unwrap();
171+
172+
assert_eq!(words[0], to_word("create", 1, 7));
173+
assert_eq!(words[1], to_word("policy", 8, 14));
174+
assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28));
175+
assert_eq!(words[3], to_word("on", 30, 32));
176+
assert_eq!(words[4], to_word("auth.users", 33, 43));
177+
assert_eq!(words[5], to_word("as", 45, 47));
178+
assert_eq!(words[6], to_word("permissive", 48, 58));
179+
assert_eq!(words[7], to_word("for", 60, 63));
180+
assert_eq!(words[8], to_word("select", 64, 70));
181+
assert_eq!(words[9], to_word("to", 73, 75));
182+
assert_eq!(words[10], to_word("public", 78, 84));
183+
assert_eq!(words[11], to_word("using", 87, 92));
184+
assert_eq!(words[12], to_word("(true)", 93, 99));
185+
}
186+
187+
#[test]
188+
fn handles_schemas_in_quotation_marks() {
189+
let query = r#"grant select on "public"."users""#.to_string();
190+
191+
let words = sql_to_words(query.as_str()).unwrap();
192+
193+
assert_eq!(words[0], to_word("grant", 0, 5));
194+
assert_eq!(words[1], to_word("select", 6, 12));
195+
assert_eq!(words[2], to_word("on", 13, 15));
196+
assert_eq!(words[3], to_word(r#""public"."users""#, 16, 32));
197+
}
198+
199+
fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex {
200+
WordWithIndex {
201+
word: word.into(),
202+
start,
203+
end,
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)