cqlsh_rs/
completer.rs

1//! Tab completion for the CQL shell.
2//!
3//! Implements rustyline's `Completer`, `Helper`, `Hinter`, `Highlighter`, and
4//! `Validator` traits to provide context-aware tab completion in the REPL.
5//! Uses the unified CQL lexer for grammar-aware context detection.
6//! Completions include CQL keywords, shell commands, schema objects (keyspaces,
7//! tables, columns), consistency levels, DESCRIBE sub-commands, and file paths.
8
9use std::borrow::Cow;
10use std::sync::Arc;
11
12use rustyline::completion::{Completer, Pair};
13use rustyline::highlight::Highlighter;
14use rustyline::hint::Hinter;
15use rustyline::validate::Validator;
16use rustyline::{Context, Helper};
17use tokio::runtime::Handle;
18use tokio::sync::RwLock;
19
20use crate::colorizer::CqlColorizer;
21use crate::cql_lexer::{self, GrammarContext, TokenKind};
22use crate::schema_cache::SchemaCache;
23
24/// CQL keywords that can start a statement.
25const CQL_KEYWORDS: &[&str] = &[
26    "ALTER", "APPLY", "BATCH", "BEGIN", "CREATE", "DELETE", "DESCRIBE", "DROP", "GRANT", "INSERT",
27    "LIST", "REVOKE", "SELECT", "TRUNCATE", "UPDATE", "USE",
28];
29
30/// CQL clause keywords used within statements.
31const CQL_CLAUSE_KEYWORDS: &[&str] = &[
32    "ADD",
33    "AGGREGATE",
34    "ALL",
35    "ALLOW",
36    "AND",
37    "AS",
38    "ASC",
39    "AUTHORIZE",
40    "BATCH",
41    "BY",
42    "CALLED",
43    "CLUSTERING",
44    "COLUMN",
45    "COMPACT",
46    "CONTAINS",
47    "COUNT",
48    "CUSTOM",
49    "DELETE",
50    "DESC",
51    "DESCRIBE",
52    "DISTINCT",
53    "DROP",
54    "ENTRIES",
55    "EXECUTE",
56    "EXISTS",
57    "FILTERING",
58    "FINALFUNC",
59    "FROM",
60    "FROZEN",
61    "FULL",
62    "FUNCTION",
63    "FUNCTIONS",
64    "IF",
65    "IN",
66    "INDEX",
67    "INITCOND",
68    "INPUT",
69    "INSERT",
70    "INTO",
71    "IS",
72    "JSON",
73    "KEY",
74    "KEYS",
75    "KEYSPACE",
76    "KEYSPACES",
77    "LANGUAGE",
78    "LIKE",
79    "LIMIT",
80    "LIST",
81    "LOGIN",
82    "MAP",
83    "MATERIALIZED",
84    "MODIFY",
85    "NAMESPACE",
86    "NORECURSIVE",
87    "NOT",
88    "NULL",
89    "OF",
90    "ON",
91    "OR",
92    "ORDER",
93    "PARTITION",
94    "PASSWORD",
95    "PER",
96    "PERMISSION",
97    "PERMISSIONS",
98    "PRIMARY",
99    "RENAME",
100    "REPLACE",
101    "RETURNS",
102    "REVOKE",
103    "SCHEMA",
104    "SELECT",
105    "SET",
106    "SFUNC",
107    "STATIC",
108    "STORAGE",
109    "STYPE",
110    "SUPERUSER",
111    "TABLE",
112    "TABLES",
113    "TEXT",
114    "TIMESTAMP",
115    "TO",
116    "TOKEN",
117    "TRIGGER",
118    "TRUNCATE",
119    "TTL",
120    "TUPLE",
121    "TYPE",
122    "UNLOGGED",
123    "UPDATE",
124    "USER",
125    "USERS",
126    "USING",
127    "VALUES",
128    "VIEW",
129    "WHERE",
130    "WITH",
131    "WRITETIME",
132];
133
134/// Built-in shell commands.
135const SHELL_COMMANDS: &[&str] = &[
136    "CAPTURE",
137    "CLEAR",
138    "CLS",
139    "CONSISTENCY",
140    "COPY",
141    "DESCRIBE",
142    "DESC",
143    "EXIT",
144    "EXPAND",
145    "HELP",
146    "LOGIN",
147    "PAGING",
148    "QUIT",
149    "SERIAL",
150    "SHOW",
151    "SOURCE",
152    "TRACING",
153];
154
155/// CQL consistency levels.
156const CONSISTENCY_LEVELS: &[&str] = &[
157    "ALL",
158    "ANY",
159    "EACH_QUORUM",
160    "LOCAL_ONE",
161    "LOCAL_QUORUM",
162    "LOCAL_SERIAL",
163    "ONE",
164    "QUORUM",
165    "SERIAL",
166    "THREE",
167    "TWO",
168];
169
170/// DESCRIBE sub-commands.
171const DESCRIBE_SUB_COMMANDS: &[&str] = &[
172    "AGGREGATE",
173    "AGGREGATES",
174    "CLUSTER",
175    "FULL",
176    "FUNCTION",
177    "FUNCTIONS",
178    "INDEX",
179    "KEYSPACE",
180    "KEYSPACES",
181    "MATERIALIZED",
182    "SCHEMA",
183    "TABLE",
184    "TABLES",
185    "TYPE",
186    "TYPES",
187];
188
189/// CQL data types for CREATE TABLE column definitions.
190#[allow(dead_code)] // Will be used when CqlType completion context is implemented
191const CQL_TYPES: &[&str] = &[
192    "ascii",
193    "bigint",
194    "blob",
195    "boolean",
196    "counter",
197    "date",
198    "decimal",
199    "double",
200    "duration",
201    "float",
202    "frozen",
203    "inet",
204    "int",
205    "list",
206    "map",
207    "set",
208    "smallint",
209    "text",
210    "time",
211    "timestamp",
212    "timeuuid",
213    "tinyint",
214    "tuple",
215    "uuid",
216    "varchar",
217    "varint",
218];
219
220/// Detected completion context based on the input up to the cursor.
221#[derive(Debug, PartialEq)]
222enum CompletionContext {
223    /// At the start of input — complete with statement keywords and shell commands.
224    Empty,
225    /// After a statement keyword — complete with clause keywords.
226    ClauseKeyword,
227    /// After FROM, INTO, UPDATE, etc. — complete with table names.
228    TableName { keyspace: Option<String> },
229    /// After SELECT ... FROM table WHERE — complete with column names.
230    ColumnName {
231        keyspace: Option<String>,
232        table: String,
233    },
234    /// After CONSISTENCY — complete with consistency levels.
235    ConsistencyLevel,
236    /// After DESCRIBE/DESC — complete with sub-commands or schema names.
237    DescribeTarget,
238    /// After SOURCE or CAPTURE — complete with file paths.
239    FilePath,
240    /// After USE — complete with keyspace names.
241    KeyspaceName,
242}
243
244/// Tab completer for the CQL shell REPL.
245pub struct CqlCompleter {
246    /// Shared schema cache for keyspace/table/column lookups.
247    cache: Arc<RwLock<SchemaCache>>,
248    /// Current keyspace (shared with session via USE command).
249    current_keyspace: Arc<RwLock<Option<String>>>,
250    /// Tokio runtime handle for blocking cache reads inside sync complete().
251    rt_handle: Handle,
252    /// Syntax colorizer for highlighting.
253    colorizer: CqlColorizer,
254}
255
256impl CqlCompleter {
257    /// Create a new completer with shared cache and keyspace state.
258    pub fn new(
259        cache: Arc<RwLock<SchemaCache>>,
260        current_keyspace: Arc<RwLock<Option<String>>>,
261        rt_handle: Handle,
262        color_enabled: bool,
263    ) -> Self {
264        Self {
265            cache,
266            current_keyspace,
267            rt_handle,
268            colorizer: CqlColorizer::new(color_enabled),
269        }
270    }
271
272    /// Detect completion context from the input line up to the cursor position.
273    ///
274    /// Uses the unified CQL lexer for grammar-aware context detection.
275    fn detect_context(&self, line: &str, pos: usize) -> CompletionContext {
276        let before_cursor = &line[..pos];
277        let tokens = cql_lexer::tokenize(before_cursor);
278        let sig: Vec<_> = cql_lexer::significant_tokens(&tokens);
279
280        if sig.is_empty() {
281            return CompletionContext::Empty;
282        }
283
284        // Special case: SOURCE/CAPTURE always means file path completion
285        let first_upper = sig[0].text.to_uppercase();
286        if first_upper == "SOURCE" || first_upper == "CAPTURE" {
287            return CompletionContext::FilePath;
288        }
289
290        let grammar_ctx = cql_lexer::grammar_context_at_end(before_cursor);
291
292        match grammar_ctx {
293            GrammarContext::Start => CompletionContext::Empty,
294            GrammarContext::ExpectTable => {
295                // Check if the user is typing a qualified name (ks.)
296                let keyspace = self.extract_qualifying_keyspace(&sig);
297                CompletionContext::TableName { keyspace }
298            }
299            GrammarContext::ExpectKeyspace => CompletionContext::KeyspaceName,
300            GrammarContext::ExpectColumn | GrammarContext::ExpectSetClause => {
301                // Find the table name from the token stream
302                let (ks, table) = self.extract_table_from_tokens(&sig);
303                match table {
304                    Some(t) => CompletionContext::ColumnName {
305                        keyspace: ks,
306                        table: t,
307                    },
308                    None => CompletionContext::ClauseKeyword,
309                }
310            }
311            GrammarContext::ExpectConsistencyLevel => CompletionContext::ConsistencyLevel,
312            GrammarContext::ExpectDescribeTarget => CompletionContext::DescribeTarget,
313            GrammarContext::ExpectFilePath => CompletionContext::FilePath,
314            GrammarContext::ExpectQualifiedPart => {
315                // After a dot — offer tables from the keyspace before the dot
316                let keyspace = self.extract_qualifying_keyspace(&sig);
317                CompletionContext::TableName { keyspace }
318            }
319            GrammarContext::ExpectColumnList => {
320                // In SELECT column list — if only first keyword and no space yet, complete keywords
321                if sig.len() == 1 && !before_cursor.ends_with(' ') {
322                    CompletionContext::Empty
323                } else {
324                    CompletionContext::ClauseKeyword
325                }
326            }
327            _ => {
328                // For Start-like contexts: if only one word and still typing, complete keywords
329                if sig.len() == 1 && !before_cursor.ends_with(' ') {
330                    CompletionContext::Empty
331                } else {
332                    CompletionContext::ClauseKeyword
333                }
334            }
335        }
336    }
337
338    /// Extract the keyspace qualifier from a dot-qualified name in the token stream.
339    fn extract_qualifying_keyspace(&self, sig: &[&cql_lexer::Token]) -> Option<String> {
340        // Look for pattern: identifier . (at end of tokens)
341        let len = sig.len();
342        if len >= 2 && sig[len - 1].text == "." {
343            return Some(sig[len - 2].text.clone());
344        }
345        None
346    }
347
348    /// Extract table name from the token stream by finding FROM/INTO/UPDATE <table>.
349    fn extract_table_from_tokens(
350        &self,
351        sig: &[&cql_lexer::Token],
352    ) -> (Option<String>, Option<String>) {
353        for (i, tok) in sig.iter().enumerate() {
354            let upper = tok.text.to_uppercase();
355            if matches!(upper.as_str(), "FROM" | "INTO" | "UPDATE" | "TABLE")
356                && i + 1 < sig.len()
357                && matches!(
358                    sig[i + 1].kind,
359                    TokenKind::Identifier | TokenKind::QuotedIdentifier
360                )
361            {
362                let table = sig[i + 1].text.clone();
363                // Check for qualified name (ks.table)
364                if i + 3 < sig.len() && sig[i + 2].text == "." {
365                    let ks = table;
366                    let tbl = sig[i + 3].text.clone();
367                    return (Some(ks), Some(tbl));
368                }
369                let ks = tokio::task::block_in_place(|| {
370                    self.rt_handle
371                        .block_on(async { self.current_keyspace.read().await.clone() })
372                });
373                return (ks, Some(table));
374            }
375        }
376        (None, None)
377    }
378
379    /// Generate completions for the detected context.
380    fn complete_for_context(&self, ctx: &CompletionContext, prefix: &str) -> Vec<Pair> {
381        let prefix_upper = prefix.to_uppercase();
382
383        match ctx {
384            CompletionContext::Empty => {
385                let mut candidates: Vec<&str> = Vec::new();
386                candidates.extend_from_slice(CQL_KEYWORDS);
387                candidates.extend_from_slice(SHELL_COMMANDS);
388                filter_candidates(&candidates, &prefix_upper, true)
389            }
390            CompletionContext::ClauseKeyword => {
391                filter_candidates(CQL_CLAUSE_KEYWORDS, &prefix_upper, true)
392            }
393            CompletionContext::ConsistencyLevel => {
394                filter_candidates(CONSISTENCY_LEVELS, &prefix_upper, true)
395            }
396            CompletionContext::DescribeTarget => {
397                filter_candidates(DESCRIBE_SUB_COMMANDS, &prefix_upper, true)
398            }
399            CompletionContext::KeyspaceName => {
400                let cache =
401                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
402                let names = cache.keyspace_names();
403                filter_candidates(&names, prefix, false)
404            }
405            CompletionContext::TableName { keyspace } => {
406                let cache =
407                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
408                let ks = keyspace.clone().or_else(|| {
409                    tokio::task::block_in_place(|| {
410                        self.rt_handle
411                            .block_on(async { self.current_keyspace.read().await.clone() })
412                    })
413                });
414                match ks {
415                    Some(ref ks_name) => {
416                        let names = cache.table_names(ks_name);
417                        filter_candidates(&names, prefix, false)
418                    }
419                    None => {
420                        // No keyspace context — offer keyspace names for qualification
421                        let names = cache.keyspace_names();
422                        filter_candidates(&names, prefix, false)
423                    }
424                }
425            }
426            CompletionContext::ColumnName { keyspace, table } => {
427                let cache =
428                    tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
429                let ks = keyspace.clone().or_else(|| {
430                    tokio::task::block_in_place(|| {
431                        self.rt_handle
432                            .block_on(async { self.current_keyspace.read().await.clone() })
433                    })
434                });
435                match ks {
436                    Some(ref ks_name) => {
437                        let names = cache.column_names(ks_name, table);
438                        filter_candidates(&names, prefix, false)
439                    }
440                    None => vec![],
441                }
442            }
443            CompletionContext::FilePath => complete_file_path(prefix),
444        }
445    }
446}
447
448/// Filter candidates by prefix, returning matching `Pair`s.
449fn filter_candidates(candidates: &[&str], prefix: &str, uppercase: bool) -> Vec<Pair> {
450    candidates
451        .iter()
452        .filter(|c| {
453            if uppercase {
454                c.to_uppercase().starts_with(&prefix.to_uppercase())
455            } else {
456                c.starts_with(prefix)
457            }
458        })
459        .map(|c| {
460            let display = if uppercase {
461                c.to_uppercase()
462            } else {
463                c.to_string()
464            };
465            Pair {
466                display: display.clone(),
467                replacement: display,
468            }
469        })
470        .collect()
471}
472
473/// Complete file paths for SOURCE and CAPTURE commands.
474fn complete_file_path(prefix: &str) -> Vec<Pair> {
475    // Strip surrounding quotes if present
476    let path_str = prefix
477        .strip_prefix('\'')
478        .or_else(|| prefix.strip_prefix('"'))
479        .unwrap_or(prefix);
480
481    // Expand ~ to home directory
482    let expanded = if path_str.starts_with('~') {
483        if let Some(home) = dirs::home_dir() {
484            path_str.replacen('~', &home.to_string_lossy(), 1)
485        } else {
486            path_str.to_string()
487        }
488    } else {
489        path_str.to_string()
490    };
491
492    let (dir, file_prefix) = if expanded.ends_with('/') {
493        (expanded.as_str(), "")
494    } else {
495        let path = std::path::Path::new(&expanded);
496        let parent = path
497            .parent()
498            .map(|p| p.to_str().unwrap_or("."))
499            .unwrap_or(".");
500        let file = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
501        (parent, file)
502    };
503
504    let dir_to_read = if dir.is_empty() { "." } else { dir };
505
506    let Ok(entries) = std::fs::read_dir(dir_to_read) else {
507        return vec![];
508    };
509
510    entries
511        .filter_map(|entry| entry.ok())
512        .filter_map(|entry| {
513            let name = entry.file_name().to_string_lossy().to_string();
514            if name.starts_with(file_prefix) {
515                let is_dir = entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false);
516                let suffix = if is_dir { "/" } else { "" };
517                let full = if dir.is_empty() || dir == "." {
518                    format!("{name}{suffix}")
519                } else if dir.ends_with('/') {
520                    format!("{dir}{name}{suffix}")
521                } else {
522                    format!("{dir}/{name}{suffix}")
523                };
524                Some(Pair {
525                    display: name + suffix,
526                    replacement: full,
527                })
528            } else {
529                None
530            }
531        })
532        .collect()
533}
534
535impl Completer for CqlCompleter {
536    type Candidate = Pair;
537
538    fn complete(
539        &self,
540        line: &str,
541        pos: usize,
542        _ctx: &Context<'_>,
543    ) -> rustyline::Result<(usize, Vec<Pair>)> {
544        // block_in_place: complete() is called from within the Tokio runtime (sync rustyline trait)
545        let needs_refresh = tokio::task::block_in_place(|| {
546            self.rt_handle
547                .block_on(async { self.cache.read().await.is_stale() })
548        });
549        if needs_refresh {
550            // Best-effort refresh — don't block on errors
551            tokio::task::block_in_place(|| {
552                self.rt_handle.block_on(async {
553                    // Try to get write lock without blocking other completions
554                    if let Ok(mut cache) = self.cache.try_write() {
555                        // Re-check staleness after acquiring lock
556                        if cache.is_stale() {
557                            // We can't refresh without a session reference here.
558                            // The REPL pre-refreshes the cache; this is a fallback mark.
559                            cache.invalidate();
560                        }
561                    }
562                })
563            });
564        }
565
566        let context = self.detect_context(line, pos);
567
568        // Find the start of the word being completed
569        let before_cursor = &line[..pos];
570        let word_start = before_cursor
571            .rfind(|c: char| c.is_whitespace() || c == '.' || c == '\'' || c == '"')
572            .map(|i| i + 1)
573            .unwrap_or(0);
574        let prefix = &line[word_start..pos];
575
576        let completions = self.complete_for_context(&context, prefix);
577
578        Ok((word_start, completions))
579    }
580}
581
582impl Hinter for CqlCompleter {
583    type Hint = String;
584
585    fn hint(&self, _line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
586        None
587    }
588}
589
590impl Highlighter for CqlCompleter {
591    fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
592        let colored = self.colorizer.colorize_line(line);
593        if colored == line {
594            Cow::Borrowed(line)
595        } else {
596            Cow::Owned(colored)
597        }
598    }
599
600    fn highlight_prompt<'b, 's: 'b, 'p: 'b>(
601        &'s self,
602        prompt: &'p str,
603        _default: bool,
604    ) -> Cow<'b, str> {
605        Cow::Borrowed(prompt)
606    }
607
608    fn highlight_char(
609        &self,
610        _line: &str,
611        _pos: usize,
612        _forced: rustyline::highlight::CmdKind,
613    ) -> bool {
614        // Return true to trigger re-highlighting on every keystroke
615        true
616    }
617}
618
619impl Validator for CqlCompleter {}
620
621impl Helper for CqlCompleter {}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    fn make_completer() -> CqlCompleter {
628        let rt = tokio::runtime::Runtime::new().unwrap();
629        let cache = Arc::new(RwLock::new(SchemaCache::new()));
630        let current_ks = Arc::new(RwLock::new(None::<String>));
631        CqlCompleter::new(cache, current_ks, rt.handle().clone(), false)
632    }
633
634    #[test]
635    fn completer_can_be_created() {
636        let _c = make_completer();
637    }
638
639    #[test]
640    fn detect_empty_context() {
641        let c = make_completer();
642        assert_eq!(c.detect_context("", 0), CompletionContext::Empty);
643    }
644
645    #[test]
646    fn detect_keyword_prefix() {
647        let c = make_completer();
648        assert_eq!(c.detect_context("SEL", 3), CompletionContext::Empty);
649    }
650
651    #[test]
652    fn detect_consistency_context() {
653        let c = make_completer();
654        assert_eq!(
655            c.detect_context("CONSISTENCY ", 12),
656            CompletionContext::ConsistencyLevel
657        );
658    }
659
660    #[test]
661    fn detect_serial_consistency_context() {
662        let c = make_completer();
663        assert_eq!(
664            c.detect_context("SERIAL CONSISTENCY ", 19),
665            CompletionContext::ConsistencyLevel
666        );
667    }
668
669    #[test]
670    fn detect_use_keyspace_context() {
671        let c = make_completer();
672        assert_eq!(c.detect_context("USE ", 4), CompletionContext::KeyspaceName);
673    }
674
675    #[test]
676    fn detect_describe_sub_command() {
677        let c = make_completer();
678        assert_eq!(
679            c.detect_context("DESCRIBE ", 9),
680            CompletionContext::DescribeTarget
681        );
682    }
683
684    #[test]
685    fn detect_describe_table_name() {
686        let c = make_completer();
687        assert_eq!(
688            c.detect_context("DESCRIBE TABLE ", 15),
689            CompletionContext::TableName { keyspace: None }
690        );
691    }
692
693    #[test]
694    fn detect_describe_keyspace_name() {
695        let c = make_completer();
696        assert_eq!(
697            c.detect_context("DESCRIBE KEYSPACE ", 18),
698            CompletionContext::KeyspaceName
699        );
700    }
701
702    #[test]
703    fn detect_source_file_path() {
704        let c = make_completer();
705        assert_eq!(
706            c.detect_context("SOURCE '/tmp/", 13),
707            CompletionContext::FilePath
708        );
709    }
710
711    #[test]
712    fn detect_capture_file_path() {
713        let c = make_completer();
714        assert_eq!(c.detect_context("CAPTURE ", 8), CompletionContext::FilePath);
715    }
716
717    #[test]
718    fn detect_from_table_context() {
719        let c = make_completer();
720        assert_eq!(
721            c.detect_context("SELECT * FROM ", 14),
722            CompletionContext::TableName { keyspace: None }
723        );
724    }
725
726    #[test]
727    fn complete_keyword_prefix() {
728        let c = make_completer();
729        let pairs = c.complete_for_context(&CompletionContext::Empty, "SEL");
730        assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
731    }
732
733    #[test]
734    fn complete_consistency_level_prefix() {
735        let c = make_completer();
736        let pairs = c.complete_for_context(&CompletionContext::ConsistencyLevel, "QU");
737        assert!(pairs.iter().any(|p| p.replacement == "QUORUM"));
738    }
739
740    #[test]
741    fn complete_describe_sub_command() {
742        let c = make_completer();
743        let pairs = c.complete_for_context(&CompletionContext::DescribeTarget, "KEY");
744        assert!(pairs.iter().any(|p| p.replacement == "KEYSPACE"));
745        assert!(pairs.iter().any(|p| p.replacement == "KEYSPACES"));
746    }
747
748    #[test]
749    fn filter_is_case_insensitive_for_keywords() {
750        let pairs = filter_candidates(CQL_KEYWORDS, "sel", true);
751        assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
752    }
753
754    #[test]
755    fn file_path_completion_tmp() {
756        // /tmp should exist on all Unix systems
757        let pairs = complete_file_path("/tmp/");
758        // Should return entries — exact count varies
759        assert!(
760            !pairs.is_empty() || std::fs::read_dir("/tmp").map(|d| d.count()).unwrap_or(0) == 0
761        );
762    }
763}