Skip to main content

cqlsh_rs/
executor.rs

1//! Non-interactive CQL execution engine with injectable output writer.
2
3use std::io::{self, BufRead, Write};
4
5use crate::colorizer::CqlColorizer;
6use crate::config::MergedConfig;
7use crate::parser::{self, ParseResult, StatementParser};
8use crate::session::CqlSession;
9use crate::{describe, error, formatter};
10
11/// Execute a CQL string from the `-e` flag (semicolon-separated statements).
12///
13/// All output is written to `writer`; errors go to stderr.
14/// Returns exit code: 0 on success, 1 on any CQL error.
15pub async fn execute_cql_string(
16    session: &mut CqlSession,
17    config: &MergedConfig,
18    colorizer: &CqlColorizer,
19    cql_string: &str,
20    writer: &mut dyn Write,
21) -> i32 {
22    // Python cqlsh accepts `-e "SELECT 1"` without a trailing semicolon.
23    let with_semi;
24    let cql_string = {
25        let t = cql_string.trim_end();
26        // Python cqlsh accepts `-e "SELECT 1"` without a trailing semicolon.
27        if !t.is_empty() && !t.ends_with(';') {
28            with_semi = format!("{t};");
29            &with_semi
30        } else {
31            cql_string
32        }
33    };
34    let statements = parser::parse_batch(cql_string);
35    let mut had_error = false;
36    let mut debug = config.debug;
37
38    for stmt in statements {
39        if !execute_single_statement(
40            session, config, colorizer, &mut debug, &stmt, None, 0, writer,
41        )
42        .await
43        {
44            had_error = true;
45        }
46    }
47
48    if had_error {
49        1
50    } else {
51        0
52    }
53}
54
55/// Execute CQL statements from a file (`-f` flag).
56///
57/// Returns exit code: 0 on success, 1 on any CQL or I/O error.
58pub async fn execute_cql_file(
59    session: &mut CqlSession,
60    config: &MergedConfig,
61    colorizer: &CqlColorizer,
62    file_path: &str,
63    writer: &mut dyn Write,
64) -> i32 {
65    let file = match std::fs::File::open(file_path) {
66        Ok(f) => f,
67        Err(e) => {
68            eprintln!("Could not open '{}': {e}", file_path);
69            return 1;
70        }
71    };
72    let reader = io::BufReader::new(file);
73    execute_cql_reader(session, config, colorizer, reader, file_path, writer).await
74}
75
76/// Execute CQL statements from any `BufRead` source (file or stdin).
77///
78/// Returns exit code: 0 on success, 1 on any CQL or I/O error.
79pub async fn execute_cql_reader<R: io::BufRead>(
80    session: &mut CqlSession,
81    config: &MergedConfig,
82    colorizer: &CqlColorizer,
83    reader: R,
84    source_name: &str,
85    writer: &mut dyn Write,
86) -> i32 {
87    let mut stmt_parser = StatementParser::new();
88    let mut had_error = false;
89    let mut debug = config.debug;
90    let mut line_number: usize = 0;
91    let mut stmt_start_line: usize = 1;
92
93    for line_result in reader.lines() {
94        let line = match line_result {
95            Ok(l) => l,
96            Err(e) => {
97                eprintln!("Error reading '{}': {e}", source_name);
98                return 1;
99            }
100        };
101        line_number += 1;
102
103        if stmt_parser.is_empty() {
104            stmt_start_line = line_number;
105        }
106
107        let trimmed = line.trim();
108        if stmt_parser.is_empty() && !trimmed.is_empty() && parser::is_shell_command(trimmed) {
109            let clean = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
110            if !execute_single_statement(
111                session,
112                config,
113                colorizer,
114                &mut debug,
115                clean,
116                Some(source_name),
117                stmt_start_line,
118                writer,
119            )
120            .await
121            {
122                had_error = true;
123            }
124            continue;
125        }
126
127        if let ParseResult::Complete(statements) = stmt_parser.feed_line(&line) {
128            for stmt in statements {
129                if !execute_single_statement(
130                    session,
131                    config,
132                    colorizer,
133                    &mut debug,
134                    &stmt,
135                    Some(source_name),
136                    stmt_start_line,
137                    writer,
138                )
139                .await
140                {
141                    had_error = true;
142                }
143            }
144            stmt_start_line = line_number + 1;
145        }
146    }
147
148    if had_error {
149        1
150    } else {
151        0
152    }
153}
154
155/// Execute a single CQL statement or shell command in non-interactive mode.
156///
157/// Output is written to `writer`; errors and warnings are written to stderr.
158/// `debug` is mutable so that `DEBUG ON/OFF` affects subsequent statements.
159///
160/// Returns `true` on success, `false` on error.
161#[allow(clippy::too_many_arguments)]
162pub fn execute_single_statement<'a>(
163    session: &'a mut CqlSession,
164    config: &'a MergedConfig,
165    colorizer: &'a CqlColorizer,
166    debug: &'a mut bool,
167    input: &'a str,
168    source_name: Option<&'a str>,
169    line_number: usize,
170    writer: &'a mut dyn Write,
171) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + 'a>> {
172    Box::pin(async move {
173        let trimmed = input.trim();
174        if trimmed.is_empty() {
175            return true;
176        }
177
178        let upper = trimmed.to_uppercase();
179
180        if upper == "DEBUG" {
181            if *debug {
182                let _ = writeln!(
183                    writer,
184                    "Debug output is currently enabled. Use DEBUG OFF to disable."
185                );
186            } else {
187                let _ = writeln!(
188                    writer,
189                    "Debug output is currently disabled. Use DEBUG ON to enable."
190                );
191            }
192            return true;
193        }
194        if upper == "DEBUG ON" {
195            *debug = true;
196            let _ = writeln!(writer, "Now printing debug output.");
197            return true;
198        }
199        if upper == "DEBUG OFF" {
200            *debug = false;
201            let _ = writeln!(writer, "Disabled debug output.");
202            return true;
203        }
204
205        if upper == "UNICODE" {
206            let _ = writeln!(
207                writer,
208                "Encoding: {}\nDefault encoding: utf-8",
209                config.encoding
210            );
211            return true;
212        }
213
214        if upper == "CONSISTENCY" {
215            let cl = session.get_consistency();
216            let _ = writeln!(writer, "Current consistency level is {cl}.");
217            return true;
218        }
219        if let Some(rest) = upper.strip_prefix("CONSISTENCY ") {
220            let level = rest.trim();
221            match session.set_consistency_str(level) {
222                Ok(()) => {
223                    let _ = writeln!(writer, "Consistency level set to {level}.");
224                }
225                Err(e) => {
226                    eprintln!("{e}");
227                    return false;
228                }
229            }
230            return true;
231        }
232        if upper == "SERIAL CONSISTENCY" {
233            match session.get_serial_consistency() {
234                Some(scl) => {
235                    let _ = writeln!(writer, "Current serial consistency level is {scl}.");
236                }
237                None => {
238                    let _ = writeln!(writer, "Current serial consistency level is SERIAL.");
239                }
240            }
241            return true;
242        }
243        if let Some(rest) = upper.strip_prefix("SERIAL CONSISTENCY ") {
244            let level = rest.trim();
245            match session.set_serial_consistency_str(level) {
246                Ok(()) => {
247                    let _ = writeln!(writer, "Serial consistency level set to {level}.");
248                }
249                Err(e) => {
250                    eprintln!("{e}");
251                    return false;
252                }
253            }
254            return true;
255        }
256        if upper == "TRACING OFF" || upper == "TRACING" {
257            session.set_tracing(false);
258            let _ = writeln!(writer, "Tracing is disabled");
259            return true;
260        }
261        if upper == "TRACING ON" {
262            session.set_tracing(true);
263            let _ = writeln!(writer, "Tracing is enabled");
264            return true;
265        }
266        if upper == "SHOW VERSION" {
267            let _ = writeln!(writer, "[cqlsh {}]", env!("CARGO_PKG_VERSION"));
268            return true;
269        }
270        if upper == "SHOW HOST" {
271            let _ = writeln!(writer, "Connected to: {}", session.connection_display);
272            return true;
273        }
274
275        if upper == "DESCRIBE"
276            || upper == "DESC"
277            || upper.starts_with("DESCRIBE ")
278            || upper.starts_with("DESC ")
279        {
280            let args = if upper.starts_with("DESCRIBE ") {
281                trimmed["DESCRIBE ".len()..].trim()
282            } else if upper.starts_with("DESC ") {
283                trimmed["DESC ".len()..].trim()
284            } else {
285                ""
286            };
287            let mut buf = Vec::new();
288            match describe::execute(session, args, &mut buf).await {
289                Ok(()) => {
290                    let _ = writer.write_all(&buf);
291                }
292                Err(e) => {
293                    eprintln!("Error: {e}");
294                    return false;
295                }
296            }
297            return true;
298        }
299
300        if upper.starts_with("SOURCE ") {
301            let path = trimmed["SOURCE ".len()..].trim();
302            let path = strip_quotes(path);
303            if config.no_file_io {
304                eprintln!("File I/O is disabled (--no-file-io).");
305                return true;
306            }
307            let expanded = expand_tilde(path);
308            let file = match std::fs::File::open(&expanded) {
309                Ok(f) => f,
310                Err(e) => {
311                    eprintln!("Could not open '{}': {e}", expanded.display());
312                    return false;
313                }
314            };
315            let reader = io::BufReader::new(file);
316            let source_name_str = expanded.display().to_string();
317            let mut stmt_parser = StatementParser::new();
318            let mut had_error = false;
319            let mut src_line_number: usize = 0;
320            let mut src_stmt_start: usize = 1;
321            for line_result in reader.lines() {
322                let line: String = match line_result {
323                    Ok(l) => l,
324                    Err(e) => {
325                        eprintln!("Error reading '{}': {e}", source_name_str);
326                        return false;
327                    }
328                };
329                src_line_number += 1;
330                if stmt_parser.is_empty() {
331                    src_stmt_start = src_line_number;
332                }
333                let ltrimmed = line.trim();
334                if stmt_parser.is_empty()
335                    && !ltrimmed.is_empty()
336                    && parser::is_shell_command(ltrimmed)
337                {
338                    let clean = ltrimmed.strip_suffix(';').unwrap_or(ltrimmed).trim_end();
339                    if !execute_single_statement(
340                        session,
341                        config,
342                        colorizer,
343                        debug,
344                        clean,
345                        Some(&source_name_str),
346                        src_stmt_start,
347                        writer,
348                    )
349                    .await
350                    {
351                        had_error = true;
352                    }
353                    continue;
354                }
355                if let ParseResult::Complete(statements) = stmt_parser.feed_line(&line) {
356                    for stmt in statements {
357                        if !execute_single_statement(
358                            session,
359                            config,
360                            colorizer,
361                            debug,
362                            &stmt,
363                            Some(&source_name_str),
364                            src_stmt_start,
365                            writer,
366                        )
367                        .await
368                        {
369                            had_error = true;
370                        }
371                    }
372                    src_stmt_start = src_line_number + 1;
373                }
374            }
375            return !had_error;
376        }
377        if upper == "SOURCE" {
378            eprintln!("SOURCE requires a file path argument.");
379            return true;
380        }
381
382        if upper == "CLEAR" || upper == "CLS" {
383            let _ = write!(writer, "\x1B[2J\x1B[1;1H");
384            return true;
385        }
386
387        if upper == "LOGIN" {
388            eprintln!("Usage: LOGIN <username> [<password>]");
389            return false;
390        }
391        if upper.starts_with("LOGIN ") {
392            let args = trimmed["LOGIN ".len()..].trim();
393            let parts: Vec<&str> = args.splitn(2, char::is_whitespace).collect();
394            let new_user = parts[0].to_string();
395            let new_pass = if parts.len() > 1 {
396                Some(parts[1].trim_matches('\'').to_string())
397            } else {
398                None
399            };
400            let mut new_config = config.clone();
401            new_config.username = Some(new_user);
402            new_config.password = new_pass;
403            let prev_keyspace = session.current_keyspace().map(str::to_string);
404            match crate::session::CqlSession::connect(&new_config).await {
405                Ok(mut new_session) => {
406                    if let Some(ks) = prev_keyspace {
407                        if let Err(e) = new_session.use_keyspace(&ks).await {
408                            eprintln!("Warning: could not restore keyspace '{ks}': {e}");
409                        }
410                    }
411                    *session = new_session;
412                }
413                Err(e) => {
414                    eprintln!("{}", error::format_error_colored(&e, colorizer));
415                    return false;
416                }
417            }
418            return true;
419        }
420
421        // Handle COPY TO
422        if upper.starts_with("COPY ") && upper.contains(" TO ") {
423            if config.no_file_io {
424                eprintln!("File I/O is disabled (--no-file-io).");
425                return false;
426            }
427            match crate::copy::parse_copy_to(trimmed) {
428                Ok(cmd) => {
429                    let ks = session.current_keyspace();
430                    match crate::copy::execute_copy_to(session, &cmd, ks).await {
431                        Ok(()) => return true,
432                        Err(e) => {
433                            eprintln!("COPY TO error: {e}");
434                            return false;
435                        }
436                    }
437                }
438                Err(e) => {
439                    eprintln!("Invalid COPY TO syntax: {e}");
440                    return false;
441                }
442            }
443        }
444
445        // Handle COPY FROM
446        if upper.starts_with("COPY ") && upper.contains(" FROM ") {
447            if config.no_file_io {
448                eprintln!("File I/O is disabled (--no-file-io).");
449                return false;
450            }
451            match crate::copy::parse_copy_from(trimmed) {
452                Ok(cmd) => {
453                    let ks = session.current_keyspace();
454                    match crate::copy::execute_copy_from(session, &cmd, ks).await {
455                        Ok(()) => return true,
456                        Err(e) => {
457                            eprintln!("COPY FROM error: {e}");
458                            return false;
459                        }
460                    }
461                }
462                Err(e) => {
463                    eprintln!("Invalid COPY FROM syntax: {e}");
464                    return false;
465                }
466            }
467        }
468
469        if upper == "QUIT"
470            || upper == "EXIT"
471            || upper == "HELP"
472            || upper == "?"
473            || upper.starts_with("HELP ")
474            || upper == "EXPAND"
475            || upper == "EXPAND ON"
476            || upper == "EXPAND OFF"
477            || upper == "PAGING"
478            || upper == "PAGING ON"
479            || upper == "PAGING OFF"
480            || upper.starts_with("PAGING ")
481            || upper == "CAPTURE"
482            || upper == "CAPTURE OFF"
483            || upper.starts_with("CAPTURE ")
484        {
485            return true;
486        }
487
488        match session.execute(trimmed).await {
489            Ok(result) => {
490                for warning in &result.warnings {
491                    let msg = format!("Warnings: {warning}");
492                    eprintln!("{}", colorizer.colorize_warning(&msg));
493                }
494
495                if !result.columns.is_empty() {
496                    let mut buf = Vec::new();
497                    formatter::print_tabular(&result, colorizer, &mut buf);
498                    let _ = writer.write_all(&buf);
499                }
500
501                if session.is_tracing_enabled() && !upper.contains("SYSTEM_TRACES") {
502                    if let Some(trace_id) = result.tracing_id {
503                        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
504                        match session.get_trace_session(trace_id).await {
505                            Ok(Some(trace)) => {
506                                let mut buf = Vec::new();
507                                formatter::print_trace(&trace, colorizer, &mut buf);
508                                let _ = writer.write_all(&buf);
509                            }
510                            Ok(None) => {
511                                eprintln!(
512                                    "Trace {trace_id} not yet available. Use SHOW SESSION {trace_id} to view later."
513                                );
514                            }
515                            Err(e) => {
516                                eprintln!("Error fetching trace: {e}");
517                            }
518                        }
519                    }
520                }
521
522                true
523            }
524            Err(e) => {
525                let err_msg = error::format_error_colored(&e, colorizer);
526                if let Some(src) = source_name {
527                    eprintln!("{src}:{line_number}:{err_msg}");
528                } else {
529                    eprintln!("{err_msg}");
530                }
531                if *debug {
532                    eprintln!("Debug: {e:?}");
533                }
534                false
535            }
536        }
537    })
538}
539
540fn strip_quotes(s: &str) -> &str {
541    if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
542        &s[1..s.len() - 1]
543    } else {
544        s
545    }
546}
547
548fn expand_tilde(path: &str) -> std::path::PathBuf {
549    if let Some(rest) = path.strip_prefix("~/") {
550        if let Some(home) = dirs::home_dir() {
551            return home.join(rest);
552        }
553    } else if path == "~" {
554        if let Some(home) = dirs::home_dir() {
555            return home;
556        }
557    }
558    std::path::PathBuf::from(path)
559}