Skip to main content

cqlsh_rs/driver/
mod.rs

1//! Driver abstraction layer for CQL database connectivity.
2//!
3//! Provides a trait-based abstraction over the underlying database driver,
4//! enabling testability and future flexibility. The primary implementation
5//! uses the `scylla` crate for Cassandra/ScyllaDB connectivity.
6//!
7//! Many types and trait methods are defined ahead of their use in later
8//! development phases (REPL, DESCRIBE, COPY, etc.).
9
10pub mod proxy_address_translator;
11pub mod scylla_driver;
12pub mod types;
13
14use std::collections::HashMap;
15
16use anyhow::Result;
17use async_trait::async_trait;
18
19pub use scylla_driver::ScyllaDriver;
20#[allow(unused_imports)]
21pub use types::{CqlColumn, CqlResult, CqlRow, CqlRowStream, CqlValue};
22
23/// Configuration for establishing a database connection.
24#[derive(Debug, Clone)]
25pub struct ConnectionConfig {
26    /// Contact point host (e.g., "127.0.0.1").
27    pub host: String,
28    /// Native transport port (default: 9042).
29    pub port: u16,
30    /// Optional username for authentication.
31    pub username: Option<String>,
32    /// Optional password for authentication.
33    pub password: Option<String>,
34    /// Optional default keyspace.
35    pub keyspace: Option<String>,
36    /// Connection timeout in seconds.
37    pub connect_timeout: u64,
38    /// Per-request timeout in seconds.
39    pub request_timeout: u64,
40    /// Whether to use SSL/TLS.
41    pub ssl: bool,
42    /// SSL/TLS configuration.
43    pub ssl_config: Option<SslConfig>,
44    /// Protocol version (None = auto-negotiate).
45    pub protocol_version: Option<u8>,
46}
47
48/// SSL/TLS configuration options.
49#[derive(Debug, Clone, Default)]
50pub struct SslConfig {
51    /// Path to CA certificate file for server verification.
52    pub certfile: Option<String>,
53    /// Whether to validate the server certificate.
54    pub validate: bool,
55    /// Path to client private key file (for mutual TLS).
56    pub userkey: Option<String>,
57    /// Path to client certificate file (for mutual TLS).
58    pub usercert: Option<String>,
59    /// Per-host certificate files.
60    pub host_certfiles: HashMap<String, String>,
61}
62
63/// Metadata about a column in a result set.
64#[derive(Debug, Clone)]
65pub struct ColumnMetadata {
66    pub name: String,
67    pub type_name: String,
68}
69
70/// Metadata about a keyspace.
71#[derive(Debug, Clone)]
72pub struct KeyspaceMetadata {
73    pub name: String,
74    pub replication: HashMap<String, String>,
75    pub durable_writes: bool,
76}
77
78/// Metadata about a table.
79#[derive(Debug, Clone)]
80pub struct TableMetadata {
81    pub keyspace: String,
82    pub name: String,
83    pub columns: Vec<ColumnMetadata>,
84    pub partition_key: Vec<String>,
85    pub clustering_key: Vec<String>,
86    /// Clustering order for each clustering column (e.g., "ASC" or "DESC").
87    /// Parallel to `clustering_key`.
88    pub clustering_order: Vec<String>,
89    /// Table properties from system_schema.tables (e.g., bloom_filter_fp_chance, compaction, etc.).
90    pub properties: std::collections::BTreeMap<String, String>,
91}
92
93/// Metadata about a user-defined type (UDT).
94#[derive(Debug, Clone)]
95pub struct UdtMetadata {
96    pub keyspace: String,
97    pub name: String,
98    pub field_names: Vec<String>,
99    pub field_types: Vec<String>,
100}
101
102/// Metadata about a user-defined function (UDF).
103#[derive(Debug, Clone)]
104pub struct FunctionMetadata {
105    pub keyspace: String,
106    pub name: String,
107    pub argument_types: Vec<String>,
108    pub return_type: String,
109}
110
111/// Metadata about a user-defined aggregate (UDA).
112#[derive(Debug, Clone)]
113pub struct AggregateMetadata {
114    pub keyspace: String,
115    pub name: String,
116    pub argument_types: Vec<String>,
117    pub return_type: String,
118}
119
120/// Consistency levels matching CQL specification.
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum Consistency {
123    Any,
124    One,
125    Two,
126    Three,
127    Quorum,
128    All,
129    LocalQuorum,
130    EachQuorum,
131    Serial,
132    LocalSerial,
133    LocalOne,
134}
135
136impl Consistency {
137    /// Parse a consistency level from a string (case-insensitive).
138    pub fn from_str_cql(s: &str) -> Option<Self> {
139        match s.to_uppercase().as_str() {
140            "ANY" => Some(Self::Any),
141            "ONE" => Some(Self::One),
142            "TWO" => Some(Self::Two),
143            "THREE" => Some(Self::Three),
144            "QUORUM" => Some(Self::Quorum),
145            "ALL" => Some(Self::All),
146            "LOCAL_QUORUM" => Some(Self::LocalQuorum),
147            "EACH_QUORUM" => Some(Self::EachQuorum),
148            "SERIAL" => Some(Self::Serial),
149            "LOCAL_SERIAL" => Some(Self::LocalSerial),
150            "LOCAL_ONE" => Some(Self::LocalOne),
151            _ => None,
152        }
153    }
154
155    /// Return the CQL string representation.
156    pub fn as_cql_str(&self) -> &'static str {
157        match self {
158            Self::Any => "ANY",
159            Self::One => "ONE",
160            Self::Two => "TWO",
161            Self::Three => "THREE",
162            Self::Quorum => "QUORUM",
163            Self::All => "ALL",
164            Self::LocalQuorum => "LOCAL_QUORUM",
165            Self::EachQuorum => "EACH_QUORUM",
166            Self::Serial => "SERIAL",
167            Self::LocalSerial => "LOCAL_SERIAL",
168            Self::LocalOne => "LOCAL_ONE",
169        }
170    }
171}
172
173impl std::fmt::Display for Consistency {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.write_str(self.as_cql_str())
176    }
177}
178
179/// The core driver trait abstracting database operations.
180///
181/// All methods are async and return `Result` for proper error propagation.
182/// Implementations must be `Send + Sync` for use across async tasks.
183#[async_trait]
184pub trait CqlDriver: Send + Sync {
185    /// Establish a connection to the database cluster.
186    async fn connect(config: &ConnectionConfig) -> Result<Self>
187    where
188        Self: Sized;
189
190    /// Execute a raw CQL query string without parameters.
191    async fn execute_unpaged(&self, query: &str) -> Result<CqlResult>;
192
193    /// Execute a CQL query with automatic paging, returning all rows.
194    async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult>;
195
196    /// Execute a CQL query returning a streaming result set.
197    ///
198    /// Rows are fetched lazily page-by-page from the server.
199    async fn execute_streaming(&self, query: &str, page_size: i32) -> Result<CqlRowStream>;
200
201    /// Prepare a CQL statement for repeated execution.
202    async fn prepare(&self, query: &str) -> Result<PreparedId>;
203
204    /// Execute a previously prepared statement with the given values.
205    async fn execute_prepared(
206        &self,
207        prepared_id: &PreparedId,
208        values: &[CqlValue],
209    ) -> Result<CqlResult>;
210
211    /// Switch the current keyspace (USE <keyspace>).
212    async fn use_keyspace(&self, keyspace: &str) -> Result<()>;
213
214    /// Get the current consistency level.
215    fn get_consistency(&self) -> Consistency;
216
217    /// Set the consistency level for subsequent queries.
218    fn set_consistency(&self, consistency: Consistency);
219
220    /// Get the current serial consistency level.
221    fn get_serial_consistency(&self) -> Option<Consistency>;
222
223    /// Set the serial consistency level for subsequent queries.
224    fn set_serial_consistency(&self, consistency: Option<Consistency>);
225
226    /// Enable or disable request tracing.
227    fn set_tracing(&self, enabled: bool);
228
229    /// Check if tracing is currently enabled.
230    fn is_tracing_enabled(&self) -> bool;
231
232    /// Get the last tracing session ID (if tracing was enabled).
233    fn last_trace_id(&self) -> Option<uuid::Uuid>;
234
235    /// Retrieve tracing session data for a given trace ID.
236    async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>>;
237
238    /// Get metadata for all keyspaces.
239    async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>>;
240
241    /// Get metadata for all tables in a keyspace.
242    async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>>;
243
244    /// Get metadata for a specific table.
245    async fn get_table_metadata(
246        &self,
247        keyspace: &str,
248        table: &str,
249    ) -> Result<Option<TableMetadata>>;
250
251    /// Get metadata for all user-defined types in a keyspace.
252    async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>>;
253
254    /// Get metadata for all user-defined functions in a keyspace.
255    async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>>;
256
257    /// Get metadata for all user-defined aggregates in a keyspace.
258    async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>>;
259
260    /// Get the cluster name.
261    async fn get_cluster_name(&self) -> Result<Option<String>>;
262
263    /// Get the CQL version from the connected node.
264    async fn get_cql_version(&self) -> Result<Option<String>>;
265
266    /// Get the release version of the connected node.
267    async fn get_release_version(&self) -> Result<Option<String>>;
268
269    /// Get the ScyllaDB version (None if not ScyllaDB).
270    async fn get_scylla_version(&self) -> Result<Option<String>>;
271
272    /// Check if the connection is still alive.
273    async fn is_connected(&self) -> bool;
274}
275
276/// Opaque handle for a prepared statement.
277#[derive(Debug, Clone)]
278pub struct PreparedId {
279    /// Internal identifier (implementation-specific).
280    pub(crate) inner: Vec<u8>,
281}
282
283/// Tracing session data returned by the database.
284#[derive(Debug, Clone)]
285pub struct TracingSession {
286    pub trace_id: uuid::Uuid,
287    pub client: Option<String>,
288    pub command: Option<String>,
289    pub coordinator: Option<String>,
290    pub duration: Option<i32>,
291    pub parameters: HashMap<String, String>,
292    pub request: Option<String>,
293    pub started_at: Option<String>,
294    pub events: Vec<TracingEvent>,
295}
296
297/// A single event within a tracing session.
298#[derive(Debug, Clone)]
299pub struct TracingEvent {
300    pub activity: Option<String>,
301    pub source: Option<String>,
302    pub source_elapsed: Option<i32>,
303    pub thread: Option<String>,
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn udt_metadata_fields() {
312        let udt = UdtMetadata {
313            keyspace: "ks".to_string(),
314            name: "address".to_string(),
315            field_names: vec!["street".to_string(), "city".to_string()],
316            field_types: vec!["text".to_string(), "text".to_string()],
317        };
318        assert_eq!(udt.keyspace, "ks");
319        assert_eq!(udt.name, "address");
320        assert_eq!(udt.field_names.len(), 2);
321        assert_eq!(udt.field_types.len(), 2);
322        assert_eq!(udt.field_names[0], "street");
323        assert_eq!(udt.field_types[0], "text");
324    }
325
326    #[test]
327    fn function_metadata_fields() {
328        let func = FunctionMetadata {
329            keyspace: "ks".to_string(),
330            name: "my_func".to_string(),
331            argument_types: vec!["int".to_string(), "text".to_string()],
332            return_type: "boolean".to_string(),
333        };
334        assert_eq!(func.keyspace, "ks");
335        assert_eq!(func.name, "my_func");
336        assert_eq!(func.argument_types, vec!["int", "text"]);
337        assert_eq!(func.return_type, "boolean");
338    }
339
340    #[test]
341    fn aggregate_metadata_fields() {
342        let agg = AggregateMetadata {
343            keyspace: "ks".to_string(),
344            name: "my_agg".to_string(),
345            argument_types: vec!["int".to_string()],
346            return_type: "bigint".to_string(),
347        };
348        assert_eq!(agg.keyspace, "ks");
349        assert_eq!(agg.name, "my_agg");
350        assert_eq!(agg.argument_types, vec!["int"]);
351        assert_eq!(agg.return_type, "bigint");
352    }
353
354    #[test]
355    fn udt_metadata_clone() {
356        let udt = UdtMetadata {
357            keyspace: "ks".to_string(),
358            name: "my_type".to_string(),
359            field_names: vec!["f1".to_string()],
360            field_types: vec!["int".to_string()],
361        };
362        let cloned = udt.clone();
363        assert_eq!(cloned.keyspace, udt.keyspace);
364        assert_eq!(cloned.name, udt.name);
365    }
366
367    #[test]
368    fn function_metadata_empty_args() {
369        let func = FunctionMetadata {
370            keyspace: "ks".to_string(),
371            name: "no_args_func".to_string(),
372            argument_types: vec![],
373            return_type: "text".to_string(),
374        };
375        assert!(func.argument_types.is_empty());
376    }
377
378    #[test]
379    fn aggregate_metadata_clone() {
380        let agg = AggregateMetadata {
381            keyspace: "ks".to_string(),
382            name: "my_agg".to_string(),
383            argument_types: vec!["int".to_string()],
384            return_type: "bigint".to_string(),
385        };
386        let cloned = agg.clone();
387        assert_eq!(cloned.return_type, agg.return_type);
388    }
389
390    #[test]
391    fn consistency_from_str() {
392        assert_eq!(
393            Consistency::from_str_cql("QUORUM"),
394            Some(Consistency::Quorum)
395        );
396        assert_eq!(
397            Consistency::from_str_cql("local_quorum"),
398            Some(Consistency::LocalQuorum)
399        );
400        assert_eq!(
401            Consistency::from_str_cql("LOCAL_SERIAL"),
402            Some(Consistency::LocalSerial)
403        );
404        assert_eq!(Consistency::from_str_cql("INVALID"), None);
405    }
406
407    #[test]
408    fn consistency_display() {
409        assert_eq!(Consistency::One.to_string(), "ONE");
410        assert_eq!(Consistency::LocalQuorum.to_string(), "LOCAL_QUORUM");
411    }
412
413    #[test]
414    fn consistency_roundtrip() {
415        let levels = [
416            Consistency::Any,
417            Consistency::One,
418            Consistency::Two,
419            Consistency::Three,
420            Consistency::Quorum,
421            Consistency::All,
422            Consistency::LocalQuorum,
423            Consistency::EachQuorum,
424            Consistency::Serial,
425            Consistency::LocalSerial,
426            Consistency::LocalOne,
427        ];
428        for level in &levels {
429            let s = level.as_cql_str();
430            assert_eq!(Consistency::from_str_cql(s), Some(*level));
431        }
432    }
433}