Skip to main content

cqlsh_rs/driver/
scylla_driver.rs

1//! ScyllaDriver — CqlDriver implementation using the `scylla` crate.
2//!
3//! Provides connectivity to Apache Cassandra and ScyllaDB clusters using
4//! the scylla-rust-driver, with support for authentication, SSL/TLS,
5//! prepared statements, paging, and schema metadata queries.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::{StreamExt, TryStreamExt};
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22    Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23    CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlRowStream, CqlValue};
28use super::{
29    AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30    KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31    UdtMetadata,
32};
33
34/// ScyllaDriver wraps a scylla `Session` and provides the `CqlDriver` trait.
35pub struct ScyllaDriver {
36    session: Session,
37    /// Cache of prepared statements keyed by internal ID.
38    prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39    /// Current consistency level.
40    consistency: Mutex<Consistency>,
41    /// Current serial consistency level.
42    serial_consistency: Mutex<Option<Consistency>>,
43    /// Whether tracing is enabled for queries.
44    tracing_enabled: AtomicBool,
45    /// Last tracing session ID.
46    last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50    /// Build the TLS configuration from SslConfig.
51    fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52        use rustls::pki_types::CertificateDer;
53        use std::fs::File;
54        use std::io::BufReader;
55
56        if !ssl_config.validate {
57            return Self::build_rustls_config_no_verify(ssl_config);
58        }
59
60        let mut root_store = rustls::RootCertStore::empty();
61
62        // Load CA certificate if provided
63        if let Some(certfile) = &ssl_config.certfile {
64            let file = File::open(certfile)
65                .with_context(|| format!("opening CA certificate: {certfile}"))?;
66            let mut reader = BufReader::new(file);
67            let certs = rustls_pemfile::certs(&mut reader)
68                .collect::<std::result::Result<Vec<_>, _>>()
69                .with_context(|| format!("parsing CA certificate: {certfile}"))?;
70            for cert in certs {
71                root_store
72                    .add(cert)
73                    .context("adding CA certificate to root store")?;
74            }
75        }
76
77        let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
78
79        // Client certificate authentication (mutual TLS)
80        let config = if let (Some(usercert_path), Some(userkey_path)) =
81            (&ssl_config.usercert, &ssl_config.userkey)
82        {
83            let cert_file = File::open(usercert_path)
84                .with_context(|| format!("opening client certificate: {usercert_path}"))?;
85            let mut cert_reader = BufReader::new(cert_file);
86            let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
87                .collect::<std::result::Result<Vec<_>, _>>()
88                .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
89
90            let key_file = File::open(userkey_path)
91                .with_context(|| format!("opening client key: {userkey_path}"))?;
92            let mut key_reader = BufReader::new(key_file);
93            let key = rustls_pemfile::private_key(&mut key_reader)
94                .with_context(|| format!("parsing client key: {userkey_path}"))?
95                .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
96
97            builder
98                .with_client_auth_cert(certs, key)
99                .context("configuring mutual TLS")?
100        } else {
101            builder.with_no_client_auth()
102        };
103
104        Ok(Arc::new(config))
105    }
106
107    fn build_rustls_config_no_verify(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
108        use rustls::client::danger::{
109            HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
110        };
111        use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
112        use rustls::{DigitallySignedStruct, Error, SignatureScheme};
113        use std::fs::File;
114        use std::io::BufReader;
115
116        #[derive(Debug)]
117        struct NoVerifier;
118
119        impl ServerCertVerifier for NoVerifier {
120            fn verify_server_cert(
121                &self,
122                _end_entity: &CertificateDer<'_>,
123                _intermediates: &[CertificateDer<'_>],
124                _server_name: &ServerName<'_>,
125                _ocsp_response: &[u8],
126                _now: UnixTime,
127            ) -> std::result::Result<ServerCertVerified, Error> {
128                Ok(ServerCertVerified::assertion())
129            }
130
131            fn verify_tls12_signature(
132                &self,
133                _message: &[u8],
134                _cert: &CertificateDer<'_>,
135                _dss: &DigitallySignedStruct,
136            ) -> std::result::Result<HandshakeSignatureValid, Error> {
137                Ok(HandshakeSignatureValid::assertion())
138            }
139
140            fn verify_tls13_signature(
141                &self,
142                _message: &[u8],
143                _cert: &CertificateDer<'_>,
144                _dss: &DigitallySignedStruct,
145            ) -> std::result::Result<HandshakeSignatureValid, Error> {
146                Ok(HandshakeSignatureValid::assertion())
147            }
148
149            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
150                rustls::crypto::ring::default_provider()
151                    .signature_verification_algorithms
152                    .supported_schemes()
153            }
154        }
155
156        let builder = rustls::ClientConfig::builder()
157            .dangerous()
158            .with_custom_certificate_verifier(Arc::new(NoVerifier));
159
160        let config = if let (Some(usercert_path), Some(userkey_path)) =
161            (&ssl_config.usercert, &ssl_config.userkey)
162        {
163            let cert_file = File::open(usercert_path)
164                .with_context(|| format!("opening client certificate: {usercert_path}"))?;
165            let mut cert_reader = BufReader::new(cert_file);
166            let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
167                .collect::<std::result::Result<Vec<_>, _>>()
168                .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
169
170            let key_file = File::open(userkey_path)
171                .with_context(|| format!("opening client key: {userkey_path}"))?;
172            let mut key_reader = BufReader::new(key_file);
173            let key = rustls_pemfile::private_key(&mut key_reader)
174                .with_context(|| format!("parsing client key: {userkey_path}"))?
175                .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
176
177            builder
178                .with_client_auth_cert(certs, key)
179                .context("configuring mutual TLS")?
180        } else {
181            builder.with_no_client_auth()
182        };
183
184        Ok(Arc::new(config))
185    }
186
187    /// Extract a `Vec<String>` from a `CqlValue::List` column value.
188    fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
189        match val {
190            Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
191            _ => Vec::new(),
192        }
193    }
194
195    /// Convert a scylla QueryResult into our CqlResult type.
196    fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
197        let tracing_id = result.tracing_id();
198        let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
199
200        // Check if this is a non-row result (DDL/DML)
201        if !result.is_rows() {
202            return Ok(CqlResult {
203                columns: Vec::new(),
204                rows: Vec::new(),
205                has_rows: false,
206                tracing_id,
207                warnings,
208            });
209        }
210
211        // Convert to QueryRowsResult to access typed rows
212        let rows_result = result
213            .into_rows_result()
214            .context("converting query result to rows")?;
215
216        // Extract column metadata
217        let col_specs = rows_result.column_specs();
218        let columns: Vec<CqlColumn> = col_specs
219            .iter()
220            .map(|spec| CqlColumn {
221                name: spec.name().to_string(),
222                type_name: format!("{:?}", spec.typ()),
223            })
224            .collect();
225
226        // Deserialize rows as untyped Row (Vec<Option<CqlValue>>)
227        let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
228
229        let mut cql_rows = Vec::new();
230        for row_result in typed_rows {
231            let row = row_result.context("deserializing row")?;
232            let values: Vec<CqlValue> = row
233                .columns
234                .into_iter()
235                .enumerate()
236                .map(|(col_idx, opt_val)| match opt_val {
237                    Some(v) => {
238                        tracing::debug!(
239                            column = col_idx,
240                            variant = ?std::mem::discriminant(&v),
241                            "converting ScyllaCqlValue: {v:?}"
242                        );
243                        Self::convert_scylla_value(v)
244                    }
245                    None => {
246                        tracing::debug!(column = col_idx, "column value is None (null)");
247                        CqlValue::Null
248                    }
249                })
250                .collect();
251            cql_rows.push(CqlRow { values });
252        }
253
254        Ok(CqlResult {
255            columns,
256            rows: cql_rows,
257            has_rows: true,
258            tracing_id,
259            warnings,
260        })
261    }
262
263    /// Convert a scylla CqlValue to our CqlValue type.
264    fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
265        match value {
266            ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
267            ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
268            ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
269            ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
270            ScyllaCqlValue::Decimal(d) => {
271                let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
272                let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
273                CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
274            }
275            ScyllaCqlValue::Date(d) => {
276                // scylla CqlDate wraps u32 days since epoch center (2^31)
277                let days = d.0;
278                let epoch_offset = days as i64 - (1i64 << 31);
279                match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
280                {
281                    Some(date) => CqlValue::Date(date),
282                    None => CqlValue::Text(format!("<invalid date: {days}>")),
283                }
284            }
285            ScyllaCqlValue::Double(d) => CqlValue::Double(d),
286            ScyllaCqlValue::Duration(d) => CqlValue::Duration {
287                months: d.months,
288                days: d.days,
289                nanoseconds: d.nanoseconds,
290            },
291            ScyllaCqlValue::Empty => CqlValue::Null,
292            ScyllaCqlValue::Float(f) => CqlValue::Float(f),
293            ScyllaCqlValue::Int(i) => CqlValue::Int(i),
294            ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
295            ScyllaCqlValue::Text(s) => CqlValue::Text(s),
296            ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
297            ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
298            ScyllaCqlValue::List(items) => {
299                CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
300            }
301            ScyllaCqlValue::Map(entries) => CqlValue::Map(
302                entries
303                    .into_iter()
304                    .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
305                    .collect(),
306            ),
307            ScyllaCqlValue::Set(items) => {
308                CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
309            }
310            ScyllaCqlValue::UserDefinedType {
311                keyspace,
312                name,
313                fields,
314            } => CqlValue::UserDefinedType {
315                keyspace,
316                type_name: name,
317                fields: fields
318                    .into_iter()
319                    .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
320                    .collect(),
321            },
322            ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
323            ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
324            ScyllaCqlValue::Time(t) => {
325                let nanos = t.0;
326                let secs = (nanos / 1_000_000_000) as u32;
327                let nano_part = (nanos % 1_000_000_000) as u32;
328                match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
329                    Some(time) => CqlValue::Time(time),
330                    None => CqlValue::Text(format!("<invalid time: {nanos}>")),
331                }
332            }
333            ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
334            ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
335                items
336                    .into_iter()
337                    .map(|v| v.map(Self::convert_scylla_value))
338                    .collect(),
339            ),
340            ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
341            ScyllaCqlValue::Varint(v) => {
342                let big_int =
343                    num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
344                CqlValue::Varint(big_int)
345            }
346            // CqlValue is non-exhaustive; handle future variants gracefully
347            _ => {
348                tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
349                CqlValue::Text(format!("{value:?}"))
350            }
351        }
352    }
353
354    /// Convert our internal CqlValue to scylla's CqlValue (reverse of convert_scylla_value).
355    fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
356        match v {
357            CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
358            CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
359            CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
360            CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
361            CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
362            CqlValue::Duration {
363                months,
364                days,
365                nanoseconds,
366            } => ScyllaCqlValue::Duration(CqlDuration {
367                months: *months,
368                days: *days,
369                nanoseconds: *nanoseconds,
370            }),
371            CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
372            CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
373            CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
374            CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
375            CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
376            CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
377            CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
378            CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
379            CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
380            CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
381            CqlValue::Date(d) => {
382                // Convert NaiveDate back to scylla's u32 days offset from 2^31 epoch
383                let days_from_ce = d.num_days_from_ce();
384                let epoch_offset = days_from_ce as i64 - 719_163;
385                let cql_days = (epoch_offset + (1i64 << 31)) as u32;
386                ScyllaCqlValue::Date(CqlDate(cql_days))
387            }
388            CqlValue::Time(t) => {
389                let nanos =
390                    t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
391                ScyllaCqlValue::Time(CqlTime(nanos))
392            }
393            CqlValue::Varint(bi) => {
394                let bytes = bi.to_signed_bytes_be();
395                ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
396            }
397            CqlValue::Decimal(d) => {
398                let (int_val, scale) = d.as_bigint_and_exponent();
399                let bytes = int_val.to_signed_bytes_be();
400                ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
401                    &bytes,
402                    scale as i32,
403                ))
404            }
405            CqlValue::List(items) => {
406                ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
407            }
408            CqlValue::Set(items) => {
409                ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
410            }
411            CqlValue::Map(entries) => ScyllaCqlValue::Map(
412                entries
413                    .iter()
414                    .map(|(k, v)| {
415                        (
416                            Self::internal_to_scylla_cql(k),
417                            Self::internal_to_scylla_cql(v),
418                        )
419                    })
420                    .collect(),
421            ),
422            CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
423                items
424                    .iter()
425                    .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
426                    .collect(),
427            ),
428            CqlValue::UserDefinedType {
429                keyspace,
430                type_name,
431                fields,
432            } => ScyllaCqlValue::UserDefinedType {
433                keyspace: keyspace.clone(),
434                name: type_name.clone(),
435                fields: fields
436                    .iter()
437                    .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
438                    .collect(),
439            },
440            CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
441        }
442    }
443
444    /// Convert our Consistency to scylla's Consistency.
445    fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
446        use scylla::statement::Consistency as SC;
447        match c {
448            Consistency::Any => SC::Any,
449            Consistency::One => SC::One,
450            Consistency::Two => SC::Two,
451            Consistency::Three => SC::Three,
452            Consistency::Quorum => SC::Quorum,
453            Consistency::All => SC::All,
454            Consistency::LocalQuorum => SC::LocalQuorum,
455            Consistency::EachQuorum => SC::EachQuorum,
456            Consistency::Serial => SC::Serial,
457            Consistency::LocalSerial => SC::LocalSerial,
458            Consistency::LocalOne => SC::LocalOne,
459        }
460    }
461
462    /// Convert our Consistency to scylla's SerialConsistency.
463    fn to_scylla_serial_consistency(
464        c: Consistency,
465    ) -> Option<scylla::statement::SerialConsistency> {
466        use scylla::statement::SerialConsistency as SC;
467        match c {
468            Consistency::Serial => Some(SC::Serial),
469            Consistency::LocalSerial => Some(SC::LocalSerial),
470            _ => None,
471        }
472    }
473
474    /// Build a Statement with the current consistency and tracing settings.
475    fn build_query(&self, cql: &str) -> Statement {
476        let mut stmt = Statement::new(cql);
477
478        let consistency = *self.consistency.lock().unwrap();
479        stmt.set_consistency(Self::to_scylla_consistency(consistency));
480
481        let serial = *self.serial_consistency.lock().unwrap();
482        if let Some(sc) = serial {
483            if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
484                stmt.set_serial_consistency(Some(sc));
485            }
486        }
487
488        if self.tracing_enabled.load(Ordering::Relaxed) {
489            stmt.set_tracing(true);
490        }
491
492        stmt
493    }
494
495    /// Store tracing ID from a result if present.
496    fn store_trace_id(&self, result: &QueryResult) {
497        if let Some(trace_id) = result.tracing_id() {
498            *self.last_trace_id.lock().unwrap() = Some(trace_id);
499        }
500    }
501}
502
503#[async_trait]
504impl CqlDriver for ScyllaDriver {
505    async fn connect(config: &ConnectionConfig) -> Result<Self> {
506        let addr = format!("{}:{}", config.host, config.port);
507
508        let mut builder = SessionBuilder::new().known_node(&addr);
509
510        // cqlsh is a single-user interactive tool — one connection per host suffices
511        // and avoids connection explosion when using a proxy translator.
512        builder = builder.pool_size(scylla::client::PoolSize::PerHost(
513            std::num::NonZeroUsize::new(1).unwrap(),
514        ));
515
516        if let (Some(username), Some(password)) = (&config.username, &config.password) {
517            builder = builder.user(username, password);
518        }
519
520        builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
521
522        if let Some(keyspace) = &config.keyspace {
523            builder = builder.use_keyspace(keyspace, false);
524        }
525
526        // Always install the proxy address translator. Since known_node addresses
527        // are never translated (only peers from system.peers are), this is safe:
528        // - Direct connections: peers share the same network, translator is a no-op
529        //   in practice (peers are reachable anyway, and translated to contact point
530        //   which also works since it's the same node in single-node or resolves correctly)
531        // - Proxy connections: peers have unreachable internal IPs, translator
532        //   redirects all of them to the proxy contact point
533        let contact_point = tokio::net::lookup_host(&addr)
534            .await
535            .ok()
536            .and_then(|mut addrs| addrs.next());
537        if let Some(contact_point) = contact_point {
538            let translator = Arc::new(
539                super::proxy_address_translator::ProxyAddressTranslator::new(contact_point),
540            );
541            builder = builder.address_translator(translator);
542        }
543
544        if config.ssl {
545            let tls_config = if let Some(ssl_config) = &config.ssl_config {
546                Self::build_rustls_config(ssl_config)?
547            } else {
548                Self::build_rustls_config_no_verify(&SslConfig::default())?
549            };
550            builder = builder.tls_context(Some(tls_config));
551        }
552
553        let session = builder.build().await.context("connecting to cluster")?;
554
555        Ok(ScyllaDriver {
556            session,
557            prepared_cache: Mutex::new(HashMap::new()),
558            consistency: Mutex::new(Consistency::One),
559            serial_consistency: Mutex::new(None),
560            tracing_enabled: AtomicBool::new(false),
561            last_trace_id: Mutex::new(None),
562        })
563    }
564
565    async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
566        let stmt = self.build_query(query);
567
568        let result = self.session.query_unpaged(stmt, ()).await?;
569
570        self.store_trace_id(&result);
571        Self::convert_query_result(result)
572    }
573
574    async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
575        let mut stmt = self.build_query(query);
576        stmt.set_page_size(page_size);
577
578        let query_pager = self
579            .session
580            .query_iter(stmt, ())
581            .await
582            .context("starting paged query")?;
583
584        // Get column metadata from the pager
585        let col_specs = query_pager.column_specs();
586        let columns: Vec<CqlColumn> = col_specs
587            .iter()
588            .map(|spec| CqlColumn {
589                name: spec.name().to_string(),
590                type_name: format!("{:?}", spec.typ()),
591            })
592            .collect();
593
594        // Stream all rows using the untyped Row type
595        let mut rows_stream = query_pager.rows_stream::<Row>()?;
596        let mut cql_rows = Vec::new();
597
598        while let Some(row) = rows_stream.try_next().await? {
599            let values: Vec<CqlValue> = row
600                .columns
601                .into_iter()
602                .map(|opt_val| match opt_val {
603                    Some(v) => Self::convert_scylla_value(v),
604                    None => CqlValue::Null,
605                })
606                .collect();
607            cql_rows.push(CqlRow { values });
608        }
609
610        Ok(CqlResult {
611            columns,
612            rows: cql_rows,
613            has_rows: true,
614            tracing_id: None,
615            warnings: Vec::new(),
616        })
617    }
618
619    async fn execute_streaming(&self, query: &str, page_size: i32) -> Result<CqlRowStream> {
620        let mut stmt = self.build_query(query);
621        stmt.set_page_size(page_size);
622
623        let query_pager = self
624            .session
625            .query_iter(stmt, ())
626            .await
627            .context("starting streaming query")?;
628
629        let col_specs = query_pager.column_specs();
630        let columns: Vec<CqlColumn> = col_specs
631            .iter()
632            .map(|spec| CqlColumn {
633                name: spec.name().to_string(),
634                type_name: format!("{:?}", spec.typ()),
635            })
636            .collect();
637
638        let rows_stream = query_pager.rows_stream::<Row>()?;
639
640        let mapped_stream = rows_stream.map(|row_result| {
641            row_result
642                .map(|row| {
643                    let values: Vec<CqlValue> = row
644                        .columns
645                        .into_iter()
646                        .map(|opt_val| match opt_val {
647                            Some(v) => Self::convert_scylla_value(v),
648                            None => CqlValue::Null,
649                        })
650                        .collect();
651                    CqlRow { values }
652                })
653                .map_err(|e| anyhow::anyhow!("{}", e))
654        });
655
656        Ok(CqlRowStream {
657            columns,
658            rows: Box::pin(mapped_stream),
659        })
660    }
661
662    async fn prepare(&self, query: &str) -> Result<PreparedId> {
663        let prepared = self
664            .session
665            .prepare(query)
666            .await
667            .context("preparing CQL statement")?;
668
669        let id = prepared.get_id().to_vec();
670        self.prepared_cache
671            .lock()
672            .unwrap()
673            .insert(id.clone(), prepared);
674
675        Ok(PreparedId { inner: id })
676    }
677
678    async fn execute_prepared(
679        &self,
680        prepared_id: &PreparedId,
681        values: &[CqlValue],
682    ) -> Result<CqlResult> {
683        let prepared = self
684            .prepared_cache
685            .lock()
686            .unwrap()
687            .get(&prepared_id.inner)
688            .cloned()
689            .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
690
691        // Convert internal CqlValues to scylla CqlValues for binding.
692        // Null/Unset become None (bound as null), all others become Some(value).
693        let scylla_values: Vec<Option<ScyllaCqlValue>> = values
694            .iter()
695            .map(|v| match v {
696                CqlValue::Null | CqlValue::Unset => None,
697                other => Some(Self::internal_to_scylla_cql(other)),
698            })
699            .collect();
700
701        let result = self
702            .session
703            .execute_unpaged(&prepared, scylla_values)
704            .await
705            .context("executing prepared statement")?;
706
707        self.store_trace_id(&result);
708        Self::convert_query_result(result)
709    }
710
711    async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
712        self.session
713            .use_keyspace(keyspace, false)
714            .await
715            .with_context(|| format!("switching to keyspace: {keyspace}"))?;
716        Ok(())
717    }
718
719    fn get_consistency(&self) -> Consistency {
720        *self.consistency.lock().unwrap()
721    }
722
723    fn set_consistency(&self, consistency: Consistency) {
724        *self.consistency.lock().unwrap() = consistency;
725    }
726
727    fn get_serial_consistency(&self) -> Option<Consistency> {
728        *self.serial_consistency.lock().unwrap()
729    }
730
731    fn set_serial_consistency(&self, consistency: Option<Consistency>) {
732        *self.serial_consistency.lock().unwrap() = consistency;
733    }
734
735    fn set_tracing(&self, enabled: bool) {
736        self.tracing_enabled.store(enabled, Ordering::Relaxed);
737    }
738
739    fn is_tracing_enabled(&self) -> bool {
740        self.tracing_enabled.load(Ordering::Relaxed)
741    }
742
743    fn last_trace_id(&self) -> Option<Uuid> {
744        *self.last_trace_id.lock().unwrap()
745    }
746
747    async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
748        let query = format!(
749            "SELECT client, command, coordinator, duration, parameters, request, started_at \
750             FROM system_traces.sessions WHERE session_id = {}",
751            trace_id
752        );
753        let result = self.execute_unpaged(&query).await?;
754
755        if result.rows.is_empty() {
756            return Ok(None);
757        }
758
759        let events_query = format!(
760            "SELECT activity, source, source_elapsed, thread \
761             FROM system_traces.events WHERE session_id = {}",
762            trace_id
763        );
764        let events_result = self.execute_unpaged(&events_query).await?;
765
766        let events: Vec<TracingEvent> = events_result
767            .rows
768            .iter()
769            .map(|row| TracingEvent {
770                activity: row.get(0).and_then(cql_value_to_string),
771                source: row.get(1).and_then(cql_value_to_string),
772                source_elapsed: row.get(2).and_then(cql_value_to_i32),
773                thread: row.get(3).and_then(cql_value_to_string),
774            })
775            .collect();
776
777        let session_row = &result.rows[0];
778        Ok(Some(TracingSession {
779            trace_id,
780            client: session_row.get(0).and_then(cql_value_to_string),
781            command: session_row.get(1).and_then(cql_value_to_string),
782            coordinator: session_row.get(2).and_then(cql_value_to_string),
783            duration: session_row.get(3).and_then(cql_value_to_i32),
784            parameters: HashMap::new(),
785            request: session_row.get(5).and_then(cql_value_to_string),
786            started_at: session_row.get(6).and_then(cql_value_to_string),
787            events,
788        }))
789    }
790
791    async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
792        let result = self
793            .execute_unpaged(
794                "SELECT keyspace_name, replication, durable_writes \
795                 FROM system_schema.keyspaces",
796            )
797            .await?;
798
799        let mut keyspaces = Vec::new();
800        for row in &result.rows {
801            let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
802            let durable_writes = match row.get(2) {
803                Some(CqlValue::Boolean(b)) => *b,
804                _ => true,
805            };
806
807            keyspaces.push(KeyspaceMetadata {
808                name,
809                replication: HashMap::new(),
810                durable_writes,
811            });
812        }
813
814        Ok(keyspaces)
815    }
816
817    async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
818        let ks_escaped = keyspace.replace('\'', "''");
819
820        let result = self
821            .execute_unpaged(&format!(
822                "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{ks_escaped}'"
823            ))
824            .await?;
825
826        let mut tables = Vec::new();
827        for row in &result.rows {
828            let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
829            let tbl_escaped = table_name.replace('\'', "''");
830
831            let col_result = self
832                .execute_unpaged(&format!(
833                    "SELECT column_name, type, kind, position, clustering_order \
834                     FROM system_schema.columns \
835                     WHERE keyspace_name = '{ks_escaped}' AND table_name = '{tbl_escaped}'"
836                ))
837                .await?;
838
839            let mut pk_cols: Vec<(i32, String, String)> = Vec::new();
840            let mut ck_cols: Vec<(i32, String, String, String)> = Vec::new();
841            let mut regular_cols: Vec<(String, String)> = Vec::new();
842
843            for col_row in &col_result.rows {
844                let col_name = col_row
845                    .get_by_name("column_name", &col_result.columns)
846                    .map(|v| v.to_string())
847                    .unwrap_or_default();
848                let col_type = col_row
849                    .get_by_name("type", &col_result.columns)
850                    .map(|v| v.to_string())
851                    .unwrap_or_default();
852                let kind = col_row
853                    .get_by_name("kind", &col_result.columns)
854                    .map(|v| v.to_string())
855                    .unwrap_or_default();
856                let position = col_row
857                    .get_by_name("position", &col_result.columns)
858                    .and_then(|v| v.to_string().parse::<i32>().ok())
859                    .unwrap_or(0);
860                let clustering_order = col_row
861                    .get_by_name("clustering_order", &col_result.columns)
862                    .map(|v| v.to_string())
863                    .unwrap_or_else(|| "none".to_string());
864
865                match kind.as_str() {
866                    "partition_key" => pk_cols.push((position, col_name, col_type)),
867                    "clustering" => ck_cols.push((position, col_name, col_type, clustering_order)),
868                    _ => regular_cols.push((col_name, col_type)),
869                }
870            }
871
872            pk_cols.sort_by_key(|c| c.0);
873            ck_cols.sort_by_key(|c| c.0);
874
875            let partition_key: Vec<String> = pk_cols.iter().map(|c| c.1.clone()).collect();
876            let clustering_key: Vec<String> = ck_cols.iter().map(|c| c.1.clone()).collect();
877            let clustering_order: Vec<String> = ck_cols
878                .iter()
879                .map(|c| {
880                    let order = c.3.to_uppercase();
881                    if order == "NONE" || order.is_empty() {
882                        "ASC".to_string()
883                    } else {
884                        order
885                    }
886                })
887                .collect();
888
889            let mut columns: Vec<ColumnMetadata> = Vec::new();
890            for (_, name, typ) in &pk_cols {
891                columns.push(ColumnMetadata {
892                    name: name.clone(),
893                    type_name: typ.clone(),
894                });
895            }
896            for (_, name, typ, _) in &ck_cols {
897                columns.push(ColumnMetadata {
898                    name: name.clone(),
899                    type_name: typ.clone(),
900                });
901            }
902            for (name, typ) in &regular_cols {
903                columns.push(ColumnMetadata {
904                    name: name.clone(),
905                    type_name: typ.clone(),
906                });
907            }
908
909            let props_result = self
910                .execute_unpaged(&format!(
911                    "SELECT bloom_filter_fp_chance, caching, comment, compaction, compression, \
912                     crc_check_chance, default_time_to_live, gc_grace_seconds, \
913                     max_index_interval, memtable_flush_period_in_ms, min_index_interval, \
914                     speculative_retry \
915                     FROM system_schema.tables \
916                     WHERE keyspace_name = '{ks_escaped}' AND table_name = '{tbl_escaped}'"
917                ))
918                .await?;
919
920            let mut properties = std::collections::BTreeMap::new();
921            if let Some(props_row) = props_result.rows.first() {
922                let prop_names = [
923                    "bloom_filter_fp_chance",
924                    "caching",
925                    "comment",
926                    "compaction",
927                    "compression",
928                    "crc_check_chance",
929                    "default_time_to_live",
930                    "gc_grace_seconds",
931                    "max_index_interval",
932                    "memtable_flush_period_in_ms",
933                    "min_index_interval",
934                    "speculative_retry",
935                ];
936                for prop_name in &prop_names {
937                    if let Some(val) = props_row.get_by_name(prop_name, &props_result.columns) {
938                        properties.insert(prop_name.to_string(), val.to_string());
939                    }
940                }
941            }
942
943            tables.push(TableMetadata {
944                keyspace: keyspace.to_string(),
945                name: table_name,
946                columns,
947                partition_key,
948                clustering_key,
949                clustering_order,
950                properties,
951            });
952        }
953
954        Ok(tables)
955    }
956
957    async fn get_table_metadata(
958        &self,
959        keyspace: &str,
960        table: &str,
961    ) -> Result<Option<TableMetadata>> {
962        let tables = self.get_tables(keyspace).await?;
963        Ok(tables.into_iter().find(|t| t.name == table))
964    }
965
966    async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
967        let query = format!(
968            "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
969            keyspace.replace('\'', "''")
970        );
971        let result = self.execute_unpaged(&query).await?;
972        let udts = result
973            .rows
974            .iter()
975            .filter_map(|row| {
976                let name = row.get_by_name("type_name", &result.columns)?.to_string();
977                let field_names =
978                    Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
979                let field_types =
980                    Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
981                Some(UdtMetadata {
982                    keyspace: keyspace.to_string(),
983                    name,
984                    field_names,
985                    field_types,
986                })
987            })
988            .collect();
989        Ok(udts)
990    }
991
992    async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
993        let query = format!(
994            "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
995            keyspace.replace('\'', "''")
996        );
997        let result = self.execute_unpaged(&query).await?;
998        let functions = result
999            .rows
1000            .iter()
1001            .filter_map(|row| {
1002                let name = row
1003                    .get_by_name("function_name", &result.columns)?
1004                    .to_string();
1005                let argument_types = Self::extract_string_list_val(
1006                    row.get_by_name("argument_types", &result.columns),
1007                );
1008                let return_type = row
1009                    .get_by_name("return_type", &result.columns)
1010                    .map(|v| v.to_string())
1011                    .unwrap_or_default();
1012                Some(FunctionMetadata {
1013                    keyspace: keyspace.to_string(),
1014                    name,
1015                    argument_types,
1016                    return_type,
1017                })
1018            })
1019            .collect();
1020        Ok(functions)
1021    }
1022
1023    async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
1024        let query = format!(
1025            "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
1026            keyspace.replace('\'', "''")
1027        );
1028        let result = self.execute_unpaged(&query).await?;
1029        let aggregates = result
1030            .rows
1031            .iter()
1032            .filter_map(|row| {
1033                let name = row
1034                    .get_by_name("aggregate_name", &result.columns)?
1035                    .to_string();
1036                let argument_types = Self::extract_string_list_val(
1037                    row.get_by_name("argument_types", &result.columns),
1038                );
1039                let return_type = row
1040                    .get_by_name("return_type", &result.columns)
1041                    .map(|v| v.to_string())
1042                    .unwrap_or_default();
1043                Some(AggregateMetadata {
1044                    keyspace: keyspace.to_string(),
1045                    name,
1046                    argument_types,
1047                    return_type,
1048                })
1049            })
1050            .collect();
1051        Ok(aggregates)
1052    }
1053
1054    async fn get_cluster_name(&self) -> Result<Option<String>> {
1055        let result = self
1056            .execute_unpaged("SELECT cluster_name FROM system.local")
1057            .await?;
1058        Ok(result
1059            .rows
1060            .first()
1061            .and_then(|row| row.get(0))
1062            .and_then(cql_value_to_string))
1063    }
1064
1065    async fn get_cql_version(&self) -> Result<Option<String>> {
1066        let result = self
1067            .execute_unpaged("SELECT cql_version FROM system.local")
1068            .await?;
1069        Ok(result
1070            .rows
1071            .first()
1072            .and_then(|row| row.get(0))
1073            .and_then(cql_value_to_string))
1074    }
1075
1076    async fn get_release_version(&self) -> Result<Option<String>> {
1077        let result = self
1078            .execute_unpaged("SELECT release_version FROM system.local")
1079            .await?;
1080        Ok(result
1081            .rows
1082            .first()
1083            .and_then(|row| row.get(0))
1084            .and_then(cql_value_to_string))
1085    }
1086
1087    async fn get_scylla_version(&self) -> Result<Option<String>> {
1088        // ScyllaDB exposes its version in system.local.scylla_version
1089        // This column doesn't exist in Apache Cassandra, so errors are expected.
1090        let result = self
1091            .execute_unpaged("SELECT scylla_version FROM system.local")
1092            .await;
1093        match result {
1094            Ok(r) => Ok(r
1095                .rows
1096                .first()
1097                .and_then(|row| row.get(0))
1098                .and_then(cql_value_to_string)),
1099            Err(_) => Ok(None), // Column doesn't exist → not ScyllaDB
1100        }
1101    }
1102
1103    async fn is_connected(&self) -> bool {
1104        self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
1105            .await
1106            .is_ok()
1107    }
1108}
1109
1110/// Helper: extract a string from a CqlValue.
1111fn cql_value_to_string(v: &CqlValue) -> Option<String> {
1112    match v {
1113        CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
1114        CqlValue::Inet(addr) => Some(addr.to_string()),
1115        CqlValue::Null => None,
1116        other => Some(other.to_string()),
1117    }
1118}
1119
1120/// Helper: extract an i32 from a CqlValue.
1121fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
1122    match v {
1123        CqlValue::Int(i) => Some(*i),
1124        CqlValue::BigInt(i) => Some(*i as i32),
1125        CqlValue::SmallInt(i) => Some(*i as i32),
1126        CqlValue::TinyInt(i) => Some(*i as i32),
1127        _ => None,
1128    }
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134
1135    #[test]
1136    fn convert_scylla_value_text() {
1137        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
1138        assert_eq!(v, CqlValue::Text("hello".to_string()));
1139    }
1140
1141    #[test]
1142    fn convert_scylla_value_int() {
1143        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
1144        assert_eq!(v, CqlValue::Int(42));
1145    }
1146
1147    #[test]
1148    fn convert_scylla_value_boolean() {
1149        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
1150        assert_eq!(v, CqlValue::Boolean(true));
1151    }
1152
1153    #[test]
1154    fn convert_scylla_value_null() {
1155        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
1156        assert_eq!(v, CqlValue::Null);
1157    }
1158
1159    #[test]
1160    fn convert_scylla_value_list() {
1161        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
1162            ScyllaCqlValue::Int(1),
1163            ScyllaCqlValue::Int(2),
1164        ]));
1165        assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1166    }
1167
1168    #[test]
1169    fn convert_scylla_value_uuid() {
1170        let id = Uuid::nil();
1171        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
1172        assert_eq!(v, CqlValue::Uuid(id));
1173    }
1174
1175    #[test]
1176    fn convert_scylla_value_blob() {
1177        let v =
1178            ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
1179        assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
1180    }
1181
1182    #[test]
1183    fn convert_scylla_value_float() {
1184        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
1185        assert_eq!(v, CqlValue::Float(1.5));
1186    }
1187
1188    #[test]
1189    fn convert_scylla_value_double() {
1190        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
1191        assert_eq!(v, CqlValue::Double(1.5));
1192    }
1193
1194    #[test]
1195    fn convert_scylla_value_map() {
1196        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
1197            ScyllaCqlValue::Text("key".to_string()),
1198            ScyllaCqlValue::Int(42),
1199        )]));
1200        assert_eq!(
1201            v,
1202            CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
1203        );
1204    }
1205
1206    #[test]
1207    fn convert_scylla_value_set() {
1208        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
1209            ScyllaCqlValue::Int(1),
1210            ScyllaCqlValue::Int(2),
1211        ]));
1212        assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1213    }
1214
1215    #[test]
1216    fn convert_scylla_value_udt() {
1217        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1218            keyspace: "ks".to_string(),
1219            name: "my_type".to_string(),
1220            fields: vec![
1221                ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1222                ("f2".to_string(), None),
1223            ],
1224        });
1225        assert_eq!(
1226            v,
1227            CqlValue::UserDefinedType {
1228                keyspace: "ks".to_string(),
1229                type_name: "my_type".to_string(),
1230                fields: vec![
1231                    ("f1".to_string(), Some(CqlValue::Int(1))),
1232                    ("f2".to_string(), None),
1233                ],
1234            }
1235        );
1236    }
1237
1238    #[test]
1239    fn to_scylla_consistency_mapping() {
1240        use scylla::statement::Consistency as SC;
1241        assert!(matches!(
1242            ScyllaDriver::to_scylla_consistency(Consistency::One),
1243            SC::One
1244        ));
1245        assert!(matches!(
1246            ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1247            SC::Quorum
1248        ));
1249        assert!(matches!(
1250            ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1251            SC::LocalQuorum
1252        ));
1253        assert!(matches!(
1254            ScyllaDriver::to_scylla_consistency(Consistency::All),
1255            SC::All
1256        ));
1257    }
1258
1259    #[test]
1260    fn to_scylla_serial_consistency_mapping() {
1261        use scylla::statement::SerialConsistency as SC;
1262        assert!(matches!(
1263            ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1264            Some(SC::Serial)
1265        ));
1266        assert!(matches!(
1267            ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1268            Some(SC::LocalSerial)
1269        ));
1270        assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1271    }
1272
1273    #[test]
1274    fn cql_value_to_string_helper() {
1275        assert_eq!(
1276            cql_value_to_string(&CqlValue::Text("hello".to_string())),
1277            Some("hello".to_string())
1278        );
1279        assert_eq!(
1280            cql_value_to_string(&CqlValue::Int(42)),
1281            Some("42".to_string())
1282        );
1283        assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1284    }
1285
1286    #[test]
1287    fn cql_value_to_i32_helper() {
1288        assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1289        assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1290        assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1291    }
1292}