1use std::borrow::Cow;
10use std::sync::Arc;
11
12use rustyline::completion::{Completer, Pair};
13use rustyline::highlight::Highlighter;
14use rustyline::hint::Hinter;
15use rustyline::validate::Validator;
16use rustyline::{Context, Helper};
17use tokio::runtime::Handle;
18use tokio::sync::RwLock;
19
20use crate::colorizer::CqlColorizer;
21use crate::cql_lexer::{self, GrammarContext, TokenKind};
22use crate::schema_cache::SchemaCache;
23
24const CQL_KEYWORDS: &[&str] = &[
26 "ALTER", "APPLY", "BATCH", "BEGIN", "CREATE", "DELETE", "DESCRIBE", "DROP", "GRANT", "INSERT",
27 "LIST", "REVOKE", "SELECT", "TRUNCATE", "UPDATE", "USE",
28];
29
30const CQL_CLAUSE_KEYWORDS: &[&str] = &[
32 "ADD",
33 "AGGREGATE",
34 "ALL",
35 "ALLOW",
36 "AND",
37 "AS",
38 "ASC",
39 "AUTHORIZE",
40 "BATCH",
41 "BY",
42 "CALLED",
43 "CLUSTERING",
44 "COLUMN",
45 "COMPACT",
46 "CONTAINS",
47 "COUNT",
48 "CUSTOM",
49 "DELETE",
50 "DESC",
51 "DESCRIBE",
52 "DISTINCT",
53 "DROP",
54 "ENTRIES",
55 "EXECUTE",
56 "EXISTS",
57 "FILTERING",
58 "FINALFUNC",
59 "FROM",
60 "FROZEN",
61 "FULL",
62 "FUNCTION",
63 "FUNCTIONS",
64 "IF",
65 "IN",
66 "INDEX",
67 "INITCOND",
68 "INPUT",
69 "INSERT",
70 "INTO",
71 "IS",
72 "JSON",
73 "KEY",
74 "KEYS",
75 "KEYSPACE",
76 "KEYSPACES",
77 "LANGUAGE",
78 "LIKE",
79 "LIMIT",
80 "LIST",
81 "LOGIN",
82 "MAP",
83 "MATERIALIZED",
84 "MODIFY",
85 "NAMESPACE",
86 "NORECURSIVE",
87 "NOT",
88 "NULL",
89 "OF",
90 "ON",
91 "OR",
92 "ORDER",
93 "PARTITION",
94 "PASSWORD",
95 "PER",
96 "PERMISSION",
97 "PERMISSIONS",
98 "PRIMARY",
99 "RENAME",
100 "REPLACE",
101 "RETURNS",
102 "REVOKE",
103 "SCHEMA",
104 "SELECT",
105 "SET",
106 "SFUNC",
107 "STATIC",
108 "STORAGE",
109 "STYPE",
110 "SUPERUSER",
111 "TABLE",
112 "TABLES",
113 "TEXT",
114 "TIMESTAMP",
115 "TO",
116 "TOKEN",
117 "TRIGGER",
118 "TRUNCATE",
119 "TTL",
120 "TUPLE",
121 "TYPE",
122 "UNLOGGED",
123 "UPDATE",
124 "USER",
125 "USERS",
126 "USING",
127 "VALUES",
128 "VIEW",
129 "WHERE",
130 "WITH",
131 "WRITETIME",
132];
133
134const SHELL_COMMANDS: &[&str] = &[
136 "CAPTURE",
137 "CLEAR",
138 "CLS",
139 "CONSISTENCY",
140 "COPY",
141 "DESCRIBE",
142 "DESC",
143 "EXIT",
144 "EXPAND",
145 "HELP",
146 "LOGIN",
147 "PAGING",
148 "QUIT",
149 "SERIAL",
150 "SHOW",
151 "SOURCE",
152 "TRACING",
153];
154
155const CONSISTENCY_LEVELS: &[&str] = &[
157 "ALL",
158 "ANY",
159 "EACH_QUORUM",
160 "LOCAL_ONE",
161 "LOCAL_QUORUM",
162 "LOCAL_SERIAL",
163 "ONE",
164 "QUORUM",
165 "SERIAL",
166 "THREE",
167 "TWO",
168];
169
170const DESCRIBE_SUB_COMMANDS: &[&str] = &[
172 "AGGREGATE",
173 "AGGREGATES",
174 "CLUSTER",
175 "FULL",
176 "FUNCTION",
177 "FUNCTIONS",
178 "INDEX",
179 "KEYSPACE",
180 "KEYSPACES",
181 "MATERIALIZED",
182 "SCHEMA",
183 "TABLE",
184 "TABLES",
185 "TYPE",
186 "TYPES",
187];
188
189#[allow(dead_code)] const CQL_TYPES: &[&str] = &[
192 "ascii",
193 "bigint",
194 "blob",
195 "boolean",
196 "counter",
197 "date",
198 "decimal",
199 "double",
200 "duration",
201 "float",
202 "frozen",
203 "inet",
204 "int",
205 "list",
206 "map",
207 "set",
208 "smallint",
209 "text",
210 "time",
211 "timestamp",
212 "timeuuid",
213 "tinyint",
214 "tuple",
215 "uuid",
216 "varchar",
217 "varint",
218];
219
220#[derive(Debug, PartialEq)]
222enum CompletionContext {
223 Empty,
225 ClauseKeyword,
227 TableName { keyspace: Option<String> },
229 ColumnName {
231 keyspace: Option<String>,
232 table: String,
233 },
234 ConsistencyLevel,
236 DescribeTarget,
238 FilePath,
240 KeyspaceName,
242}
243
244pub struct CqlCompleter {
246 cache: Arc<RwLock<SchemaCache>>,
248 current_keyspace: Arc<RwLock<Option<String>>>,
250 rt_handle: Handle,
252 colorizer: CqlColorizer,
254}
255
256impl CqlCompleter {
257 pub fn new(
259 cache: Arc<RwLock<SchemaCache>>,
260 current_keyspace: Arc<RwLock<Option<String>>>,
261 rt_handle: Handle,
262 color_enabled: bool,
263 ) -> Self {
264 Self {
265 cache,
266 current_keyspace,
267 rt_handle,
268 colorizer: CqlColorizer::new(color_enabled),
269 }
270 }
271
272 fn detect_context(&self, line: &str, pos: usize) -> CompletionContext {
276 let before_cursor = &line[..pos];
277 let tokens = cql_lexer::tokenize(before_cursor);
278 let sig: Vec<_> = cql_lexer::significant_tokens(&tokens);
279
280 if sig.is_empty() {
281 return CompletionContext::Empty;
282 }
283
284 let first_upper = sig[0].text.to_uppercase();
286 if first_upper == "SOURCE" || first_upper == "CAPTURE" {
287 return CompletionContext::FilePath;
288 }
289
290 let grammar_ctx = cql_lexer::grammar_context_at_end(before_cursor);
291
292 match grammar_ctx {
293 GrammarContext::Start => CompletionContext::Empty,
294 GrammarContext::ExpectTable => {
295 let keyspace = self.extract_qualifying_keyspace(&sig);
297 CompletionContext::TableName { keyspace }
298 }
299 GrammarContext::ExpectKeyspace => CompletionContext::KeyspaceName,
300 GrammarContext::ExpectColumn | GrammarContext::ExpectSetClause => {
301 let (ks, table) = self.extract_table_from_tokens(&sig);
303 match table {
304 Some(t) => CompletionContext::ColumnName {
305 keyspace: ks,
306 table: t,
307 },
308 None => CompletionContext::ClauseKeyword,
309 }
310 }
311 GrammarContext::ExpectConsistencyLevel => CompletionContext::ConsistencyLevel,
312 GrammarContext::ExpectDescribeTarget => CompletionContext::DescribeTarget,
313 GrammarContext::ExpectFilePath => CompletionContext::FilePath,
314 GrammarContext::ExpectQualifiedPart => {
315 let keyspace = self.extract_qualifying_keyspace(&sig);
317 CompletionContext::TableName { keyspace }
318 }
319 GrammarContext::ExpectColumnList => {
320 if sig.len() == 1 && !before_cursor.ends_with(' ') {
322 CompletionContext::Empty
323 } else {
324 CompletionContext::ClauseKeyword
325 }
326 }
327 _ => {
328 if sig.len() == 1 && !before_cursor.ends_with(' ') {
330 CompletionContext::Empty
331 } else {
332 CompletionContext::ClauseKeyword
333 }
334 }
335 }
336 }
337
338 fn extract_qualifying_keyspace(&self, sig: &[&cql_lexer::Token]) -> Option<String> {
340 let len = sig.len();
342 if len >= 2 && sig[len - 1].text == "." {
343 return Some(sig[len - 2].text.clone());
344 }
345 None
346 }
347
348 fn extract_table_from_tokens(
350 &self,
351 sig: &[&cql_lexer::Token],
352 ) -> (Option<String>, Option<String>) {
353 for (i, tok) in sig.iter().enumerate() {
354 let upper = tok.text.to_uppercase();
355 if matches!(upper.as_str(), "FROM" | "INTO" | "UPDATE" | "TABLE")
356 && i + 1 < sig.len()
357 && matches!(
358 sig[i + 1].kind,
359 TokenKind::Identifier | TokenKind::QuotedIdentifier
360 )
361 {
362 let table = sig[i + 1].text.clone();
363 if i + 3 < sig.len() && sig[i + 2].text == "." {
365 let ks = table;
366 let tbl = sig[i + 3].text.clone();
367 return (Some(ks), Some(tbl));
368 }
369 let ks = tokio::task::block_in_place(|| {
370 self.rt_handle
371 .block_on(async { self.current_keyspace.read().await.clone() })
372 });
373 return (ks, Some(table));
374 }
375 }
376 (None, None)
377 }
378
379 fn complete_for_context(&self, ctx: &CompletionContext, prefix: &str) -> Vec<Pair> {
381 let prefix_upper = prefix.to_uppercase();
382
383 match ctx {
384 CompletionContext::Empty => {
385 let mut candidates: Vec<&str> = Vec::new();
386 candidates.extend_from_slice(CQL_KEYWORDS);
387 candidates.extend_from_slice(SHELL_COMMANDS);
388 filter_candidates(&candidates, &prefix_upper, true)
389 }
390 CompletionContext::ClauseKeyword => {
391 filter_candidates(CQL_CLAUSE_KEYWORDS, &prefix_upper, true)
392 }
393 CompletionContext::ConsistencyLevel => {
394 filter_candidates(CONSISTENCY_LEVELS, &prefix_upper, true)
395 }
396 CompletionContext::DescribeTarget => {
397 filter_candidates(DESCRIBE_SUB_COMMANDS, &prefix_upper, true)
398 }
399 CompletionContext::KeyspaceName => {
400 let cache =
401 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
402 let names = cache.keyspace_names();
403 filter_candidates(&names, prefix, false)
404 }
405 CompletionContext::TableName { keyspace } => {
406 let cache =
407 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
408 let ks = keyspace.clone().or_else(|| {
409 tokio::task::block_in_place(|| {
410 self.rt_handle
411 .block_on(async { self.current_keyspace.read().await.clone() })
412 })
413 });
414 match ks {
415 Some(ref ks_name) => {
416 let names = cache.table_names(ks_name);
417 filter_candidates(&names, prefix, false)
418 }
419 None => {
420 let names = cache.keyspace_names();
422 filter_candidates(&names, prefix, false)
423 }
424 }
425 }
426 CompletionContext::ColumnName { keyspace, table } => {
427 let cache =
428 tokio::task::block_in_place(|| self.rt_handle.block_on(self.cache.read()));
429 let ks = keyspace.clone().or_else(|| {
430 tokio::task::block_in_place(|| {
431 self.rt_handle
432 .block_on(async { self.current_keyspace.read().await.clone() })
433 })
434 });
435 match ks {
436 Some(ref ks_name) => {
437 let names = cache.column_names(ks_name, table);
438 filter_candidates(&names, prefix, false)
439 }
440 None => vec![],
441 }
442 }
443 CompletionContext::FilePath => complete_file_path(prefix),
444 }
445 }
446}
447
448fn filter_candidates(candidates: &[&str], prefix: &str, uppercase: bool) -> Vec<Pair> {
450 candidates
451 .iter()
452 .filter(|c| {
453 if uppercase {
454 c.to_uppercase().starts_with(&prefix.to_uppercase())
455 } else {
456 c.starts_with(prefix)
457 }
458 })
459 .map(|c| {
460 let display = if uppercase {
461 c.to_uppercase()
462 } else {
463 c.to_string()
464 };
465 Pair {
466 display: display.clone(),
467 replacement: display,
468 }
469 })
470 .collect()
471}
472
473fn complete_file_path(prefix: &str) -> Vec<Pair> {
475 let path_str = prefix
477 .strip_prefix('\'')
478 .or_else(|| prefix.strip_prefix('"'))
479 .unwrap_or(prefix);
480
481 let expanded = if path_str.starts_with('~') {
483 if let Some(home) = dirs::home_dir() {
484 path_str.replacen('~', &home.to_string_lossy(), 1)
485 } else {
486 path_str.to_string()
487 }
488 } else {
489 path_str.to_string()
490 };
491
492 let (dir, file_prefix) = if expanded.ends_with('/') {
493 (expanded.as_str(), "")
494 } else {
495 let path = std::path::Path::new(&expanded);
496 let parent = path
497 .parent()
498 .map(|p| p.to_str().unwrap_or("."))
499 .unwrap_or(".");
500 let file = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
501 (parent, file)
502 };
503
504 let dir_to_read = if dir.is_empty() { "." } else { dir };
505
506 let Ok(entries) = std::fs::read_dir(dir_to_read) else {
507 return vec![];
508 };
509
510 entries
511 .filter_map(|entry| entry.ok())
512 .filter_map(|entry| {
513 let name = entry.file_name().to_string_lossy().to_string();
514 if name.starts_with(file_prefix) {
515 let is_dir = entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false);
516 let suffix = if is_dir { "/" } else { "" };
517 let full = if dir.is_empty() || dir == "." {
518 format!("{name}{suffix}")
519 } else if dir.ends_with('/') {
520 format!("{dir}{name}{suffix}")
521 } else {
522 format!("{dir}/{name}{suffix}")
523 };
524 Some(Pair {
525 display: name + suffix,
526 replacement: full,
527 })
528 } else {
529 None
530 }
531 })
532 .collect()
533}
534
535impl Completer for CqlCompleter {
536 type Candidate = Pair;
537
538 fn complete(
539 &self,
540 line: &str,
541 pos: usize,
542 _ctx: &Context<'_>,
543 ) -> rustyline::Result<(usize, Vec<Pair>)> {
544 let needs_refresh = tokio::task::block_in_place(|| {
546 self.rt_handle
547 .block_on(async { self.cache.read().await.is_stale() })
548 });
549 if needs_refresh {
550 tokio::task::block_in_place(|| {
552 self.rt_handle.block_on(async {
553 if let Ok(mut cache) = self.cache.try_write() {
555 if cache.is_stale() {
557 cache.invalidate();
560 }
561 }
562 })
563 });
564 }
565
566 let context = self.detect_context(line, pos);
567
568 let before_cursor = &line[..pos];
570 let word_start = before_cursor
571 .rfind(|c: char| c.is_whitespace() || c == '.' || c == '\'' || c == '"')
572 .map(|i| i + 1)
573 .unwrap_or(0);
574 let prefix = &line[word_start..pos];
575
576 let completions = self.complete_for_context(&context, prefix);
577
578 Ok((word_start, completions))
579 }
580}
581
582impl Hinter for CqlCompleter {
583 type Hint = String;
584
585 fn hint(&self, _line: &str, _pos: usize, _ctx: &Context<'_>) -> Option<String> {
586 None
587 }
588}
589
590impl Highlighter for CqlCompleter {
591 fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
592 let colored = self.colorizer.colorize_line(line);
593 if colored == line {
594 Cow::Borrowed(line)
595 } else {
596 Cow::Owned(colored)
597 }
598 }
599
600 fn highlight_prompt<'b, 's: 'b, 'p: 'b>(
601 &'s self,
602 prompt: &'p str,
603 _default: bool,
604 ) -> Cow<'b, str> {
605 Cow::Borrowed(prompt)
606 }
607
608 fn highlight_char(
609 &self,
610 _line: &str,
611 _pos: usize,
612 _forced: rustyline::highlight::CmdKind,
613 ) -> bool {
614 true
616 }
617}
618
619impl Validator for CqlCompleter {}
620
621impl Helper for CqlCompleter {}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 fn make_completer() -> CqlCompleter {
628 let rt = tokio::runtime::Runtime::new().unwrap();
629 let cache = Arc::new(RwLock::new(SchemaCache::new()));
630 let current_ks = Arc::new(RwLock::new(None::<String>));
631 CqlCompleter::new(cache, current_ks, rt.handle().clone(), false)
632 }
633
634 #[test]
635 fn completer_can_be_created() {
636 let _c = make_completer();
637 }
638
639 #[test]
640 fn detect_empty_context() {
641 let c = make_completer();
642 assert_eq!(c.detect_context("", 0), CompletionContext::Empty);
643 }
644
645 #[test]
646 fn detect_keyword_prefix() {
647 let c = make_completer();
648 assert_eq!(c.detect_context("SEL", 3), CompletionContext::Empty);
649 }
650
651 #[test]
652 fn detect_consistency_context() {
653 let c = make_completer();
654 assert_eq!(
655 c.detect_context("CONSISTENCY ", 12),
656 CompletionContext::ConsistencyLevel
657 );
658 }
659
660 #[test]
661 fn detect_serial_consistency_context() {
662 let c = make_completer();
663 assert_eq!(
664 c.detect_context("SERIAL CONSISTENCY ", 19),
665 CompletionContext::ConsistencyLevel
666 );
667 }
668
669 #[test]
670 fn detect_use_keyspace_context() {
671 let c = make_completer();
672 assert_eq!(c.detect_context("USE ", 4), CompletionContext::KeyspaceName);
673 }
674
675 #[test]
676 fn detect_describe_sub_command() {
677 let c = make_completer();
678 assert_eq!(
679 c.detect_context("DESCRIBE ", 9),
680 CompletionContext::DescribeTarget
681 );
682 }
683
684 #[test]
685 fn detect_describe_table_name() {
686 let c = make_completer();
687 assert_eq!(
688 c.detect_context("DESCRIBE TABLE ", 15),
689 CompletionContext::TableName { keyspace: None }
690 );
691 }
692
693 #[test]
694 fn detect_describe_keyspace_name() {
695 let c = make_completer();
696 assert_eq!(
697 c.detect_context("DESCRIBE KEYSPACE ", 18),
698 CompletionContext::KeyspaceName
699 );
700 }
701
702 #[test]
703 fn detect_source_file_path() {
704 let c = make_completer();
705 assert_eq!(
706 c.detect_context("SOURCE '/tmp/", 13),
707 CompletionContext::FilePath
708 );
709 }
710
711 #[test]
712 fn detect_capture_file_path() {
713 let c = make_completer();
714 assert_eq!(c.detect_context("CAPTURE ", 8), CompletionContext::FilePath);
715 }
716
717 #[test]
718 fn detect_from_table_context() {
719 let c = make_completer();
720 assert_eq!(
721 c.detect_context("SELECT * FROM ", 14),
722 CompletionContext::TableName { keyspace: None }
723 );
724 }
725
726 #[test]
727 fn complete_keyword_prefix() {
728 let c = make_completer();
729 let pairs = c.complete_for_context(&CompletionContext::Empty, "SEL");
730 assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
731 }
732
733 #[test]
734 fn complete_consistency_level_prefix() {
735 let c = make_completer();
736 let pairs = c.complete_for_context(&CompletionContext::ConsistencyLevel, "QU");
737 assert!(pairs.iter().any(|p| p.replacement == "QUORUM"));
738 }
739
740 #[test]
741 fn complete_describe_sub_command() {
742 let c = make_completer();
743 let pairs = c.complete_for_context(&CompletionContext::DescribeTarget, "KEY");
744 assert!(pairs.iter().any(|p| p.replacement == "KEYSPACE"));
745 assert!(pairs.iter().any(|p| p.replacement == "KEYSPACES"));
746 }
747
748 #[test]
749 fn filter_is_case_insensitive_for_keywords() {
750 let pairs = filter_candidates(CQL_KEYWORDS, "sel", true);
751 assert!(pairs.iter().any(|p| p.replacement == "SELECT"));
752 }
753
754 #[test]
755 fn file_path_completion_tmp() {
756 let pairs = complete_file_path("/tmp/");
758 assert!(
760 !pairs.is_empty() || std::fs::read_dir("/tmp").map(|d| d.count()).unwrap_or(0) == 0
761 );
762 }
763}