cqlsh_rs/driver/
scylla_driver.rs

1//! ScyllaDriver — CqlDriver implementation using the `scylla` crate.
2//!
3//! Provides connectivity to Apache Cassandra and ScyllaDB clusters using
4//! the scylla-rust-driver, with support for authentication, SSL/TLS,
5//! prepared statements, paging, and schema metadata queries.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::TryStreamExt;
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22    Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23    CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlValue};
28use super::{
29    AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30    KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31    UdtMetadata,
32};
33
34/// ScyllaDriver wraps a scylla `Session` and provides the `CqlDriver` trait.
35pub struct ScyllaDriver {
36    session: Session,
37    /// Cache of prepared statements keyed by internal ID.
38    prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39    /// Current consistency level.
40    consistency: Mutex<Consistency>,
41    /// Current serial consistency level.
42    serial_consistency: Mutex<Option<Consistency>>,
43    /// Whether tracing is enabled for queries.
44    tracing_enabled: AtomicBool,
45    /// Last tracing session ID.
46    last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50    /// Build the TLS configuration from SslConfig.
51    fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52        use rustls::pki_types::CertificateDer;
53        use std::fs::File;
54        use std::io::BufReader;
55
56        let mut root_store = rustls::RootCertStore::empty();
57
58        // Load CA certificate if provided
59        if let Some(certfile) = &ssl_config.certfile {
60            let file = File::open(certfile)
61                .with_context(|| format!("opening CA certificate: {certfile}"))?;
62            let mut reader = BufReader::new(file);
63            let certs = rustls_pemfile::certs(&mut reader)
64                .collect::<std::result::Result<Vec<_>, _>>()
65                .with_context(|| format!("parsing CA certificate: {certfile}"))?;
66            for cert in certs {
67                root_store
68                    .add(cert)
69                    .context("adding CA certificate to root store")?;
70            }
71        }
72
73        let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
74
75        // Client certificate authentication (mutual TLS)
76        let config = if let (Some(usercert_path), Some(userkey_path)) =
77            (&ssl_config.usercert, &ssl_config.userkey)
78        {
79            let cert_file = File::open(usercert_path)
80                .with_context(|| format!("opening client certificate: {usercert_path}"))?;
81            let mut cert_reader = BufReader::new(cert_file);
82            let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
83                .collect::<std::result::Result<Vec<_>, _>>()
84                .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
85
86            let key_file = File::open(userkey_path)
87                .with_context(|| format!("opening client key: {userkey_path}"))?;
88            let mut key_reader = BufReader::new(key_file);
89            let key = rustls_pemfile::private_key(&mut key_reader)
90                .with_context(|| format!("parsing client key: {userkey_path}"))?
91                .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
92
93            builder
94                .with_client_auth_cert(certs, key)
95                .context("configuring mutual TLS")?
96        } else {
97            builder.with_no_client_auth()
98        };
99
100        Ok(Arc::new(config))
101    }
102
103    /// Extract a `Vec<String>` from a `CqlValue::List` column value.
104    fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
105        match val {
106            Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
107            _ => Vec::new(),
108        }
109    }
110
111    /// Convert a scylla QueryResult into our CqlResult type.
112    fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
113        let tracing_id = result.tracing_id();
114        let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
115
116        // Check if this is a non-row result (DDL/DML)
117        if !result.is_rows() {
118            return Ok(CqlResult {
119                columns: Vec::new(),
120                rows: Vec::new(),
121                has_rows: false,
122                tracing_id,
123                warnings,
124            });
125        }
126
127        // Convert to QueryRowsResult to access typed rows
128        let rows_result = result
129            .into_rows_result()
130            .context("converting query result to rows")?;
131
132        // Extract column metadata
133        let col_specs = rows_result.column_specs();
134        let columns: Vec<CqlColumn> = col_specs
135            .iter()
136            .map(|spec| CqlColumn {
137                name: spec.name().to_string(),
138                type_name: format!("{:?}", spec.typ()),
139            })
140            .collect();
141
142        // Deserialize rows as untyped Row (Vec<Option<CqlValue>>)
143        let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
144
145        let mut cql_rows = Vec::new();
146        for row_result in typed_rows {
147            let row = row_result.context("deserializing row")?;
148            let values: Vec<CqlValue> = row
149                .columns
150                .into_iter()
151                .enumerate()
152                .map(|(col_idx, opt_val)| match opt_val {
153                    Some(v) => {
154                        tracing::debug!(
155                            column = col_idx,
156                            variant = ?std::mem::discriminant(&v),
157                            "converting ScyllaCqlValue: {v:?}"
158                        );
159                        Self::convert_scylla_value(v)
160                    }
161                    None => {
162                        tracing::debug!(column = col_idx, "column value is None (null)");
163                        CqlValue::Null
164                    }
165                })
166                .collect();
167            cql_rows.push(CqlRow { values });
168        }
169
170        Ok(CqlResult {
171            columns,
172            rows: cql_rows,
173            has_rows: true,
174            tracing_id,
175            warnings,
176        })
177    }
178
179    /// Convert a scylla CqlValue to our CqlValue type.
180    fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
181        match value {
182            ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
183            ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
184            ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
185            ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
186            ScyllaCqlValue::Decimal(d) => {
187                let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
188                let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
189                CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
190            }
191            ScyllaCqlValue::Date(d) => {
192                // scylla CqlDate wraps u32 days since epoch center (2^31)
193                let days = d.0;
194                let epoch_offset = days as i64 - (1i64 << 31);
195                match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
196                {
197                    Some(date) => CqlValue::Date(date),
198                    None => CqlValue::Text(format!("<invalid date: {days}>")),
199                }
200            }
201            ScyllaCqlValue::Double(d) => CqlValue::Double(d),
202            ScyllaCqlValue::Duration(d) => CqlValue::Duration {
203                months: d.months,
204                days: d.days,
205                nanoseconds: d.nanoseconds,
206            },
207            ScyllaCqlValue::Empty => CqlValue::Null,
208            ScyllaCqlValue::Float(f) => CqlValue::Float(f),
209            ScyllaCqlValue::Int(i) => CqlValue::Int(i),
210            ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
211            ScyllaCqlValue::Text(s) => CqlValue::Text(s),
212            ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
213            ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
214            ScyllaCqlValue::List(items) => {
215                CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
216            }
217            ScyllaCqlValue::Map(entries) => CqlValue::Map(
218                entries
219                    .into_iter()
220                    .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
221                    .collect(),
222            ),
223            ScyllaCqlValue::Set(items) => {
224                CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
225            }
226            ScyllaCqlValue::UserDefinedType {
227                keyspace,
228                name,
229                fields,
230            } => CqlValue::UserDefinedType {
231                keyspace,
232                type_name: name,
233                fields: fields
234                    .into_iter()
235                    .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
236                    .collect(),
237            },
238            ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
239            ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
240            ScyllaCqlValue::Time(t) => {
241                let nanos = t.0;
242                let secs = (nanos / 1_000_000_000) as u32;
243                let nano_part = (nanos % 1_000_000_000) as u32;
244                match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
245                    Some(time) => CqlValue::Time(time),
246                    None => CqlValue::Text(format!("<invalid time: {nanos}>")),
247                }
248            }
249            ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
250            ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
251                items
252                    .into_iter()
253                    .map(|v| v.map(Self::convert_scylla_value))
254                    .collect(),
255            ),
256            ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
257            ScyllaCqlValue::Varint(v) => {
258                let big_int =
259                    num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
260                CqlValue::Varint(big_int)
261            }
262            // CqlValue is non-exhaustive; handle future variants gracefully
263            _ => {
264                tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
265                CqlValue::Text(format!("{value:?}"))
266            }
267        }
268    }
269
270    /// Convert our internal CqlValue to scylla's CqlValue (reverse of convert_scylla_value).
271    fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
272        match v {
273            CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
274            CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
275            CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
276            CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
277            CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
278            CqlValue::Duration {
279                months,
280                days,
281                nanoseconds,
282            } => ScyllaCqlValue::Duration(CqlDuration {
283                months: *months,
284                days: *days,
285                nanoseconds: *nanoseconds,
286            }),
287            CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
288            CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
289            CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
290            CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
291            CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
292            CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
293            CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
294            CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
295            CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
296            CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
297            CqlValue::Date(d) => {
298                // Convert NaiveDate back to scylla's u32 days offset from 2^31 epoch
299                let days_from_ce = d.num_days_from_ce();
300                let epoch_offset = days_from_ce as i64 - 719_163;
301                let cql_days = (epoch_offset + (1i64 << 31)) as u32;
302                ScyllaCqlValue::Date(CqlDate(cql_days))
303            }
304            CqlValue::Time(t) => {
305                let nanos =
306                    t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
307                ScyllaCqlValue::Time(CqlTime(nanos))
308            }
309            CqlValue::Varint(bi) => {
310                let bytes = bi.to_signed_bytes_be();
311                ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
312            }
313            CqlValue::Decimal(d) => {
314                let (int_val, scale) = d.as_bigint_and_exponent();
315                let bytes = int_val.to_signed_bytes_be();
316                ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
317                    &bytes,
318                    scale as i32,
319                ))
320            }
321            CqlValue::List(items) => {
322                ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
323            }
324            CqlValue::Set(items) => {
325                ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
326            }
327            CqlValue::Map(entries) => ScyllaCqlValue::Map(
328                entries
329                    .iter()
330                    .map(|(k, v)| {
331                        (
332                            Self::internal_to_scylla_cql(k),
333                            Self::internal_to_scylla_cql(v),
334                        )
335                    })
336                    .collect(),
337            ),
338            CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
339                items
340                    .iter()
341                    .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
342                    .collect(),
343            ),
344            CqlValue::UserDefinedType {
345                keyspace,
346                type_name,
347                fields,
348            } => ScyllaCqlValue::UserDefinedType {
349                keyspace: keyspace.clone(),
350                name: type_name.clone(),
351                fields: fields
352                    .iter()
353                    .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
354                    .collect(),
355            },
356            CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
357        }
358    }
359
360    /// Convert our Consistency to scylla's Consistency.
361    fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
362        use scylla::statement::Consistency as SC;
363        match c {
364            Consistency::Any => SC::Any,
365            Consistency::One => SC::One,
366            Consistency::Two => SC::Two,
367            Consistency::Three => SC::Three,
368            Consistency::Quorum => SC::Quorum,
369            Consistency::All => SC::All,
370            Consistency::LocalQuorum => SC::LocalQuorum,
371            Consistency::EachQuorum => SC::EachQuorum,
372            Consistency::Serial => SC::Serial,
373            Consistency::LocalSerial => SC::LocalSerial,
374            Consistency::LocalOne => SC::LocalOne,
375        }
376    }
377
378    /// Convert our Consistency to scylla's SerialConsistency.
379    fn to_scylla_serial_consistency(
380        c: Consistency,
381    ) -> Option<scylla::statement::SerialConsistency> {
382        use scylla::statement::SerialConsistency as SC;
383        match c {
384            Consistency::Serial => Some(SC::Serial),
385            Consistency::LocalSerial => Some(SC::LocalSerial),
386            _ => None,
387        }
388    }
389
390    /// Build a Statement with the current consistency and tracing settings.
391    fn build_query(&self, cql: &str) -> Statement {
392        let mut stmt = Statement::new(cql);
393
394        let consistency = *self.consistency.lock().unwrap();
395        stmt.set_consistency(Self::to_scylla_consistency(consistency));
396
397        let serial = *self.serial_consistency.lock().unwrap();
398        if let Some(sc) = serial {
399            if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
400                stmt.set_serial_consistency(Some(sc));
401            }
402        }
403
404        if self.tracing_enabled.load(Ordering::Relaxed) {
405            stmt.set_tracing(true);
406        }
407
408        stmt
409    }
410
411    /// Store tracing ID from a result if present.
412    fn store_trace_id(&self, result: &QueryResult) {
413        if let Some(trace_id) = result.tracing_id() {
414            *self.last_trace_id.lock().unwrap() = Some(trace_id);
415        }
416    }
417}
418
419#[async_trait]
420impl CqlDriver for ScyllaDriver {
421    async fn connect(config: &ConnectionConfig) -> Result<Self> {
422        let addr = format!("{}:{}", config.host, config.port);
423
424        let mut builder = SessionBuilder::new().known_node(&addr);
425
426        // Authentication
427        if let (Some(username), Some(password)) = (&config.username, &config.password) {
428            builder = builder.user(username, password);
429        }
430
431        // Connection timeout
432        builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
433
434        // Default keyspace
435        if let Some(keyspace) = &config.keyspace {
436            builder = builder.use_keyspace(keyspace, false);
437        }
438
439        // SSL/TLS
440        if config.ssl {
441            let tls_config = if let Some(ssl_config) = &config.ssl_config {
442                Self::build_rustls_config(ssl_config)?
443            } else {
444                // SSL enabled but no config — use default (no validation)
445                let root_store = rustls::RootCertStore::empty();
446                Arc::new(
447                    rustls::ClientConfig::builder()
448                        .with_root_certificates(root_store)
449                        .with_no_client_auth(),
450                )
451            };
452            builder = builder.tls_context(Some(tls_config));
453        }
454
455        // NOTE: config.protocol_version is accepted for CLI compatibility but
456        // scylla-rust-driver 1.5.0 auto-negotiates the native protocol version.
457        // SessionBuilder has no method to force a specific protocol version.
458        // Similarly, the driver hardcodes CQL_VERSION="4.0.0" in the STARTUP frame.
459
460        let session = builder.build().await.context("connecting to cluster")?;
461
462        Ok(ScyllaDriver {
463            session,
464            prepared_cache: Mutex::new(HashMap::new()),
465            consistency: Mutex::new(Consistency::One),
466            serial_consistency: Mutex::new(None),
467            tracing_enabled: AtomicBool::new(false),
468            last_trace_id: Mutex::new(None),
469        })
470    }
471
472    async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
473        let stmt = self.build_query(query);
474
475        let result = self.session.query_unpaged(stmt, ()).await?;
476
477        self.store_trace_id(&result);
478        Self::convert_query_result(result)
479    }
480
481    async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
482        let mut stmt = self.build_query(query);
483        stmt.set_page_size(page_size);
484
485        let query_pager = self
486            .session
487            .query_iter(stmt, ())
488            .await
489            .context("starting paged query")?;
490
491        // Get column metadata from the pager
492        let col_specs = query_pager.column_specs();
493        let columns: Vec<CqlColumn> = col_specs
494            .iter()
495            .map(|spec| CqlColumn {
496                name: spec.name().to_string(),
497                type_name: format!("{:?}", spec.typ()),
498            })
499            .collect();
500
501        // Stream all rows using the untyped Row type
502        let mut rows_stream = query_pager.rows_stream::<Row>()?;
503        let mut cql_rows = Vec::new();
504
505        while let Some(row) = rows_stream.try_next().await? {
506            let values: Vec<CqlValue> = row
507                .columns
508                .into_iter()
509                .map(|opt_val| match opt_val {
510                    Some(v) => Self::convert_scylla_value(v),
511                    None => CqlValue::Null,
512                })
513                .collect();
514            cql_rows.push(CqlRow { values });
515        }
516
517        Ok(CqlResult {
518            columns,
519            rows: cql_rows,
520            has_rows: true,
521            tracing_id: None,
522            warnings: Vec::new(),
523        })
524    }
525
526    async fn prepare(&self, query: &str) -> Result<PreparedId> {
527        let prepared = self
528            .session
529            .prepare(query)
530            .await
531            .context("preparing CQL statement")?;
532
533        let id = prepared.get_id().to_vec();
534        self.prepared_cache
535            .lock()
536            .unwrap()
537            .insert(id.clone(), prepared);
538
539        Ok(PreparedId { inner: id })
540    }
541
542    async fn execute_prepared(
543        &self,
544        prepared_id: &PreparedId,
545        values: &[CqlValue],
546    ) -> Result<CqlResult> {
547        let prepared = self
548            .prepared_cache
549            .lock()
550            .unwrap()
551            .get(&prepared_id.inner)
552            .cloned()
553            .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
554
555        // Convert internal CqlValues to scylla CqlValues for binding.
556        // Null/Unset become None (bound as null), all others become Some(value).
557        let scylla_values: Vec<Option<ScyllaCqlValue>> = values
558            .iter()
559            .map(|v| match v {
560                CqlValue::Null | CqlValue::Unset => None,
561                other => Some(Self::internal_to_scylla_cql(other)),
562            })
563            .collect();
564
565        let result = self
566            .session
567            .execute_unpaged(&prepared, scylla_values)
568            .await
569            .context("executing prepared statement")?;
570
571        self.store_trace_id(&result);
572        Self::convert_query_result(result)
573    }
574
575    async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
576        self.session
577            .use_keyspace(keyspace, false)
578            .await
579            .with_context(|| format!("switching to keyspace: {keyspace}"))?;
580        Ok(())
581    }
582
583    fn get_consistency(&self) -> Consistency {
584        *self.consistency.lock().unwrap()
585    }
586
587    fn set_consistency(&self, consistency: Consistency) {
588        *self.consistency.lock().unwrap() = consistency;
589    }
590
591    fn get_serial_consistency(&self) -> Option<Consistency> {
592        *self.serial_consistency.lock().unwrap()
593    }
594
595    fn set_serial_consistency(&self, consistency: Option<Consistency>) {
596        *self.serial_consistency.lock().unwrap() = consistency;
597    }
598
599    fn set_tracing(&self, enabled: bool) {
600        self.tracing_enabled.store(enabled, Ordering::Relaxed);
601    }
602
603    fn is_tracing_enabled(&self) -> bool {
604        self.tracing_enabled.load(Ordering::Relaxed)
605    }
606
607    fn last_trace_id(&self) -> Option<Uuid> {
608        *self.last_trace_id.lock().unwrap()
609    }
610
611    async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
612        let query = format!(
613            "SELECT client, command, coordinator, duration, parameters, request, started_at \
614             FROM system_traces.sessions WHERE session_id = {}",
615            trace_id
616        );
617        let result = self.execute_unpaged(&query).await?;
618
619        if result.rows.is_empty() {
620            return Ok(None);
621        }
622
623        let events_query = format!(
624            "SELECT activity, source, source_elapsed, thread \
625             FROM system_traces.events WHERE session_id = {}",
626            trace_id
627        );
628        let events_result = self.execute_unpaged(&events_query).await?;
629
630        let events: Vec<TracingEvent> = events_result
631            .rows
632            .iter()
633            .map(|row| TracingEvent {
634                activity: row.get(0).and_then(cql_value_to_string),
635                source: row.get(1).and_then(cql_value_to_string),
636                source_elapsed: row.get(2).and_then(cql_value_to_i32),
637                thread: row.get(3).and_then(cql_value_to_string),
638            })
639            .collect();
640
641        let session_row = &result.rows[0];
642        Ok(Some(TracingSession {
643            trace_id,
644            client: session_row.get(0).and_then(cql_value_to_string),
645            command: session_row.get(1).and_then(cql_value_to_string),
646            coordinator: session_row.get(2).and_then(cql_value_to_string),
647            duration: session_row.get(3).and_then(cql_value_to_i32),
648            parameters: HashMap::new(),
649            request: session_row.get(5).and_then(cql_value_to_string),
650            started_at: session_row.get(6).and_then(cql_value_to_string),
651            events,
652        }))
653    }
654
655    async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
656        let result = self
657            .execute_unpaged(
658                "SELECT keyspace_name, replication, durable_writes \
659                 FROM system_schema.keyspaces",
660            )
661            .await?;
662
663        let mut keyspaces = Vec::new();
664        for row in &result.rows {
665            let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
666            let durable_writes = match row.get(2) {
667                Some(CqlValue::Boolean(b)) => *b,
668                _ => true,
669            };
670
671            keyspaces.push(KeyspaceMetadata {
672                name,
673                replication: HashMap::new(),
674                durable_writes,
675            });
676        }
677
678        Ok(keyspaces)
679    }
680
681    async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
682        let result = self
683            .execute_unpaged(&format!(
684                "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{}'",
685                keyspace.replace('\'', "''")
686            ))
687            .await?;
688
689        let mut tables = Vec::new();
690        for row in &result.rows {
691            let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
692
693            let col_result = self
694                .execute_unpaged(&format!(
695                    "SELECT column_name, type, kind \
696                     FROM system_schema.columns \
697                     WHERE keyspace_name = '{}' AND table_name = '{}'",
698                    keyspace.replace('\'', "''"),
699                    table_name.replace('\'', "''")
700                ))
701                .await?;
702
703            let mut columns = Vec::new();
704            let mut partition_key = Vec::new();
705            let mut clustering_key = Vec::new();
706
707            for col_row in &col_result.rows {
708                let col_name = col_row
709                    .get(0)
710                    .and_then(cql_value_to_string)
711                    .unwrap_or_default();
712                let col_type = col_row
713                    .get(1)
714                    .and_then(cql_value_to_string)
715                    .unwrap_or_default();
716                let kind = col_row
717                    .get(2)
718                    .and_then(cql_value_to_string)
719                    .unwrap_or_default();
720
721                columns.push(ColumnMetadata {
722                    name: col_name.clone(),
723                    type_name: col_type,
724                });
725
726                match kind.as_str() {
727                    "partition_key" => partition_key.push(col_name),
728                    "clustering" => clustering_key.push(col_name),
729                    _ => {}
730                }
731            }
732
733            tables.push(TableMetadata {
734                keyspace: keyspace.to_string(),
735                name: table_name,
736                columns,
737                partition_key,
738                clustering_key,
739            });
740        }
741
742        Ok(tables)
743    }
744
745    async fn get_table_metadata(
746        &self,
747        keyspace: &str,
748        table: &str,
749    ) -> Result<Option<TableMetadata>> {
750        let tables = self.get_tables(keyspace).await?;
751        Ok(tables.into_iter().find(|t| t.name == table))
752    }
753
754    async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
755        let query = format!(
756            "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
757            keyspace.replace('\'', "''")
758        );
759        let result = self.execute_unpaged(&query).await?;
760        let udts = result
761            .rows
762            .iter()
763            .filter_map(|row| {
764                let name = row.get_by_name("type_name", &result.columns)?.to_string();
765                let field_names =
766                    Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
767                let field_types =
768                    Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
769                Some(UdtMetadata {
770                    keyspace: keyspace.to_string(),
771                    name,
772                    field_names,
773                    field_types,
774                })
775            })
776            .collect();
777        Ok(udts)
778    }
779
780    async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
781        let query = format!(
782            "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
783            keyspace.replace('\'', "''")
784        );
785        let result = self.execute_unpaged(&query).await?;
786        let functions = result
787            .rows
788            .iter()
789            .filter_map(|row| {
790                let name = row
791                    .get_by_name("function_name", &result.columns)?
792                    .to_string();
793                let argument_types = Self::extract_string_list_val(
794                    row.get_by_name("argument_types", &result.columns),
795                );
796                let return_type = row
797                    .get_by_name("return_type", &result.columns)
798                    .map(|v| v.to_string())
799                    .unwrap_or_default();
800                Some(FunctionMetadata {
801                    keyspace: keyspace.to_string(),
802                    name,
803                    argument_types,
804                    return_type,
805                })
806            })
807            .collect();
808        Ok(functions)
809    }
810
811    async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
812        let query = format!(
813            "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
814            keyspace.replace('\'', "''")
815        );
816        let result = self.execute_unpaged(&query).await?;
817        let aggregates = result
818            .rows
819            .iter()
820            .filter_map(|row| {
821                let name = row
822                    .get_by_name("aggregate_name", &result.columns)?
823                    .to_string();
824                let argument_types = Self::extract_string_list_val(
825                    row.get_by_name("argument_types", &result.columns),
826                );
827                let return_type = row
828                    .get_by_name("return_type", &result.columns)
829                    .map(|v| v.to_string())
830                    .unwrap_or_default();
831                Some(AggregateMetadata {
832                    keyspace: keyspace.to_string(),
833                    name,
834                    argument_types,
835                    return_type,
836                })
837            })
838            .collect();
839        Ok(aggregates)
840    }
841
842    async fn get_cluster_name(&self) -> Result<Option<String>> {
843        let result = self
844            .execute_unpaged("SELECT cluster_name FROM system.local")
845            .await?;
846        Ok(result
847            .rows
848            .first()
849            .and_then(|row| row.get(0))
850            .and_then(cql_value_to_string))
851    }
852
853    async fn get_cql_version(&self) -> Result<Option<String>> {
854        let result = self
855            .execute_unpaged("SELECT cql_version FROM system.local")
856            .await?;
857        Ok(result
858            .rows
859            .first()
860            .and_then(|row| row.get(0))
861            .and_then(cql_value_to_string))
862    }
863
864    async fn get_release_version(&self) -> Result<Option<String>> {
865        let result = self
866            .execute_unpaged("SELECT release_version FROM system.local")
867            .await?;
868        Ok(result
869            .rows
870            .first()
871            .and_then(|row| row.get(0))
872            .and_then(cql_value_to_string))
873    }
874
875    async fn get_scylla_version(&self) -> Result<Option<String>> {
876        // ScyllaDB exposes its version in system.local.scylla_version
877        // This column doesn't exist in Apache Cassandra, so errors are expected.
878        let result = self
879            .execute_unpaged("SELECT scylla_version FROM system.local")
880            .await;
881        match result {
882            Ok(r) => Ok(r
883                .rows
884                .first()
885                .and_then(|row| row.get(0))
886                .and_then(cql_value_to_string)),
887            Err(_) => Ok(None), // Column doesn't exist → not ScyllaDB
888        }
889    }
890
891    async fn is_connected(&self) -> bool {
892        self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
893            .await
894            .is_ok()
895    }
896}
897
898/// Helper: extract a string from a CqlValue.
899fn cql_value_to_string(v: &CqlValue) -> Option<String> {
900    match v {
901        CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
902        CqlValue::Inet(addr) => Some(addr.to_string()),
903        CqlValue::Null => None,
904        other => Some(other.to_string()),
905    }
906}
907
908/// Helper: extract an i32 from a CqlValue.
909fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
910    match v {
911        CqlValue::Int(i) => Some(*i),
912        CqlValue::BigInt(i) => Some(*i as i32),
913        CqlValue::SmallInt(i) => Some(*i as i32),
914        CqlValue::TinyInt(i) => Some(*i as i32),
915        _ => None,
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922
923    #[test]
924    fn convert_scylla_value_text() {
925        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
926        assert_eq!(v, CqlValue::Text("hello".to_string()));
927    }
928
929    #[test]
930    fn convert_scylla_value_int() {
931        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
932        assert_eq!(v, CqlValue::Int(42));
933    }
934
935    #[test]
936    fn convert_scylla_value_boolean() {
937        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
938        assert_eq!(v, CqlValue::Boolean(true));
939    }
940
941    #[test]
942    fn convert_scylla_value_null() {
943        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
944        assert_eq!(v, CqlValue::Null);
945    }
946
947    #[test]
948    fn convert_scylla_value_list() {
949        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
950            ScyllaCqlValue::Int(1),
951            ScyllaCqlValue::Int(2),
952        ]));
953        assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
954    }
955
956    #[test]
957    fn convert_scylla_value_uuid() {
958        let id = Uuid::nil();
959        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
960        assert_eq!(v, CqlValue::Uuid(id));
961    }
962
963    #[test]
964    fn convert_scylla_value_blob() {
965        let v =
966            ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
967        assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
968    }
969
970    #[test]
971    fn convert_scylla_value_float() {
972        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
973        assert_eq!(v, CqlValue::Float(1.5));
974    }
975
976    #[test]
977    fn convert_scylla_value_double() {
978        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
979        assert_eq!(v, CqlValue::Double(1.5));
980    }
981
982    #[test]
983    fn convert_scylla_value_map() {
984        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
985            ScyllaCqlValue::Text("key".to_string()),
986            ScyllaCqlValue::Int(42),
987        )]));
988        assert_eq!(
989            v,
990            CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
991        );
992    }
993
994    #[test]
995    fn convert_scylla_value_set() {
996        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
997            ScyllaCqlValue::Int(1),
998            ScyllaCqlValue::Int(2),
999        ]));
1000        assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1001    }
1002
1003    #[test]
1004    fn convert_scylla_value_udt() {
1005        let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1006            keyspace: "ks".to_string(),
1007            name: "my_type".to_string(),
1008            fields: vec![
1009                ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1010                ("f2".to_string(), None),
1011            ],
1012        });
1013        assert_eq!(
1014            v,
1015            CqlValue::UserDefinedType {
1016                keyspace: "ks".to_string(),
1017                type_name: "my_type".to_string(),
1018                fields: vec![
1019                    ("f1".to_string(), Some(CqlValue::Int(1))),
1020                    ("f2".to_string(), None),
1021                ],
1022            }
1023        );
1024    }
1025
1026    #[test]
1027    fn to_scylla_consistency_mapping() {
1028        use scylla::statement::Consistency as SC;
1029        assert!(matches!(
1030            ScyllaDriver::to_scylla_consistency(Consistency::One),
1031            SC::One
1032        ));
1033        assert!(matches!(
1034            ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1035            SC::Quorum
1036        ));
1037        assert!(matches!(
1038            ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1039            SC::LocalQuorum
1040        ));
1041        assert!(matches!(
1042            ScyllaDriver::to_scylla_consistency(Consistency::All),
1043            SC::All
1044        ));
1045    }
1046
1047    #[test]
1048    fn to_scylla_serial_consistency_mapping() {
1049        use scylla::statement::SerialConsistency as SC;
1050        assert!(matches!(
1051            ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1052            Some(SC::Serial)
1053        ));
1054        assert!(matches!(
1055            ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1056            Some(SC::LocalSerial)
1057        ));
1058        assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1059    }
1060
1061    #[test]
1062    fn cql_value_to_string_helper() {
1063        assert_eq!(
1064            cql_value_to_string(&CqlValue::Text("hello".to_string())),
1065            Some("hello".to_string())
1066        );
1067        assert_eq!(
1068            cql_value_to_string(&CqlValue::Int(42)),
1069            Some("42".to_string())
1070        );
1071        assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1072    }
1073
1074    #[test]
1075    fn cql_value_to_i32_helper() {
1076        assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1077        assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1078        assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1079    }
1080}