Skip to main content

cqlsh_rs/
repl.rs

1//! Interactive REPL (Read-Eval-Print Loop) for cqlsh-rs.
2//!
3//! Integrates rustyline for line editing, history, and prompt management.
4//! Mirrors the Python cqlsh interactive behavior including multi-line input,
5//! prompt formatting, and Ctrl-C/Ctrl-D handling.
6
7use std::fs::File;
8use std::io::{self, BufRead, IsTerminal, Write};
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use anyhow::Result;
13use rustyline::error::ReadlineError;
14use rustyline::history::DefaultHistory;
15use rustyline::{CompletionType, Config, EditMode, Editor};
16use tokio::sync::RwLock;
17
18use crate::colorizer::CqlColorizer;
19use crate::completer::CqlCompleter;
20use crate::config::MergedConfig;
21use crate::describe;
22use crate::error;
23use crate::formatter;
24use crate::parser::{self, ParseResult, StatementParser};
25use crate::schema_cache::SchemaCache;
26use crate::session::CqlSession;
27
28/// Default history file path: ~/.cassandra/cql_history
29const DEFAULT_HISTORY_DIR: &str = ".cassandra";
30const DEFAULT_HISTORY_FILE: &str = "cql_history";
31/// Maximum history entries (matches Python cqlsh default).
32const DEFAULT_HISTORY_SIZE: usize = 1000;
33/// Continuation prompt for multi-line input (matches Python cqlsh).
34const CONTINUATION_PROMPT: &str = "   ... ";
35
36/// Build the primary prompt string matching Python cqlsh format.
37///
38/// Format: `[username@]cqlsh[:keyspace]> `
39///
40/// Examples:
41/// - `cqlsh> ` (no user, no keyspace)
42/// - `cqlsh:my_ks> ` (with keyspace)
43/// - `admin@cqlsh> ` (with username)
44/// - `admin@cqlsh:my_ks> ` (with username and keyspace)
45pub fn build_prompt(username: Option<&str>, keyspace: Option<&str>) -> String {
46    let mut prompt = String::with_capacity(64);
47    if let Some(user) = username {
48        prompt.push_str(user);
49        prompt.push('@');
50    }
51    prompt.push_str("cqlsh");
52    if let Some(ks) = keyspace {
53        prompt.push(':');
54        prompt.push_str(ks);
55    }
56    prompt.push_str("> ");
57    prompt
58}
59
60/// Resolve the history file path.
61///
62/// Priority: CQL_HISTORY env var > ~/.cassandra/cql_history
63fn resolve_history_path(config: &MergedConfig) -> Option<PathBuf> {
64    if config.disable_history {
65        return None;
66    }
67
68    // Check CQL_HISTORY env var (already captured in EnvConfig, but we
69    // also respect it directly here for simplicity)
70    if let Ok(path) = std::env::var("CQL_HISTORY") {
71        return Some(PathBuf::from(path));
72    }
73
74    dirs::home_dir().map(|home| home.join(DEFAULT_HISTORY_DIR).join(DEFAULT_HISTORY_FILE))
75}
76
77/// Mutable shell state for commands like EXPAND, PAGING, and CAPTURE.
78struct ShellState {
79    /// Whether expanded (vertical) output is enabled.
80    expand: bool,
81    /// Whether to pipe output through the built-in pager.
82    paging_enabled: bool,
83    /// Whether stdout is a TTY (controls pager auto-disable).
84    is_tty: bool,
85    /// Whether debug mode is enabled (toggled via DEBUG command).
86    debug: bool,
87    /// Active CAPTURE file handle (output is tee'd to this file).
88    capture_file: Option<File>,
89    /// Path of the active capture file (for display).
90    capture_path: Option<PathBuf>,
91    /// Shared schema cache for tab completion (invalidated on DDL).
92    schema_cache: Option<Arc<RwLock<SchemaCache>>>,
93    /// Shared current keyspace for tab completion.
94    shared_keyspace: Option<Arc<RwLock<Option<String>>>>,
95    /// Output colorizer for result values, headers, and errors.
96    colorizer: CqlColorizer,
97}
98
99impl ShellState {
100    /// Write output line to both stdout and the capture file (if active).
101    /// Used for short shell command output that doesn't need paging.
102    fn outputln(&mut self, text: &str) {
103        println!("{text}");
104        if let Some(ref mut f) = self.capture_file {
105            let _ = writeln!(f, "{text}");
106        }
107    }
108
109    /// Display output, routing through the pager if enabled, and writing to capture file.
110    /// An optional `title` is shown at the top of the pager (e.g., column names).
111    fn display_output(&mut self, content: &[u8], title: &str) {
112        // Write to capture file if active
113        if let Some(ref mut f) = self.capture_file {
114            let _ = f.write_all(content);
115        }
116
117        let text = String::from_utf8_lossy(content);
118
119        // Route through pager or print directly
120        if self.paging_enabled && self.is_tty {
121            if crate::pager::page_content(&text, title).is_err() {
122                // Fallback: print directly if pager fails
123                print!("{text}");
124            }
125        } else {
126            print!("{text}");
127        }
128    }
129}
130
131// Statement parsing is now handled by the parser module (SP4).
132// The REPL uses `parser::StatementParser` for incremental, context-aware
133// statement detection that correctly handles strings, comments, and
134// multi-line input.
135
136fn could_return_rows(stmt: &str) -> bool {
137    let upper = stmt.trim_start().to_uppercase();
138    upper.starts_with("SELECT ") || upper.starts_with("SELECT\n") || upper.starts_with("SELECT\t")
139}
140
141/// Run the interactive REPL loop.
142///
143/// Reads lines from the user, handles multi-line input, and dispatches
144/// complete statements to the session for execution.
145pub async fn run(session: &mut CqlSession, config: &MergedConfig) -> Result<()> {
146    let rl_config = Config::builder()
147        .max_history_size(DEFAULT_HISTORY_SIZE)
148        .expect("valid history size")
149        .edit_mode(EditMode::Emacs)
150        .auto_add_history(true)
151        .completion_type(CompletionType::List)
152        .build();
153
154    // Set up schema cache and tab completer
155    let schema_cache = Arc::new(RwLock::new(SchemaCache::new()));
156    let current_keyspace: Arc<RwLock<Option<String>>> =
157        Arc::new(RwLock::new(session.current_keyspace().map(String::from)));
158
159    // Initial schema cache population (best-effort)
160    {
161        let mut cache = schema_cache.write().await;
162        if let Err(e) = cache.refresh(session).await {
163            eprintln!("Warning: could not load schema for tab completion: {e}");
164        }
165    }
166
167    // Resolve color mode: Auto → check if stdout is a terminal
168    // --tty flag forces TTY behavior even when piped
169    let is_tty = config.tty || std::io::stdout().is_terminal();
170    let color_enabled = match config.color {
171        crate::config::ColorMode::On => true,
172        crate::config::ColorMode::Off => false,
173        crate::config::ColorMode::Auto => is_tty,
174    };
175
176    let completer = CqlCompleter::new(
177        Arc::clone(&schema_cache),
178        Arc::clone(&current_keyspace),
179        tokio::runtime::Handle::current(),
180        color_enabled,
181    );
182
183    let mut rl: Editor<CqlCompleter, DefaultHistory> = Editor::with_config(rl_config)?;
184    rl.set_helper(Some(completer));
185
186    // Load history
187    let history_path = resolve_history_path(config);
188    if let Some(ref path) = history_path {
189        // Ensure the parent directory exists
190        if let Some(parent) = path.parent() {
191            let _ = std::fs::create_dir_all(parent);
192        }
193        let _ = rl.load_history(path);
194    }
195
196    let username = config.username.as_deref();
197    let mut stmt_parser = StatementParser::new();
198    let colorizer = CqlColorizer::new(color_enabled);
199    let mut shell = ShellState {
200        expand: false,
201        paging_enabled: true,
202        is_tty,
203        debug: config.debug,
204        capture_file: None,
205        capture_path: None,
206        schema_cache: Some(Arc::clone(&schema_cache)),
207        shared_keyspace: Some(Arc::clone(&current_keyspace)),
208        colorizer,
209    };
210
211    loop {
212        let prompt = if stmt_parser.is_empty() {
213            build_prompt(username, session.current_keyspace())
214        } else {
215            CONTINUATION_PROMPT.to_string()
216        };
217
218        match rl.readline(&prompt) {
219            Ok(line) => {
220                // BUG-5 fix: Split pasted multi-line input into individual
221                // lines so each is processed separately.
222                let lines: Vec<&str> = line.split('\n').collect();
223                for sub_line in lines {
224                    process_line(sub_line, &mut stmt_parser, session, config, &mut shell).await;
225                }
226            }
227            Err(ReadlineError::Interrupted) => {
228                // Ctrl-C: cancel current input buffer, return to prompt
229                stmt_parser.reset();
230            }
231            Err(ReadlineError::Eof) => {
232                // Ctrl-D: exit
233                break;
234            }
235            Err(err) => {
236                eprintln!("Error reading input: {err}");
237                break;
238            }
239        }
240    }
241
242    // Save history
243    if let Some(ref path) = history_path {
244        let _ = rl.save_history(path);
245    }
246
247    Ok(())
248}
249
250/// Process a single line of input through the REPL pipeline.
251///
252/// Handles shell command detection, incremental parsing, and dispatch.
253async fn process_line(
254    line: &str,
255    stmt_parser: &mut StatementParser,
256    session: &mut CqlSession,
257    config: &MergedConfig,
258    shell: &mut ShellState,
259) {
260    let trimmed = line.trim();
261
262    // On an empty primary prompt, just show the prompt again
263    if stmt_parser.is_empty() && trimmed.is_empty() {
264        return;
265    }
266
267    // Shell commands are complete without semicolons (only on first line)
268    if stmt_parser.is_empty() && parser::is_shell_command(trimmed) {
269        // Strip trailing semicolon before dispatch — is_shell_command tolerates
270        // the semicolon for detection, but handlers expect clean input.
271        let clean = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
272        dispatch_input(session, config, shell, clean).await;
273        return;
274    }
275
276    // Feed line to the incremental parser
277    if let ParseResult::Complete(statements) = stmt_parser.feed_line(line) {
278        for stmt in statements {
279            dispatch_input(session, config, shell, &stmt).await;
280        }
281    }
282}
283
284/// Dispatch a complete input line/statement to the session.
285///
286/// Handles built-in shell commands and CQL statements.
287/// Uses `Box::pin` to support recursive calls from `execute_source`.
288fn dispatch_input<'a>(
289    session: &'a mut CqlSession,
290    config: &'a MergedConfig,
291    shell: &'a mut ShellState,
292    input: &'a str,
293) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'a>> {
294    Box::pin(async move {
295        let trimmed = input.trim();
296        let upper = trimmed.to_uppercase();
297
298        // Handle QUIT/EXIT
299        if upper == "QUIT" || upper == "EXIT" {
300            std::process::exit(0);
301        }
302
303        // Handle HELP [topic]
304        if upper == "HELP" || upper == "?" || upper.starts_with("HELP ") {
305            if let Some(topic) = upper.strip_prefix("HELP ") {
306                print_help_topic(topic.trim(), &mut std::io::stdout());
307            } else {
308                print_help(&mut std::io::stdout());
309            }
310            return;
311        }
312
313        // Handle CLEAR/CLS
314        if upper == "CLEAR" || upper == "CLS" {
315            print!("\x1B[2J\x1B[1;1H");
316            return;
317        }
318
319        // Handle CONSISTENCY
320        if upper == "CONSISTENCY" {
321            let cl = session.get_consistency();
322            shell.outputln(&format!("Current consistency level is {cl}."));
323            return;
324        }
325        if let Some(rest) = upper.strip_prefix("CONSISTENCY ") {
326            let level = rest.trim();
327            match session.set_consistency_str(level) {
328                Ok(()) => shell.outputln(&format!("Consistency level set to {level}.")),
329                Err(e) => eprintln!("{e}"),
330            }
331            return;
332        }
333
334        // Handle SERIAL CONSISTENCY
335        if upper == "SERIAL CONSISTENCY" {
336            match session.get_serial_consistency() {
337                Some(scl) => shell.outputln(&format!("Current serial consistency level is {scl}.")),
338                None => shell.outputln("Current serial consistency level is SERIAL."),
339            }
340            return;
341        }
342        if let Some(rest) = upper.strip_prefix("SERIAL CONSISTENCY ") {
343            let level = rest.trim();
344            match session.set_serial_consistency_str(level) {
345                Ok(()) => shell.outputln(&format!("Serial consistency level set to {level}.")),
346                Err(e) => eprintln!("{e}"),
347            }
348            return;
349        }
350
351        // Handle TRACING
352        if upper == "TRACING" || upper == "TRACING OFF" {
353            session.set_tracing(false);
354            shell.outputln("Tracing is disabled");
355            return;
356        }
357        if upper == "TRACING ON" {
358            session.set_tracing(true);
359            shell.outputln("Tracing is enabled");
360            return;
361        }
362
363        // Handle EXPAND
364        if upper == "EXPAND" {
365            if shell.expand {
366                shell.outputln("Expanded output is currently enabled. Use EXPAND OFF to disable.");
367            } else {
368                shell.outputln("Expanded output is currently disabled. Use EXPAND ON to enable.");
369            }
370            return;
371        }
372        if upper == "EXPAND ON" {
373            shell.expand = true;
374            shell.outputln("Now printing expanded output.");
375            return;
376        }
377        if upper == "EXPAND OFF" {
378            shell.expand = false;
379            shell.outputln("Disabled expanded output.");
380            return;
381        }
382
383        // Handle PAGING
384        if upper == "PAGING" {
385            if shell.paging_enabled {
386                shell.outputln("Query paging is currently enabled. Use PAGING OFF to disable.");
387            } else {
388                shell.outputln("Query paging is currently disabled. Use PAGING ON to enable.");
389            }
390            return;
391        }
392        if upper == "PAGING ON" {
393            shell.paging_enabled = true;
394            shell.outputln("Now query paging is enabled.");
395            return;
396        }
397        if upper == "PAGING OFF" {
398            shell.paging_enabled = false;
399            shell.outputln("Disabled paging.");
400            return;
401        }
402        if upper.strip_prefix("PAGING ").is_some() {
403            // Accept PAGING <N> for compatibility — enables paging
404            shell.paging_enabled = true;
405            shell.outputln("Now query paging is enabled.");
406            return;
407        }
408
409        // Handle SOURCE
410        if upper.starts_with("SOURCE ") {
411            let path = trimmed["SOURCE ".len()..].trim();
412            let path = strip_quotes(path);
413            if config.no_file_io {
414                eprintln!("File I/O is disabled (--no-file-io).");
415            } else {
416                execute_source(session, config, shell, path).await;
417            }
418            return;
419        }
420        if upper == "SOURCE" {
421            eprintln!("SOURCE requires a file path argument.");
422            return;
423        }
424
425        // Handle CAPTURE
426        if upper == "CAPTURE" {
427            match &shell.capture_path {
428                Some(path) => {
429                    shell.outputln(&format!("Currently capturing to '{}'.", path.display()))
430                }
431                None => shell.outputln("Not currently capturing."),
432            }
433            return;
434        }
435        if upper == "CAPTURE OFF" {
436            if shell.capture_file.is_some() {
437                let path = shell.capture_path.take().unwrap();
438                shell.capture_file = None;
439                shell.outputln(&format!(
440                    "Stopped capture. Output saved to '{}'.",
441                    path.display()
442                ));
443            } else {
444                shell.outputln("Not currently capturing.");
445            }
446            return;
447        }
448        if upper.strip_prefix("CAPTURE ").is_some() {
449            let path = trimmed["CAPTURE ".len()..].trim();
450            let path = strip_quotes(path);
451            if config.no_file_io {
452                eprintln!("File I/O is disabled (--no-file-io).");
453            } else {
454                let expanded = expand_tilde(path);
455                match File::create(&expanded) {
456                    Ok(file) => {
457                        shell.outputln(&format!(
458                            "Now capturing query output to '{}'.",
459                            expanded.display()
460                        ));
461                        shell.capture_file = Some(file);
462                        shell.capture_path = Some(expanded);
463                    }
464                    Err(e) => eprintln!("Unable to open '{}' for writing: {e}", expanded.display()),
465                }
466            }
467            return;
468        }
469
470        // Handle DEBUG
471        if upper == "DEBUG" {
472            if shell.debug {
473                shell.outputln("Debug output is currently enabled. Use DEBUG OFF to disable.");
474            } else {
475                shell.outputln("Debug output is currently disabled. Use DEBUG ON to enable.");
476            }
477            return;
478        }
479        if upper == "DEBUG ON" {
480            shell.debug = true;
481            shell.outputln("Now printing debug output.");
482            return;
483        }
484        if upper == "DEBUG OFF" {
485            shell.debug = false;
486            shell.outputln("Disabled debug output.");
487            return;
488        }
489
490        // Handle UNICODE
491        if upper == "UNICODE" {
492            shell.outputln(&format!(
493                "Encoding: {}\nDefault encoding: utf-8",
494                config.encoding
495            ));
496            return;
497        }
498
499        // Handle LOGIN
500        if upper == "LOGIN" {
501            eprintln!("Usage: LOGIN <username> [<password>]");
502            return;
503        }
504        if upper.starts_with("LOGIN ") {
505            let args = trimmed["LOGIN ".len()..].trim();
506            let parts: Vec<&str> = args.splitn(2, char::is_whitespace).collect();
507            let new_user = parts[0].to_string();
508            let new_pass = if parts.len() > 1 {
509                Some(parts[1].to_string())
510            } else {
511                // Prompt for password
512                eprint!("Password: ");
513                let _ = io::stderr().flush();
514                let mut pass = String::new();
515                if io::stdin().read_line(&mut pass).is_ok() {
516                    Some(pass.trim().to_string())
517                } else {
518                    None
519                }
520            };
521            // Reconnect with new credentials
522            let prev_keyspace = session.current_keyspace().map(str::to_string);
523            let mut new_config = config.clone();
524            new_config.username = Some(new_user);
525            new_config.password = new_pass;
526            match crate::session::CqlSession::connect(&new_config).await {
527                Ok(mut new_session) => {
528                    if let Some(ks) = prev_keyspace {
529                        if let Err(e) = new_session.use_keyspace(&ks).await {
530                            eprintln!("Warning: could not restore keyspace '{ks}': {e}");
531                        }
532                    }
533                    *session = new_session;
534                    shell.outputln("Login successful.");
535                }
536                Err(e) => {
537                    eprintln!("Login failed: {e}");
538                }
539            }
540            return;
541        }
542
543        // Handle COPY TO
544        if upper.starts_with("COPY ") && upper.contains(" TO ") {
545            if config.no_file_io {
546                eprintln!("File I/O is disabled (--no-file-io).");
547            } else {
548                match crate::copy::parse_copy_to(trimmed) {
549                    Ok(cmd) => {
550                        let ks = session.current_keyspace();
551                        match crate::copy::execute_copy_to(session, &cmd, ks).await {
552                            Ok(()) => {}
553                            Err(e) => eprintln!("COPY TO error: {e}"),
554                        }
555                    }
556                    Err(e) => eprintln!("Invalid COPY TO syntax: {e}"),
557                }
558            }
559            return;
560        }
561
562        // Handle COPY FROM
563        if upper.starts_with("COPY ") && upper.contains(" FROM ") {
564            if config.no_file_io {
565                eprintln!("File I/O is disabled (--no-file-io).");
566            } else {
567                match crate::copy::parse_copy_from(trimmed) {
568                    Ok(cmd) => {
569                        let ks = session.current_keyspace();
570                        match crate::copy::execute_copy_from(session, &cmd, ks).await {
571                            Ok(()) => {}
572                            Err(e) => eprintln!("COPY FROM error: {e}"),
573                        }
574                    }
575                    Err(e) => eprintln!("Invalid COPY FROM syntax: {e}"),
576                }
577            }
578            return;
579        }
580
581        // Handle DESCRIBE / DESC
582        if upper == "DESCRIBE"
583            || upper == "DESC"
584            || upper.starts_with("DESCRIBE ")
585            || upper.starts_with("DESC ")
586        {
587            let args = if upper.starts_with("DESCRIBE ") {
588                trimmed["DESCRIBE ".len()..].trim()
589            } else if upper.starts_with("DESC ") {
590                trimmed["DESC ".len()..].trim()
591            } else {
592                ""
593            };
594            let mut buf = Vec::new();
595            match describe::execute(session, args, &mut buf).await {
596                Ok(()) => shell.display_output(&buf, ""),
597                Err(e) => eprintln!("Error: {e}"),
598            }
599            return;
600        }
601
602        // Handle SHOW VERSION
603        if upper == "SHOW VERSION" {
604            shell.outputln(&format!("[cqlsh {}]", env!("CARGO_PKG_VERSION")));
605            return;
606        }
607
608        // Handle SHOW HOST
609        if upper == "SHOW HOST" {
610            shell.outputln(&format!("Connected to: {}", session.connection_display));
611            return;
612        }
613
614        // Handle SHOW SESSION <uuid>
615        if let Some(rest) = upper.strip_prefix("SHOW SESSION ") {
616            let uuid_str = rest.trim();
617            match uuid::Uuid::parse_str(uuid_str) {
618                Ok(trace_id) => match session.get_trace_session(trace_id).await {
619                    Ok(Some(trace)) => {
620                        let mut buf = Vec::new();
621                        formatter::print_trace(&trace, &shell.colorizer, &mut buf);
622                        shell.display_output(&buf, "");
623                    }
624                    Ok(None) => eprintln!("Trace session {trace_id} not found."),
625                    Err(e) => eprintln!("Error fetching trace: {e}"),
626                },
627                Err(_) => eprintln!("Invalid UUID: {uuid_str}"),
628            }
629            return;
630        }
631        if upper == "SHOW SESSION" {
632            eprintln!("Usage: SHOW SESSION <trace-uuid>");
633            return;
634        }
635
636        // Execute as CQL statement
637        if shell.paging_enabled && shell.is_tty && could_return_rows(trimmed) {
638            match session.execute_streaming(trimmed, config.fetch_size).await {
639                Ok(mut row_stream) => {
640                    if !row_stream.columns.is_empty() {
641                        let col_title = row_stream
642                            .columns
643                            .iter()
644                            .map(|c| c.name.as_str())
645                            .collect::<Vec<_>>()
646                            .join(" | ");
647
648                        match crate::pager::page_stream(&col_title) {
649                            Ok(mut pipe_writer) => {
650                                use futures::StreamExt;
651                                let is_file_mode = pipe_writer.is_file_mode();
652                                let mut fmt = if shell.expand {
653                                    formatter::StreamingTableFormatter::new_expanded(
654                                        row_stream.columns.clone(),
655                                        &shell.colorizer,
656                                        &mut pipe_writer,
657                                    )
658                                } else {
659                                    formatter::StreamingTableFormatter::new(
660                                        row_stream.columns.clone(),
661                                        &shell.colorizer,
662                                        &mut pipe_writer,
663                                        100,
664                                    )
665                                };
666                                let mut row_count: usize = 0;
667                                while let Some(row_result) = row_stream.rows.next().await {
668                                    match row_result {
669                                        Ok(row) => {
670                                            if fmt.add_row(row).is_err() {
671                                                break;
672                                            }
673                                            row_count += 1;
674                                            let _ = fmt.flush_writer();
675                                            if !is_file_mode && row_count.is_multiple_of(1000) {
676                                                eprint!("\rFetched {row_count} rows...");
677                                            }
678                                        }
679                                        Err(e) => {
680                                            let msg = format!("Error fetching row: {e}");
681                                            eprintln!("{}", shell.colorizer.colorize_error(&msg));
682                                            break;
683                                        }
684                                    }
685                                }
686                                if row_count >= 1000 && !is_file_mode {
687                                    eprintln!("\rFetched {row_count} rows.   ");
688                                }
689                                let _ = fmt.finish();
690                                // pipe_writer dropped here; pager thread finishes
691                                return;
692                            }
693                            Err(_) => {
694                                // Pager failed to start — fall through to non-streaming path
695                            }
696                        }
697                    } else {
698                        // No columns (non-SELECT result) — fall through to non-streaming path
699                    }
700                }
701                Err(e) => {
702                    eprintln!("{}", error::format_error_colored(&e, &shell.colorizer));
703                    if config.debug {
704                        eprintln!("Debug: {e:?}");
705                    }
706                    return;
707                }
708            }
709        }
710
711        match session.execute(trimmed).await {
712            Ok(result) => {
713                // Sync current keyspace for tab completion after USE
714                let upper_stmt = trimmed.to_uppercase();
715                if upper_stmt.starts_with("USE ") {
716                    if let Some(ref shared_ks) = shell.shared_keyspace {
717                        let ks = session.current_keyspace().map(String::from);
718                        let shared = Arc::clone(shared_ks);
719                        *shared.write().await = ks;
720                    }
721                }
722
723                // Invalidate schema cache after DDL statements
724                if upper_stmt.starts_with("CREATE ")
725                    || upper_stmt.starts_with("ALTER ")
726                    || upper_stmt.starts_with("DROP ")
727                {
728                    if let Some(ref cache) = shell.schema_cache {
729                        let mut c = cache.write().await;
730                        c.invalidate();
731                        let _ = c.refresh(session).await;
732                    }
733                }
734
735                // Print warnings if present (red bold when colored)
736                for warning in &result.warnings {
737                    let msg = format!("Warnings: {warning}");
738                    eprintln!("{}", shell.colorizer.colorize_warning(&msg));
739                }
740
741                if !result.columns.is_empty() {
742                    // Build column list for pager title (sticky header context)
743                    let col_title = result
744                        .columns
745                        .iter()
746                        .map(|c| c.name.as_str())
747                        .collect::<Vec<_>>()
748                        .join(" | ");
749
750                    let mut buf = Vec::new();
751                    if shell.expand {
752                        formatter::print_expanded(&result, &shell.colorizer, &mut buf);
753                    } else {
754                        formatter::print_tabular(&result, &shell.colorizer, &mut buf);
755                    }
756                    shell.display_output(&buf, &col_title);
757                }
758
759                // Print trace info if tracing is enabled
760                if session.is_tracing_enabled() && !upper_stmt.contains("SYSTEM_TRACES") {
761                    if let Some(trace_id) = result.tracing_id {
762                        // Brief delay to allow trace data to propagate
763                        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
764                        match session.get_trace_session(trace_id).await {
765                            Ok(Some(trace)) => {
766                                let mut buf = Vec::new();
767                                formatter::print_trace(&trace, &shell.colorizer, &mut buf);
768                                shell.display_output(&buf, "");
769                            }
770                            Ok(None) => {
771                                shell.outputln(&format!(
772                                "Trace {trace_id} not yet available. Use SHOW SESSION {trace_id} to view later."
773                            ));
774                            }
775                            Err(e) => {
776                                eprintln!("Error fetching trace: {e}");
777                            }
778                        }
779                    }
780                }
781            }
782            Err(e) => {
783                eprintln!("{}", error::format_error_colored(&e, &shell.colorizer));
784                if config.debug {
785                    eprintln!("Debug: {e:?}");
786                }
787            }
788        }
789    })
790}
791
792/// Print a basic help message matching Python cqlsh style.
793pub fn print_help(writer: &mut dyn std::io::Write) {
794    writeln!(
795        writer,
796        "\
797Documented shell commands:
798  CAPTURE       Capture output to file
799  CLEAR         Clear the terminal screen
800  CONSISTENCY   Get/set consistency level
801  DEBUG         Toggle debug mode
802  DESCRIBE      Schema introspection (CLUSTER, KEYSPACES, TABLE, etc.)
803  EXIT / QUIT   Exit the shell
804  EXPAND        Toggle expanded (vertical) output
805  HELP          Show this help or help on a topic
806  LOGIN         Re-authenticate with new credentials
807  PAGING        Configure automatic paging
808  SERIAL        Get/set serial consistency level
809  SHOW          Show version, host, or session trace info
810  SOURCE        Execute CQL from a file
811  TRACING       Toggle request tracing
812  UNICODE       Show Unicode character handling info
813
814Partially implemented:
815  COPY TO       Export table data to CSV file
816  COPY FROM     Import CSV data into a table
817
818CQL statements (executed via the database):
819  SELECT, INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, USE, etc."
820    )
821    .ok();
822}
823
824/// Print help for a specific topic.
825///
826/// This is a stub — full per-topic help text will be added in a later phase.
827/// For now, print a message indicating the topic exists or is unknown.
828pub fn print_help_topic(topic: &str, writer: &mut dyn std::io::Write) {
829    let shell_commands = [
830        "CAPTURE",
831        "CLEAR",
832        "CLS",
833        "CONSISTENCY",
834        "COPY",
835        "DESC",
836        "DESCRIBE",
837        "EXIT",
838        "EXPAND",
839        "HELP",
840        "LOGIN",
841        "PAGING",
842        "QUIT",
843        "SERIAL",
844        "SHOW",
845        "SOURCE",
846        "TRACING",
847        "UNICODE",
848        "DEBUG",
849        "USE",
850    ];
851    let cql_topics = [
852        "AGGREGATES",
853        "ALTER_KEYSPACE",
854        "ALTER_TABLE",
855        "ALTER_TYPE",
856        "ALTER_USER",
857        "APPLY",
858        "BEGIN",
859        "CREATE_AGGREGATE",
860        "CREATE_FUNCTION",
861        "CREATE_INDEX",
862        "CREATE_KEYSPACE",
863        "CREATE_TABLE",
864        "CREATE_TRIGGER",
865        "CREATE_TYPE",
866        "CREATE_USER",
867        "DELETE",
868        "DROP_AGGREGATE",
869        "DROP_FUNCTION",
870        "DROP_INDEX",
871        "DROP_KEYSPACE",
872        "DROP_TABLE",
873        "DROP_TRIGGER",
874        "DROP_TYPE",
875        "DROP_USER",
876        "GRANT",
877        "INSERT",
878        "LIST_PERMISSIONS",
879        "LIST_USERS",
880        "PERMISSIONS",
881        "REVOKE",
882        "SELECT",
883        "TEXT_OUTPUT",
884        "TRUNCATE",
885        "TYPES",
886        "UPDATE",
887        "USE",
888    ];
889
890    let upper = topic.to_uppercase();
891    if shell_commands.contains(&upper.as_str()) || cql_topics.contains(&upper.as_str()) {
892        writeln!(writer, "Help topic: {upper}").ok();
893        writeln!(writer, "(Detailed help text not yet implemented.)").ok();
894    } else {
895        writeln!(
896            writer,
897            "No help topic matching '{topic}'. Try HELP for a list of topics."
898        )
899        .ok();
900    }
901}
902
903/// Strip surrounding single or double quotes from a string.
904fn strip_quotes(s: &str) -> &str {
905    if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
906        &s[1..s.len() - 1]
907    } else {
908        s
909    }
910}
911
912/// Expand `~` at the start of a path to the user's home directory.
913fn expand_tilde(path: &str) -> PathBuf {
914    if let Some(rest) = path.strip_prefix("~/") {
915        if let Some(home) = dirs::home_dir() {
916            return home.join(rest);
917        }
918    } else if path == "~" {
919        if let Some(home) = dirs::home_dir() {
920            return home;
921        }
922    }
923    PathBuf::from(path)
924}
925
926/// Execute a SOURCE file: read CQL statements and execute them sequentially.
927///
928/// Shell commands in the file (SHOW, CONSISTENCY, etc.) are routed through
929/// `dispatch_input` just like interactive input — they are not sent to the DB.
930fn execute_source<'a>(
931    session: &'a mut CqlSession,
932    config: &'a MergedConfig,
933    shell: &'a mut ShellState,
934    path: &'a str,
935) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'a>> {
936    Box::pin(async move {
937        let expanded = expand_tilde(path);
938        let file = match File::open(&expanded) {
939            Ok(f) => f,
940            Err(e) => {
941                eprintln!("Could not open '{}': {e}", expanded.display());
942                return;
943            }
944        };
945
946        let reader = io::BufReader::new(file);
947        let mut parser = StatementParser::new();
948
949        for line_result in reader.lines() {
950            let line = match line_result {
951                Ok(l) => l,
952                Err(e) => {
953                    eprintln!("Error reading '{}': {e}", expanded.display());
954                    return;
955                }
956            };
957
958            // Check if it's a shell command on a fresh line
959            let trimmed = line.trim();
960            if parser.is_empty() && !trimmed.is_empty() && parser::is_shell_command(trimmed) {
961                dispatch_input(session, config, shell, trimmed).await;
962                continue;
963            }
964
965            match parser.feed_line(&line) {
966                ParseResult::Complete(statements) => {
967                    for stmt in statements {
968                        dispatch_input(session, config, shell, &stmt).await;
969                    }
970                }
971                ParseResult::Incomplete => {}
972            }
973        }
974    })
975}
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    // --- Prompt tests ---
982
983    #[test]
984    fn prompt_default() {
985        assert_eq!(build_prompt(None, None), "cqlsh> ");
986    }
987
988    #[test]
989    fn prompt_with_keyspace() {
990        assert_eq!(build_prompt(None, Some("my_ks")), "cqlsh:my_ks> ");
991    }
992
993    #[test]
994    fn prompt_with_username() {
995        assert_eq!(build_prompt(Some("admin"), None), "admin@cqlsh> ");
996    }
997
998    #[test]
999    fn prompt_with_username_and_keyspace() {
1000        assert_eq!(
1001            build_prompt(Some("admin"), Some("system")),
1002            "admin@cqlsh:system> "
1003        );
1004    }
1005
1006    // --- Helper function tests ---
1007
1008    #[test]
1009    fn strip_quotes_double() {
1010        assert_eq!(strip_quotes("\"hello\""), "hello");
1011    }
1012
1013    #[test]
1014    fn strip_quotes_single() {
1015        assert_eq!(strip_quotes("'hello'"), "hello");
1016    }
1017
1018    #[test]
1019    fn strip_quotes_none() {
1020        assert_eq!(strip_quotes("hello"), "hello");
1021    }
1022
1023    #[test]
1024    fn strip_quotes_mismatched() {
1025        assert_eq!(strip_quotes("\"hello'"), "\"hello'");
1026    }
1027
1028    #[test]
1029    fn expand_tilde_plain_path() {
1030        assert_eq!(
1031            expand_tilde("/tmp/file.cql"),
1032            PathBuf::from("/tmp/file.cql")
1033        );
1034    }
1035
1036    #[test]
1037    fn expand_tilde_home() {
1038        if let Some(home) = dirs::home_dir() {
1039            assert_eq!(expand_tilde("~/test.cql"), home.join("test.cql"));
1040        }
1041    }
1042
1043    #[test]
1044    fn shell_state_initial() {
1045        let state = ShellState {
1046            expand: false,
1047            paging_enabled: true,
1048            is_tty: false,
1049            debug: false,
1050            capture_file: None,
1051            capture_path: None,
1052            schema_cache: None,
1053            shared_keyspace: None,
1054            colorizer: CqlColorizer::new(false),
1055        };
1056        assert!(!state.expand);
1057        assert!(state.paging_enabled);
1058        assert!(state.capture_file.is_none());
1059        assert!(state.capture_path.is_none());
1060    }
1061
1062    // --- History path tests ---
1063
1064    #[test]
1065    fn history_disabled_returns_none() {
1066        let config = test_config(true);
1067        assert!(resolve_history_path(&config).is_none());
1068    }
1069
1070    #[test]
1071    fn history_enabled_returns_path() {
1072        let config = test_config(false);
1073        let path = resolve_history_path(&config);
1074        if dirs::home_dir().is_some() {
1075            assert!(path.is_some());
1076            let p = path.unwrap();
1077            assert!(p.to_string_lossy().contains("cql_history"));
1078        }
1079    }
1080
1081    /// Create a minimal MergedConfig for testing.
1082    fn test_config(disable_history: bool) -> MergedConfig {
1083        use crate::config::{ColorMode, CqlshrcConfig};
1084
1085        MergedConfig {
1086            host: "127.0.0.1".to_string(),
1087            port: 9042,
1088            username: None,
1089            password: None,
1090            keyspace: None,
1091            ssl: false,
1092            color: ColorMode::Auto,
1093            debug: false,
1094            tty: false,
1095            no_file_io: false,
1096            no_compact: false,
1097            disable_history,
1098            execute: None,
1099            file: None,
1100            connect_timeout: 5,
1101            request_timeout: 10,
1102            encoding: "utf-8".to_string(),
1103            cqlversion: None,
1104            protocol_version: None,
1105            consistency_level: None,
1106            serial_consistency_level: None,
1107            browser: None,
1108            secure_connect_bundle: None,
1109            fetch_size: 100,
1110            cqlshrc_path: PathBuf::from("/dev/null"),
1111            cqlshrc: CqlshrcConfig::default(),
1112        }
1113    }
1114
1115    // --- BUG: Shell commands with trailing semicolons ---
1116
1117    // --- SHOW SESSION tests ---
1118
1119    #[test]
1120    fn show_session_parses_uuid() {
1121        let input = "SHOW SESSION 12345678-1234-1234-1234-123456789abc";
1122        let upper = input.trim().to_uppercase();
1123        assert!(upper.starts_with("SHOW SESSION "));
1124        let uuid_str = input.trim()["SHOW SESSION ".len()..].trim();
1125        let uuid = uuid::Uuid::parse_str(uuid_str).unwrap();
1126        assert_eq!(uuid.to_string(), "12345678-1234-1234-1234-123456789abc");
1127    }
1128
1129    #[test]
1130    fn show_session_rejects_invalid_uuid() {
1131        let uuid_str = "not-a-uuid";
1132        assert!(uuid::Uuid::parse_str(uuid_str).is_err());
1133    }
1134
1135    #[test]
1136    fn show_session_bare_detected_as_shell_command() {
1137        assert!(parser::is_shell_command(
1138            "SHOW SESSION 12345678-1234-1234-1234-123456789abc"
1139        ));
1140        assert!(parser::is_shell_command("SHOW SESSION"));
1141    }
1142
1143    // --- Shell command semicolon tests ---
1144
1145    #[test]
1146    fn shell_command_semicolon_stripped_before_dispatch() {
1147        // Bug: `DESCRIBE KEYSPACES;` was dispatched with `;` intact,
1148        // causing describe::execute to receive args="KEYSPACES;" which didn't match.
1149        let input = "DESCRIBE KEYSPACES;";
1150        let clean = input.strip_suffix(';').unwrap_or(input).trim_end();
1151        assert_eq!(clean, "DESCRIBE KEYSPACES");
1152    }
1153
1154    #[test]
1155    fn shell_command_without_semicolon_unchanged() {
1156        let input = "DESCRIBE KEYSPACES";
1157        let clean = input.strip_suffix(';').unwrap_or(input).trim_end();
1158        assert_eq!(clean, "DESCRIBE KEYSPACES");
1159    }
1160
1161    #[test]
1162    fn describe_table_semicolon_stripped() {
1163        let input = "DESCRIBE TABLE test_ks.events;";
1164        let clean = input.strip_suffix(';').unwrap_or(input).trim_end();
1165        assert_eq!(clean, "DESCRIBE TABLE test_ks.events");
1166        // Verify the args extraction matches what dispatch_input does
1167        let trimmed = clean.trim();
1168        let upper = trimmed.to_uppercase();
1169        assert!(upper.starts_with("DESCRIBE "));
1170        let args = &trimmed["DESCRIBE ".len()..];
1171        assert_eq!(args.trim(), "TABLE test_ks.events");
1172    }
1173
1174    // --- BUG-4: SOURCE file parsing tests ---
1175
1176    #[test]
1177    fn parse_batch_includes_shell_commands() {
1178        let input = "SELECT 1;\nSHOW VERSION\n";
1179        let stmts = parser::parse_batch(input);
1180        assert_eq!(stmts.len(), 2);
1181        assert_eq!(stmts[0], "SELECT 1");
1182        assert_eq!(stmts[1], "SHOW VERSION");
1183    }
1184
1185    #[test]
1186    fn parse_batch_shell_command_with_semicolon() {
1187        let input = "SHOW VERSION;\nSELECT 1;\n";
1188        let stmts = parser::parse_batch(input);
1189        assert_eq!(stmts.len(), 2);
1190        assert_eq!(stmts[0], "SHOW VERSION");
1191        assert_eq!(stmts[1], "SELECT 1");
1192    }
1193
1194    #[test]
1195    fn source_file_line_by_line_detects_shell_commands() {
1196        let lines = vec!["CONSISTENCY QUORUM", "SELECT * FROM t;", "SHOW HOST"];
1197        let mut shell_cmds = Vec::new();
1198        let mut cql_stmts = Vec::new();
1199        let mut parser = StatementParser::new();
1200
1201        for line in &lines {
1202            let trimmed = line.trim();
1203            if parser.is_empty() && !trimmed.is_empty() && parser::is_shell_command(trimmed) {
1204                shell_cmds.push(trimmed.to_string());
1205                continue;
1206            }
1207            if let ParseResult::Complete(stmts) = parser.feed_line(line) {
1208                cql_stmts.extend(stmts);
1209            }
1210        }
1211
1212        assert_eq!(shell_cmds, vec!["CONSISTENCY QUORUM", "SHOW HOST"]);
1213        assert_eq!(cql_stmts, vec!["SELECT * FROM t"]);
1214    }
1215
1216    // --- BUG-5: Multi-line paste tests ---
1217
1218    #[test]
1219    fn multiline_paste_splits_into_lines() {
1220        let pasted = "SHOW VERSION\nSELECT 1;\nSHOW HOST";
1221        let lines: Vec<&str> = pasted.split('\n').collect();
1222        assert_eq!(lines.len(), 3);
1223        assert_eq!(lines[0], "SHOW VERSION");
1224        assert_eq!(lines[1], "SELECT 1;");
1225        assert_eq!(lines[2], "SHOW HOST");
1226        assert!(parser::is_shell_command(lines[0].trim()));
1227        assert!(parser::is_shell_command(lines[2].trim()));
1228    }
1229
1230    #[test]
1231    fn could_return_rows_detects_select() {
1232        assert!(super::could_return_rows("SELECT * FROM foo"));
1233        assert!(super::could_return_rows("select count(*) from bar"));
1234        assert!(!super::could_return_rows("INSERT INTO foo (a) VALUES (1)"));
1235        assert!(!super::could_return_rows("USE my_keyspace"));
1236    }
1237
1238    #[test]
1239    fn multiline_paste_shell_command_not_concatenated() {
1240        let pasted = "CAPTURE '/tmp/test.txt'\nSELECT 1;\nCAPTURE OFF";
1241        let lines: Vec<&str> = pasted.split('\n').collect();
1242        assert_eq!(lines.len(), 3);
1243        assert_eq!(lines[0], "CAPTURE '/tmp/test.txt'");
1244        assert!(parser::is_shell_command(lines[0].trim()));
1245    }
1246
1247    #[test]
1248    fn print_help_writes_output() {
1249        let mut buf = Vec::<u8>::new();
1250        print_help(&mut buf);
1251        let out = String::from_utf8(buf).unwrap();
1252        assert!(out.contains("HELP"));
1253        assert!(out.contains("EXIT"));
1254        assert!(out.contains("DESCRIBE"));
1255    }
1256
1257    #[test]
1258    fn print_help_topic_known() {
1259        let mut buf = Vec::<u8>::new();
1260        print_help_topic("CONSISTENCY", &mut buf);
1261        let out = String::from_utf8(buf).unwrap();
1262        assert!(!out.is_empty());
1263    }
1264
1265    #[test]
1266    fn print_help_topic_unknown() {
1267        let mut buf = Vec::<u8>::new();
1268        print_help_topic("NONEXISTENT_TOPIC_XYZ", &mut buf);
1269        let out = String::from_utf8(buf).unwrap();
1270        assert!(out.contains("Unknown") || out.contains("unknown") || !out.is_empty());
1271    }
1272
1273    #[test]
1274    fn print_help_topic_cql() {
1275        let mut buf = Vec::<u8>::new();
1276        print_help_topic("SELECT", &mut buf);
1277        let out = String::from_utf8(buf).unwrap();
1278        assert!(!out.is_empty());
1279    }
1280}