1use std::io::{self, BufRead, Write};
4
5use crate::colorizer::CqlColorizer;
6use crate::config::MergedConfig;
7use crate::parser::{self, ParseResult, StatementParser};
8use crate::session::CqlSession;
9use crate::{describe, error, formatter};
10
11pub async fn execute_cql_string(
16 session: &mut CqlSession,
17 config: &MergedConfig,
18 colorizer: &CqlColorizer,
19 cql_string: &str,
20 writer: &mut dyn Write,
21) -> i32 {
22 let with_semi;
24 let cql_string = {
25 let t = cql_string.trim_end();
26 if !t.is_empty() && !t.ends_with(';') {
28 with_semi = format!("{t};");
29 &with_semi
30 } else {
31 cql_string
32 }
33 };
34 let statements = parser::parse_batch(cql_string);
35 let mut had_error = false;
36 let mut debug = config.debug;
37
38 for stmt in statements {
39 if !execute_single_statement(
40 session, config, colorizer, &mut debug, &stmt, None, 0, writer,
41 )
42 .await
43 {
44 had_error = true;
45 }
46 }
47
48 if had_error {
49 1
50 } else {
51 0
52 }
53}
54
55pub async fn execute_cql_file(
59 session: &mut CqlSession,
60 config: &MergedConfig,
61 colorizer: &CqlColorizer,
62 file_path: &str,
63 writer: &mut dyn Write,
64) -> i32 {
65 let file = match std::fs::File::open(file_path) {
66 Ok(f) => f,
67 Err(e) => {
68 eprintln!("Could not open '{}': {e}", file_path);
69 return 1;
70 }
71 };
72 let reader = io::BufReader::new(file);
73 execute_cql_reader(session, config, colorizer, reader, file_path, writer).await
74}
75
76pub async fn execute_cql_reader<R: io::BufRead>(
80 session: &mut CqlSession,
81 config: &MergedConfig,
82 colorizer: &CqlColorizer,
83 reader: R,
84 source_name: &str,
85 writer: &mut dyn Write,
86) -> i32 {
87 let mut stmt_parser = StatementParser::new();
88 let mut had_error = false;
89 let mut debug = config.debug;
90 let mut line_number: usize = 0;
91 let mut stmt_start_line: usize = 1;
92
93 for line_result in reader.lines() {
94 let line = match line_result {
95 Ok(l) => l,
96 Err(e) => {
97 eprintln!("Error reading '{}': {e}", source_name);
98 return 1;
99 }
100 };
101 line_number += 1;
102
103 if stmt_parser.is_empty() {
104 stmt_start_line = line_number;
105 }
106
107 let trimmed = line.trim();
108 if stmt_parser.is_empty() && !trimmed.is_empty() && parser::is_shell_command(trimmed) {
109 let clean = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
110 if !execute_single_statement(
111 session,
112 config,
113 colorizer,
114 &mut debug,
115 clean,
116 Some(source_name),
117 stmt_start_line,
118 writer,
119 )
120 .await
121 {
122 had_error = true;
123 }
124 continue;
125 }
126
127 if let ParseResult::Complete(statements) = stmt_parser.feed_line(&line) {
128 for stmt in statements {
129 if !execute_single_statement(
130 session,
131 config,
132 colorizer,
133 &mut debug,
134 &stmt,
135 Some(source_name),
136 stmt_start_line,
137 writer,
138 )
139 .await
140 {
141 had_error = true;
142 }
143 }
144 stmt_start_line = line_number + 1;
145 }
146 }
147
148 if had_error {
149 1
150 } else {
151 0
152 }
153}
154
155#[allow(clippy::too_many_arguments)]
162pub fn execute_single_statement<'a>(
163 session: &'a mut CqlSession,
164 config: &'a MergedConfig,
165 colorizer: &'a CqlColorizer,
166 debug: &'a mut bool,
167 input: &'a str,
168 source_name: Option<&'a str>,
169 line_number: usize,
170 writer: &'a mut dyn Write,
171) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + 'a>> {
172 Box::pin(async move {
173 let trimmed = input.trim();
174 if trimmed.is_empty() {
175 return true;
176 }
177
178 let upper = trimmed.to_uppercase();
179
180 if upper == "DEBUG" {
181 if *debug {
182 let _ = writeln!(
183 writer,
184 "Debug output is currently enabled. Use DEBUG OFF to disable."
185 );
186 } else {
187 let _ = writeln!(
188 writer,
189 "Debug output is currently disabled. Use DEBUG ON to enable."
190 );
191 }
192 return true;
193 }
194 if upper == "DEBUG ON" {
195 *debug = true;
196 let _ = writeln!(writer, "Now printing debug output.");
197 return true;
198 }
199 if upper == "DEBUG OFF" {
200 *debug = false;
201 let _ = writeln!(writer, "Disabled debug output.");
202 return true;
203 }
204
205 if upper == "UNICODE" {
206 let _ = writeln!(
207 writer,
208 "Encoding: {}\nDefault encoding: utf-8",
209 config.encoding
210 );
211 return true;
212 }
213
214 if upper == "CONSISTENCY" {
215 let cl = session.get_consistency();
216 let _ = writeln!(writer, "Current consistency level is {cl}.");
217 return true;
218 }
219 if let Some(rest) = upper.strip_prefix("CONSISTENCY ") {
220 let level = rest.trim();
221 match session.set_consistency_str(level) {
222 Ok(()) => {
223 let _ = writeln!(writer, "Consistency level set to {level}.");
224 }
225 Err(e) => {
226 eprintln!("{e}");
227 return false;
228 }
229 }
230 return true;
231 }
232 if upper == "SERIAL CONSISTENCY" {
233 match session.get_serial_consistency() {
234 Some(scl) => {
235 let _ = writeln!(writer, "Current serial consistency level is {scl}.");
236 }
237 None => {
238 let _ = writeln!(writer, "Current serial consistency level is SERIAL.");
239 }
240 }
241 return true;
242 }
243 if let Some(rest) = upper.strip_prefix("SERIAL CONSISTENCY ") {
244 let level = rest.trim();
245 match session.set_serial_consistency_str(level) {
246 Ok(()) => {
247 let _ = writeln!(writer, "Serial consistency level set to {level}.");
248 }
249 Err(e) => {
250 eprintln!("{e}");
251 return false;
252 }
253 }
254 return true;
255 }
256 if upper == "TRACING OFF" || upper == "TRACING" {
257 session.set_tracing(false);
258 let _ = writeln!(writer, "Tracing is disabled");
259 return true;
260 }
261 if upper == "TRACING ON" {
262 session.set_tracing(true);
263 let _ = writeln!(writer, "Tracing is enabled");
264 return true;
265 }
266 if upper == "SHOW VERSION" {
267 let _ = writeln!(writer, "[cqlsh {}]", env!("CARGO_PKG_VERSION"));
268 return true;
269 }
270 if upper == "SHOW HOST" {
271 let _ = writeln!(writer, "Connected to: {}", session.connection_display);
272 return true;
273 }
274
275 if upper == "DESCRIBE"
276 || upper == "DESC"
277 || upper.starts_with("DESCRIBE ")
278 || upper.starts_with("DESC ")
279 {
280 let args = if upper.starts_with("DESCRIBE ") {
281 trimmed["DESCRIBE ".len()..].trim()
282 } else if upper.starts_with("DESC ") {
283 trimmed["DESC ".len()..].trim()
284 } else {
285 ""
286 };
287 let mut buf = Vec::new();
288 match describe::execute(session, args, &mut buf).await {
289 Ok(()) => {
290 let _ = writer.write_all(&buf);
291 }
292 Err(e) => {
293 eprintln!("Error: {e}");
294 return false;
295 }
296 }
297 return true;
298 }
299
300 if upper.starts_with("SOURCE ") {
301 let path = trimmed["SOURCE ".len()..].trim();
302 let path = strip_quotes(path);
303 if config.no_file_io {
304 eprintln!("File I/O is disabled (--no-file-io).");
305 return true;
306 }
307 let expanded = expand_tilde(path);
308 let file = match std::fs::File::open(&expanded) {
309 Ok(f) => f,
310 Err(e) => {
311 eprintln!("Could not open '{}': {e}", expanded.display());
312 return false;
313 }
314 };
315 let reader = io::BufReader::new(file);
316 let source_name_str = expanded.display().to_string();
317 let mut stmt_parser = StatementParser::new();
318 let mut had_error = false;
319 let mut src_line_number: usize = 0;
320 let mut src_stmt_start: usize = 1;
321 for line_result in reader.lines() {
322 let line: String = match line_result {
323 Ok(l) => l,
324 Err(e) => {
325 eprintln!("Error reading '{}': {e}", source_name_str);
326 return false;
327 }
328 };
329 src_line_number += 1;
330 if stmt_parser.is_empty() {
331 src_stmt_start = src_line_number;
332 }
333 let ltrimmed = line.trim();
334 if stmt_parser.is_empty()
335 && !ltrimmed.is_empty()
336 && parser::is_shell_command(ltrimmed)
337 {
338 let clean = ltrimmed.strip_suffix(';').unwrap_or(ltrimmed).trim_end();
339 if !execute_single_statement(
340 session,
341 config,
342 colorizer,
343 debug,
344 clean,
345 Some(&source_name_str),
346 src_stmt_start,
347 writer,
348 )
349 .await
350 {
351 had_error = true;
352 }
353 continue;
354 }
355 if let ParseResult::Complete(statements) = stmt_parser.feed_line(&line) {
356 for stmt in statements {
357 if !execute_single_statement(
358 session,
359 config,
360 colorizer,
361 debug,
362 &stmt,
363 Some(&source_name_str),
364 src_stmt_start,
365 writer,
366 )
367 .await
368 {
369 had_error = true;
370 }
371 }
372 src_stmt_start = src_line_number + 1;
373 }
374 }
375 return !had_error;
376 }
377 if upper == "SOURCE" {
378 eprintln!("SOURCE requires a file path argument.");
379 return true;
380 }
381
382 if upper == "CLEAR" || upper == "CLS" {
383 let _ = write!(writer, "\x1B[2J\x1B[1;1H");
384 return true;
385 }
386
387 if upper == "LOGIN" {
388 eprintln!("Usage: LOGIN <username> [<password>]");
389 return false;
390 }
391 if upper.starts_with("LOGIN ") {
392 let args = trimmed["LOGIN ".len()..].trim();
393 let parts: Vec<&str> = args.splitn(2, char::is_whitespace).collect();
394 let new_user = parts[0].to_string();
395 let new_pass = if parts.len() > 1 {
396 Some(parts[1].trim_matches('\'').to_string())
397 } else {
398 None
399 };
400 let mut new_config = config.clone();
401 new_config.username = Some(new_user);
402 new_config.password = new_pass;
403 let prev_keyspace = session.current_keyspace().map(str::to_string);
404 match crate::session::CqlSession::connect(&new_config).await {
405 Ok(mut new_session) => {
406 if let Some(ks) = prev_keyspace {
407 if let Err(e) = new_session.use_keyspace(&ks).await {
408 eprintln!("Warning: could not restore keyspace '{ks}': {e}");
409 }
410 }
411 *session = new_session;
412 }
413 Err(e) => {
414 eprintln!("{}", error::format_error_colored(&e, colorizer));
415 return false;
416 }
417 }
418 return true;
419 }
420
421 if upper.starts_with("COPY ") && upper.contains(" TO ") {
423 if config.no_file_io {
424 eprintln!("File I/O is disabled (--no-file-io).");
425 return false;
426 }
427 match crate::copy::parse_copy_to(trimmed) {
428 Ok(cmd) => {
429 let ks = session.current_keyspace();
430 match crate::copy::execute_copy_to(session, &cmd, ks).await {
431 Ok(()) => return true,
432 Err(e) => {
433 eprintln!("COPY TO error: {e}");
434 return false;
435 }
436 }
437 }
438 Err(e) => {
439 eprintln!("Invalid COPY TO syntax: {e}");
440 return false;
441 }
442 }
443 }
444
445 if upper.starts_with("COPY ") && upper.contains(" FROM ") {
447 if config.no_file_io {
448 eprintln!("File I/O is disabled (--no-file-io).");
449 return false;
450 }
451 match crate::copy::parse_copy_from(trimmed) {
452 Ok(cmd) => {
453 let ks = session.current_keyspace();
454 match crate::copy::execute_copy_from(session, &cmd, ks).await {
455 Ok(()) => return true,
456 Err(e) => {
457 eprintln!("COPY FROM error: {e}");
458 return false;
459 }
460 }
461 }
462 Err(e) => {
463 eprintln!("Invalid COPY FROM syntax: {e}");
464 return false;
465 }
466 }
467 }
468
469 if upper == "QUIT"
470 || upper == "EXIT"
471 || upper == "HELP"
472 || upper == "?"
473 || upper.starts_with("HELP ")
474 || upper == "EXPAND"
475 || upper == "EXPAND ON"
476 || upper == "EXPAND OFF"
477 || upper == "PAGING"
478 || upper == "PAGING ON"
479 || upper == "PAGING OFF"
480 || upper.starts_with("PAGING ")
481 || upper == "CAPTURE"
482 || upper == "CAPTURE OFF"
483 || upper.starts_with("CAPTURE ")
484 {
485 return true;
486 }
487
488 match session.execute(trimmed).await {
489 Ok(result) => {
490 for warning in &result.warnings {
491 let msg = format!("Warnings: {warning}");
492 eprintln!("{}", colorizer.colorize_warning(&msg));
493 }
494
495 if !result.columns.is_empty() {
496 let mut buf = Vec::new();
497 formatter::print_tabular(&result, colorizer, &mut buf);
498 let _ = writer.write_all(&buf);
499 }
500
501 if session.is_tracing_enabled() && !upper.contains("SYSTEM_TRACES") {
502 if let Some(trace_id) = result.tracing_id {
503 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
504 match session.get_trace_session(trace_id).await {
505 Ok(Some(trace)) => {
506 let mut buf = Vec::new();
507 formatter::print_trace(&trace, colorizer, &mut buf);
508 let _ = writer.write_all(&buf);
509 }
510 Ok(None) => {
511 eprintln!(
512 "Trace {trace_id} not yet available. Use SHOW SESSION {trace_id} to view later."
513 );
514 }
515 Err(e) => {
516 eprintln!("Error fetching trace: {e}");
517 }
518 }
519 }
520 }
521
522 true
523 }
524 Err(e) => {
525 let err_msg = error::format_error_colored(&e, colorizer);
526 if let Some(src) = source_name {
527 eprintln!("{src}:{line_number}:{err_msg}");
528 } else {
529 eprintln!("{err_msg}");
530 }
531 if *debug {
532 eprintln!("Debug: {e:?}");
533 }
534 false
535 }
536 }
537 })
538}
539
540fn strip_quotes(s: &str) -> &str {
541 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
542 &s[1..s.len() - 1]
543 } else {
544 s
545 }
546}
547
548fn expand_tilde(path: &str) -> std::path::PathBuf {
549 if let Some(rest) = path.strip_prefix("~/") {
550 if let Some(home) = dirs::home_dir() {
551 return home.join(rest);
552 }
553 } else if path == "~" {
554 if let Some(home) = dirs::home_dir() {
555 return home;
556 }
557 }
558 std::path::PathBuf::from(path)
559}