Skip to main content

cqlsh_rs/
session.rs

1//! CQL session management layer.
2//!
3//! Wraps the driver with higher-level session state management including
4//! keyspace tracking, consistency level management, and tracing control.
5//! This mirrors the Python cqlsh `Shell` session state.
6
7use anyhow::{bail, Result};
8
9use crate::config::MergedConfig;
10use crate::driver::types::CqlValue;
11use crate::driver::{
12    AggregateMetadata, ConnectionConfig, Consistency, CqlDriver, CqlResult, CqlRowStream,
13    FunctionMetadata, KeyspaceMetadata, PreparedId, ScyllaDriver, SslConfig, TableMetadata,
14    TracingSession, UdtMetadata,
15};
16
17/// High-level CQL session managing driver state and user preferences.
18pub struct CqlSession {
19    driver: ScyllaDriver,
20    /// Current keyspace (updated on USE commands).
21    current_keyspace: Option<String>,
22    /// Display name for the connection (host:port).
23    pub connection_display: String,
24    /// Cluster name retrieved after connecting.
25    pub cluster_name: Option<String>,
26    /// CQL version from the connected node.
27    pub cql_version: Option<String>,
28    /// Release version of the connected node.
29    pub release_version: Option<String>,
30    /// ScyllaDB version (None if connected to Apache Cassandra).
31    pub scylla_version: Option<String>,
32}
33
34impl CqlSession {
35    /// Create a new session by connecting using the merged configuration.
36    pub async fn connect(config: &MergedConfig) -> Result<Self> {
37        let ssl_config = if config.ssl {
38            Some(SslConfig {
39                certfile: config.cqlshrc.ssl.certfile.clone(),
40                validate: config.cqlshrc.ssl.validate.unwrap_or(false),
41                userkey: config.cqlshrc.ssl.userkey.clone(),
42                usercert: config.cqlshrc.ssl.usercert.clone(),
43                host_certfiles: config.cqlshrc.certfiles.clone(),
44            })
45        } else {
46            None
47        };
48
49        let conn_config = ConnectionConfig {
50            host: config.host.clone(),
51            port: config.port,
52            username: config.username.clone(),
53            password: config.password.clone(),
54            keyspace: config.keyspace.clone(),
55            connect_timeout: config.connect_timeout,
56            request_timeout: config.request_timeout,
57            ssl: config.ssl,
58            ssl_config,
59            protocol_version: config.protocol_version,
60        };
61
62        let driver = ScyllaDriver::connect(&conn_config).await?;
63
64        let connection_display = format!("{}:{}", config.host, config.port);
65
66        // Fetch cluster metadata after connecting
67        let cluster_name = driver.get_cluster_name().await.ok().flatten();
68        let cql_version = driver.get_cql_version().await.ok().flatten();
69        let release_version = driver.get_release_version().await.ok().flatten();
70        let scylla_version = driver.get_scylla_version().await.ok().flatten();
71
72        // Set initial consistency from config
73        if let Some(cl_str) = &config.consistency_level {
74            if let Some(cl) = Consistency::from_str_cql(cl_str) {
75                driver.set_consistency(cl);
76            }
77        }
78
79        // Set initial serial consistency from config
80        if let Some(scl_str) = &config.serial_consistency_level {
81            if let Some(scl) = Consistency::from_str_cql(scl_str) {
82                driver.set_serial_consistency(Some(scl));
83            }
84        }
85
86        Ok(CqlSession {
87            driver,
88            current_keyspace: config.keyspace.clone(),
89            connection_display,
90            cluster_name,
91            cql_version,
92            release_version,
93            scylla_version,
94        })
95    }
96
97    /// Check if all nodes agree on the schema version.
98    /// Returns `true` if schema is in agreement, `false` if there's a mismatch.
99    /// Errors during the check are silently ignored (non-critical).
100    pub async fn check_schema_agreement(&self) -> bool {
101        use std::collections::HashSet;
102
103        let mut versions = HashSet::new();
104
105        // Get local node's schema version
106        if let Ok(result) = self
107            .driver
108            .execute_unpaged("SELECT schema_version FROM system.local WHERE key='local'")
109            .await
110        {
111            for row in &result.rows {
112                if let Some(v) = row.get(0) {
113                    versions.insert(v.to_string());
114                }
115            }
116        }
117
118        // Get peer nodes' schema versions
119        if let Ok(result) = self
120            .driver
121            .execute_unpaged("SELECT schema_version FROM system.peers")
122            .await
123        {
124            for row in &result.rows {
125                if let Some(v) = row.get(0) {
126                    versions.insert(v.to_string());
127                }
128            }
129        }
130
131        // Agreement means exactly 0 or 1 distinct versions
132        versions.len() <= 1
133    }
134
135    /// Execute a CQL statement. Handles USE keyspace commands specially.
136    pub async fn execute(&mut self, query: &str) -> Result<CqlResult> {
137        let trimmed = query.trim();
138
139        // Detect USE keyspace commands
140        if let Some(keyspace) = parse_use_command(trimmed) {
141            self.use_keyspace(&keyspace).await?;
142            return Ok(CqlResult::empty());
143        }
144
145        self.driver.execute_unpaged(query).await
146    }
147
148    /// Execute a raw CQL query without USE interception.
149    ///
150    /// Used by DESCRIBE and other internal commands that need to query
151    /// system tables directly.
152    pub async fn execute_query(&self, query: &str) -> Result<CqlResult> {
153        self.driver.execute_unpaged(query).await
154    }
155
156    /// Execute a CQL statement with paging.
157    pub async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
158        self.driver.execute_paged(query, page_size).await
159    }
160
161    pub async fn execute_streaming(&self, query: &str, page_size: i32) -> Result<CqlRowStream> {
162        self.driver.execute_streaming(query, page_size).await
163    }
164
165    /// Prepare a CQL statement.
166    pub async fn prepare(&self, query: &str) -> Result<PreparedId> {
167        self.driver.prepare(query).await
168    }
169
170    /// Execute a previously prepared statement with typed bound values.
171    pub async fn execute_prepared(
172        &self,
173        id: &PreparedId,
174        values: &[CqlValue],
175    ) -> Result<CqlResult> {
176        self.driver.execute_prepared(id, values).await
177    }
178
179    /// Switch to a different keyspace.
180    pub async fn use_keyspace(&mut self, keyspace: &str) -> Result<()> {
181        self.driver.use_keyspace(keyspace).await?;
182        self.current_keyspace = Some(keyspace.to_string());
183        Ok(())
184    }
185
186    /// Get the current keyspace.
187    pub fn current_keyspace(&self) -> Option<&str> {
188        self.current_keyspace.as_deref()
189    }
190
191    /// Get the current consistency level.
192    pub fn get_consistency(&self) -> Consistency {
193        self.driver.get_consistency()
194    }
195
196    /// Set the consistency level.
197    pub fn set_consistency(&self, consistency: Consistency) {
198        self.driver.set_consistency(consistency);
199    }
200
201    /// Set the consistency level from a string. Returns error if invalid.
202    pub fn set_consistency_str(&self, level: &str) -> Result<()> {
203        let consistency = Consistency::from_str_cql(level)
204            .ok_or_else(|| anyhow::anyhow!("invalid consistency level: {level}"))?;
205        self.driver.set_consistency(consistency);
206        Ok(())
207    }
208
209    /// Get the current serial consistency level.
210    pub fn get_serial_consistency(&self) -> Option<Consistency> {
211        self.driver.get_serial_consistency()
212    }
213
214    /// Set the serial consistency level.
215    pub fn set_serial_consistency(&self, consistency: Option<Consistency>) {
216        self.driver.set_serial_consistency(consistency);
217    }
218
219    /// Set the serial consistency level from a string. Returns error if invalid.
220    pub fn set_serial_consistency_str(&self, level: &str) -> Result<()> {
221        let consistency = Consistency::from_str_cql(level)
222            .ok_or_else(|| anyhow::anyhow!("invalid serial consistency level: {level}"))?;
223        match consistency {
224            Consistency::Serial | Consistency::LocalSerial => {
225                self.driver.set_serial_consistency(Some(consistency));
226                Ok(())
227            }
228            _ => bail!("serial consistency must be SERIAL or LOCAL_SERIAL, got {level}"),
229        }
230    }
231
232    /// Enable or disable tracing.
233    pub fn set_tracing(&self, enabled: bool) {
234        self.driver.set_tracing(enabled);
235    }
236
237    /// Check if tracing is enabled.
238    pub fn is_tracing_enabled(&self) -> bool {
239        self.driver.is_tracing_enabled()
240    }
241
242    /// Get the last tracing session ID.
243    pub fn last_trace_id(&self) -> Option<uuid::Uuid> {
244        self.driver.last_trace_id()
245    }
246
247    /// Retrieve tracing session data.
248    pub async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>> {
249        self.driver.get_trace_session(trace_id).await
250    }
251
252    /// Get metadata for all keyspaces.
253    pub async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
254        self.driver.get_keyspaces().await
255    }
256
257    /// Get metadata for tables in a keyspace.
258    pub async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
259        self.driver.get_tables(keyspace).await
260    }
261
262    /// Get metadata for a specific table.
263    pub async fn get_table_metadata(
264        &self,
265        keyspace: &str,
266        table: &str,
267    ) -> Result<Option<TableMetadata>> {
268        self.driver.get_table_metadata(keyspace, table).await
269    }
270
271    /// Get metadata for all user-defined types in a keyspace.
272    pub async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
273        self.driver.get_udts(keyspace).await
274    }
275
276    /// Get metadata for all user-defined functions in a keyspace.
277    pub async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
278        self.driver.get_functions(keyspace).await
279    }
280
281    /// Get metadata for all user-defined aggregates in a keyspace.
282    pub async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
283        self.driver.get_aggregates(keyspace).await
284    }
285
286    /// Check if the connection is still alive.
287    pub async fn is_connected(&self) -> bool {
288        self.driver.is_connected().await
289    }
290}
291
292/// Parse a USE keyspace command, returning the keyspace name if matched.
293fn parse_use_command(query: &str) -> Option<String> {
294    let upper = query.to_uppercase();
295    let trimmed = upper.trim().trim_end_matches(';').trim();
296
297    if !trimmed.starts_with("USE ") {
298        return None;
299    }
300
301    let keyspace = query
302        .trim()
303        .trim_end_matches(';')
304        .trim()
305        .strip_prefix("USE ")
306        .or_else(|| {
307            query
308                .trim()
309                .trim_end_matches(';')
310                .trim()
311                .strip_prefix("use ")
312        })
313        .map(|s| s.trim())?;
314
315    // Remove quotes if present
316    let keyspace = if (keyspace.starts_with('"') && keyspace.ends_with('"'))
317        || (keyspace.starts_with('\'') && keyspace.ends_with('\''))
318    {
319        &keyspace[1..keyspace.len() - 1]
320    } else {
321        keyspace
322    };
323
324    if keyspace.is_empty() {
325        None
326    } else {
327        Some(keyspace.to_string())
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn parse_use_simple() {
337        assert_eq!(
338            parse_use_command("USE my_keyspace"),
339            Some("my_keyspace".to_string())
340        );
341    }
342
343    #[test]
344    fn parse_use_semicolon() {
345        assert_eq!(
346            parse_use_command("USE my_keyspace;"),
347            Some("my_keyspace".to_string())
348        );
349    }
350
351    #[test]
352    fn parse_use_lowercase() {
353        assert_eq!(
354            parse_use_command("use test_ks"),
355            Some("test_ks".to_string())
356        );
357    }
358
359    #[test]
360    fn parse_use_quoted() {
361        assert_eq!(
362            parse_use_command("USE \"MyKeyspace\""),
363            Some("MyKeyspace".to_string())
364        );
365    }
366
367    #[test]
368    fn parse_use_single_quoted() {
369        assert_eq!(parse_use_command("USE 'my_ks'"), Some("my_ks".to_string()));
370    }
371
372    #[test]
373    fn parse_use_with_whitespace() {
374        assert_eq!(
375            parse_use_command("  USE  my_keyspace  ;  "),
376            Some("my_keyspace".to_string())
377        );
378    }
379
380    #[test]
381    fn parse_not_use_command() {
382        assert_eq!(parse_use_command("SELECT * FROM table"), None);
383        assert_eq!(parse_use_command("INSERT INTO users"), None);
384    }
385
386    #[test]
387    fn parse_use_empty() {
388        assert_eq!(parse_use_command("USE "), None);
389        assert_eq!(parse_use_command("USE ;"), None);
390    }
391}