1use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::TryStreamExt;
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22 Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23 CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlValue};
28use super::{
29 AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30 KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31 UdtMetadata,
32};
33
34pub struct ScyllaDriver {
36 session: Session,
37 prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39 consistency: Mutex<Consistency>,
41 serial_consistency: Mutex<Option<Consistency>>,
43 tracing_enabled: AtomicBool,
45 last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50 fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52 use rustls::pki_types::CertificateDer;
53 use std::fs::File;
54 use std::io::BufReader;
55
56 let mut root_store = rustls::RootCertStore::empty();
57
58 if let Some(certfile) = &ssl_config.certfile {
60 let file = File::open(certfile)
61 .with_context(|| format!("opening CA certificate: {certfile}"))?;
62 let mut reader = BufReader::new(file);
63 let certs = rustls_pemfile::certs(&mut reader)
64 .collect::<std::result::Result<Vec<_>, _>>()
65 .with_context(|| format!("parsing CA certificate: {certfile}"))?;
66 for cert in certs {
67 root_store
68 .add(cert)
69 .context("adding CA certificate to root store")?;
70 }
71 }
72
73 let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
74
75 let config = if let (Some(usercert_path), Some(userkey_path)) =
77 (&ssl_config.usercert, &ssl_config.userkey)
78 {
79 let cert_file = File::open(usercert_path)
80 .with_context(|| format!("opening client certificate: {usercert_path}"))?;
81 let mut cert_reader = BufReader::new(cert_file);
82 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
83 .collect::<std::result::Result<Vec<_>, _>>()
84 .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
85
86 let key_file = File::open(userkey_path)
87 .with_context(|| format!("opening client key: {userkey_path}"))?;
88 let mut key_reader = BufReader::new(key_file);
89 let key = rustls_pemfile::private_key(&mut key_reader)
90 .with_context(|| format!("parsing client key: {userkey_path}"))?
91 .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
92
93 builder
94 .with_client_auth_cert(certs, key)
95 .context("configuring mutual TLS")?
96 } else {
97 builder.with_no_client_auth()
98 };
99
100 Ok(Arc::new(config))
101 }
102
103 fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
105 match val {
106 Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
107 _ => Vec::new(),
108 }
109 }
110
111 fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
113 let tracing_id = result.tracing_id();
114 let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
115
116 if !result.is_rows() {
118 return Ok(CqlResult {
119 columns: Vec::new(),
120 rows: Vec::new(),
121 has_rows: false,
122 tracing_id,
123 warnings,
124 });
125 }
126
127 let rows_result = result
129 .into_rows_result()
130 .context("converting query result to rows")?;
131
132 let col_specs = rows_result.column_specs();
134 let columns: Vec<CqlColumn> = col_specs
135 .iter()
136 .map(|spec| CqlColumn {
137 name: spec.name().to_string(),
138 type_name: format!("{:?}", spec.typ()),
139 })
140 .collect();
141
142 let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
144
145 let mut cql_rows = Vec::new();
146 for row_result in typed_rows {
147 let row = row_result.context("deserializing row")?;
148 let values: Vec<CqlValue> = row
149 .columns
150 .into_iter()
151 .enumerate()
152 .map(|(col_idx, opt_val)| match opt_val {
153 Some(v) => {
154 tracing::debug!(
155 column = col_idx,
156 variant = ?std::mem::discriminant(&v),
157 "converting ScyllaCqlValue: {v:?}"
158 );
159 Self::convert_scylla_value(v)
160 }
161 None => {
162 tracing::debug!(column = col_idx, "column value is None (null)");
163 CqlValue::Null
164 }
165 })
166 .collect();
167 cql_rows.push(CqlRow { values });
168 }
169
170 Ok(CqlResult {
171 columns,
172 rows: cql_rows,
173 has_rows: true,
174 tracing_id,
175 warnings,
176 })
177 }
178
179 fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
181 match value {
182 ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
183 ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
184 ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
185 ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
186 ScyllaCqlValue::Decimal(d) => {
187 let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
188 let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
189 CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
190 }
191 ScyllaCqlValue::Date(d) => {
192 let days = d.0;
194 let epoch_offset = days as i64 - (1i64 << 31);
195 match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
196 {
197 Some(date) => CqlValue::Date(date),
198 None => CqlValue::Text(format!("<invalid date: {days}>")),
199 }
200 }
201 ScyllaCqlValue::Double(d) => CqlValue::Double(d),
202 ScyllaCqlValue::Duration(d) => CqlValue::Duration {
203 months: d.months,
204 days: d.days,
205 nanoseconds: d.nanoseconds,
206 },
207 ScyllaCqlValue::Empty => CqlValue::Null,
208 ScyllaCqlValue::Float(f) => CqlValue::Float(f),
209 ScyllaCqlValue::Int(i) => CqlValue::Int(i),
210 ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
211 ScyllaCqlValue::Text(s) => CqlValue::Text(s),
212 ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
213 ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
214 ScyllaCqlValue::List(items) => {
215 CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
216 }
217 ScyllaCqlValue::Map(entries) => CqlValue::Map(
218 entries
219 .into_iter()
220 .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
221 .collect(),
222 ),
223 ScyllaCqlValue::Set(items) => {
224 CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
225 }
226 ScyllaCqlValue::UserDefinedType {
227 keyspace,
228 name,
229 fields,
230 } => CqlValue::UserDefinedType {
231 keyspace,
232 type_name: name,
233 fields: fields
234 .into_iter()
235 .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
236 .collect(),
237 },
238 ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
239 ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
240 ScyllaCqlValue::Time(t) => {
241 let nanos = t.0;
242 let secs = (nanos / 1_000_000_000) as u32;
243 let nano_part = (nanos % 1_000_000_000) as u32;
244 match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
245 Some(time) => CqlValue::Time(time),
246 None => CqlValue::Text(format!("<invalid time: {nanos}>")),
247 }
248 }
249 ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
250 ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
251 items
252 .into_iter()
253 .map(|v| v.map(Self::convert_scylla_value))
254 .collect(),
255 ),
256 ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
257 ScyllaCqlValue::Varint(v) => {
258 let big_int =
259 num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
260 CqlValue::Varint(big_int)
261 }
262 _ => {
264 tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
265 CqlValue::Text(format!("{value:?}"))
266 }
267 }
268 }
269
270 fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
272 match v {
273 CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
274 CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
275 CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
276 CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
277 CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
278 CqlValue::Duration {
279 months,
280 days,
281 nanoseconds,
282 } => ScyllaCqlValue::Duration(CqlDuration {
283 months: *months,
284 days: *days,
285 nanoseconds: *nanoseconds,
286 }),
287 CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
288 CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
289 CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
290 CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
291 CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
292 CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
293 CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
294 CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
295 CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
296 CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
297 CqlValue::Date(d) => {
298 let days_from_ce = d.num_days_from_ce();
300 let epoch_offset = days_from_ce as i64 - 719_163;
301 let cql_days = (epoch_offset + (1i64 << 31)) as u32;
302 ScyllaCqlValue::Date(CqlDate(cql_days))
303 }
304 CqlValue::Time(t) => {
305 let nanos =
306 t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
307 ScyllaCqlValue::Time(CqlTime(nanos))
308 }
309 CqlValue::Varint(bi) => {
310 let bytes = bi.to_signed_bytes_be();
311 ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
312 }
313 CqlValue::Decimal(d) => {
314 let (int_val, scale) = d.as_bigint_and_exponent();
315 let bytes = int_val.to_signed_bytes_be();
316 ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
317 &bytes,
318 scale as i32,
319 ))
320 }
321 CqlValue::List(items) => {
322 ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
323 }
324 CqlValue::Set(items) => {
325 ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
326 }
327 CqlValue::Map(entries) => ScyllaCqlValue::Map(
328 entries
329 .iter()
330 .map(|(k, v)| {
331 (
332 Self::internal_to_scylla_cql(k),
333 Self::internal_to_scylla_cql(v),
334 )
335 })
336 .collect(),
337 ),
338 CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
339 items
340 .iter()
341 .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
342 .collect(),
343 ),
344 CqlValue::UserDefinedType {
345 keyspace,
346 type_name,
347 fields,
348 } => ScyllaCqlValue::UserDefinedType {
349 keyspace: keyspace.clone(),
350 name: type_name.clone(),
351 fields: fields
352 .iter()
353 .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
354 .collect(),
355 },
356 CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
357 }
358 }
359
360 fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
362 use scylla::statement::Consistency as SC;
363 match c {
364 Consistency::Any => SC::Any,
365 Consistency::One => SC::One,
366 Consistency::Two => SC::Two,
367 Consistency::Three => SC::Three,
368 Consistency::Quorum => SC::Quorum,
369 Consistency::All => SC::All,
370 Consistency::LocalQuorum => SC::LocalQuorum,
371 Consistency::EachQuorum => SC::EachQuorum,
372 Consistency::Serial => SC::Serial,
373 Consistency::LocalSerial => SC::LocalSerial,
374 Consistency::LocalOne => SC::LocalOne,
375 }
376 }
377
378 fn to_scylla_serial_consistency(
380 c: Consistency,
381 ) -> Option<scylla::statement::SerialConsistency> {
382 use scylla::statement::SerialConsistency as SC;
383 match c {
384 Consistency::Serial => Some(SC::Serial),
385 Consistency::LocalSerial => Some(SC::LocalSerial),
386 _ => None,
387 }
388 }
389
390 fn build_query(&self, cql: &str) -> Statement {
392 let mut stmt = Statement::new(cql);
393
394 let consistency = *self.consistency.lock().unwrap();
395 stmt.set_consistency(Self::to_scylla_consistency(consistency));
396
397 let serial = *self.serial_consistency.lock().unwrap();
398 if let Some(sc) = serial {
399 if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
400 stmt.set_serial_consistency(Some(sc));
401 }
402 }
403
404 if self.tracing_enabled.load(Ordering::Relaxed) {
405 stmt.set_tracing(true);
406 }
407
408 stmt
409 }
410
411 fn store_trace_id(&self, result: &QueryResult) {
413 if let Some(trace_id) = result.tracing_id() {
414 *self.last_trace_id.lock().unwrap() = Some(trace_id);
415 }
416 }
417}
418
419#[async_trait]
420impl CqlDriver for ScyllaDriver {
421 async fn connect(config: &ConnectionConfig) -> Result<Self> {
422 let addr = format!("{}:{}", config.host, config.port);
423
424 let mut builder = SessionBuilder::new().known_node(&addr);
425
426 if let (Some(username), Some(password)) = (&config.username, &config.password) {
428 builder = builder.user(username, password);
429 }
430
431 builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
433
434 if let Some(keyspace) = &config.keyspace {
436 builder = builder.use_keyspace(keyspace, false);
437 }
438
439 if config.ssl {
441 let tls_config = if let Some(ssl_config) = &config.ssl_config {
442 Self::build_rustls_config(ssl_config)?
443 } else {
444 let root_store = rustls::RootCertStore::empty();
446 Arc::new(
447 rustls::ClientConfig::builder()
448 .with_root_certificates(root_store)
449 .with_no_client_auth(),
450 )
451 };
452 builder = builder.tls_context(Some(tls_config));
453 }
454
455 let session = builder.build().await.context("connecting to cluster")?;
461
462 Ok(ScyllaDriver {
463 session,
464 prepared_cache: Mutex::new(HashMap::new()),
465 consistency: Mutex::new(Consistency::One),
466 serial_consistency: Mutex::new(None),
467 tracing_enabled: AtomicBool::new(false),
468 last_trace_id: Mutex::new(None),
469 })
470 }
471
472 async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
473 let stmt = self.build_query(query);
474
475 let result = self.session.query_unpaged(stmt, ()).await?;
476
477 self.store_trace_id(&result);
478 Self::convert_query_result(result)
479 }
480
481 async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
482 let mut stmt = self.build_query(query);
483 stmt.set_page_size(page_size);
484
485 let query_pager = self
486 .session
487 .query_iter(stmt, ())
488 .await
489 .context("starting paged query")?;
490
491 let col_specs = query_pager.column_specs();
493 let columns: Vec<CqlColumn> = col_specs
494 .iter()
495 .map(|spec| CqlColumn {
496 name: spec.name().to_string(),
497 type_name: format!("{:?}", spec.typ()),
498 })
499 .collect();
500
501 let mut rows_stream = query_pager.rows_stream::<Row>()?;
503 let mut cql_rows = Vec::new();
504
505 while let Some(row) = rows_stream.try_next().await? {
506 let values: Vec<CqlValue> = row
507 .columns
508 .into_iter()
509 .map(|opt_val| match opt_val {
510 Some(v) => Self::convert_scylla_value(v),
511 None => CqlValue::Null,
512 })
513 .collect();
514 cql_rows.push(CqlRow { values });
515 }
516
517 Ok(CqlResult {
518 columns,
519 rows: cql_rows,
520 has_rows: true,
521 tracing_id: None,
522 warnings: Vec::new(),
523 })
524 }
525
526 async fn prepare(&self, query: &str) -> Result<PreparedId> {
527 let prepared = self
528 .session
529 .prepare(query)
530 .await
531 .context("preparing CQL statement")?;
532
533 let id = prepared.get_id().to_vec();
534 self.prepared_cache
535 .lock()
536 .unwrap()
537 .insert(id.clone(), prepared);
538
539 Ok(PreparedId { inner: id })
540 }
541
542 async fn execute_prepared(
543 &self,
544 prepared_id: &PreparedId,
545 values: &[CqlValue],
546 ) -> Result<CqlResult> {
547 let prepared = self
548 .prepared_cache
549 .lock()
550 .unwrap()
551 .get(&prepared_id.inner)
552 .cloned()
553 .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
554
555 let scylla_values: Vec<Option<ScyllaCqlValue>> = values
558 .iter()
559 .map(|v| match v {
560 CqlValue::Null | CqlValue::Unset => None,
561 other => Some(Self::internal_to_scylla_cql(other)),
562 })
563 .collect();
564
565 let result = self
566 .session
567 .execute_unpaged(&prepared, scylla_values)
568 .await
569 .context("executing prepared statement")?;
570
571 self.store_trace_id(&result);
572 Self::convert_query_result(result)
573 }
574
575 async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
576 self.session
577 .use_keyspace(keyspace, false)
578 .await
579 .with_context(|| format!("switching to keyspace: {keyspace}"))?;
580 Ok(())
581 }
582
583 fn get_consistency(&self) -> Consistency {
584 *self.consistency.lock().unwrap()
585 }
586
587 fn set_consistency(&self, consistency: Consistency) {
588 *self.consistency.lock().unwrap() = consistency;
589 }
590
591 fn get_serial_consistency(&self) -> Option<Consistency> {
592 *self.serial_consistency.lock().unwrap()
593 }
594
595 fn set_serial_consistency(&self, consistency: Option<Consistency>) {
596 *self.serial_consistency.lock().unwrap() = consistency;
597 }
598
599 fn set_tracing(&self, enabled: bool) {
600 self.tracing_enabled.store(enabled, Ordering::Relaxed);
601 }
602
603 fn is_tracing_enabled(&self) -> bool {
604 self.tracing_enabled.load(Ordering::Relaxed)
605 }
606
607 fn last_trace_id(&self) -> Option<Uuid> {
608 *self.last_trace_id.lock().unwrap()
609 }
610
611 async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
612 let query = format!(
613 "SELECT client, command, coordinator, duration, parameters, request, started_at \
614 FROM system_traces.sessions WHERE session_id = {}",
615 trace_id
616 );
617 let result = self.execute_unpaged(&query).await?;
618
619 if result.rows.is_empty() {
620 return Ok(None);
621 }
622
623 let events_query = format!(
624 "SELECT activity, source, source_elapsed, thread \
625 FROM system_traces.events WHERE session_id = {}",
626 trace_id
627 );
628 let events_result = self.execute_unpaged(&events_query).await?;
629
630 let events: Vec<TracingEvent> = events_result
631 .rows
632 .iter()
633 .map(|row| TracingEvent {
634 activity: row.get(0).and_then(cql_value_to_string),
635 source: row.get(1).and_then(cql_value_to_string),
636 source_elapsed: row.get(2).and_then(cql_value_to_i32),
637 thread: row.get(3).and_then(cql_value_to_string),
638 })
639 .collect();
640
641 let session_row = &result.rows[0];
642 Ok(Some(TracingSession {
643 trace_id,
644 client: session_row.get(0).and_then(cql_value_to_string),
645 command: session_row.get(1).and_then(cql_value_to_string),
646 coordinator: session_row.get(2).and_then(cql_value_to_string),
647 duration: session_row.get(3).and_then(cql_value_to_i32),
648 parameters: HashMap::new(),
649 request: session_row.get(5).and_then(cql_value_to_string),
650 started_at: session_row.get(6).and_then(cql_value_to_string),
651 events,
652 }))
653 }
654
655 async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
656 let result = self
657 .execute_unpaged(
658 "SELECT keyspace_name, replication, durable_writes \
659 FROM system_schema.keyspaces",
660 )
661 .await?;
662
663 let mut keyspaces = Vec::new();
664 for row in &result.rows {
665 let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
666 let durable_writes = match row.get(2) {
667 Some(CqlValue::Boolean(b)) => *b,
668 _ => true,
669 };
670
671 keyspaces.push(KeyspaceMetadata {
672 name,
673 replication: HashMap::new(),
674 durable_writes,
675 });
676 }
677
678 Ok(keyspaces)
679 }
680
681 async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
682 let result = self
683 .execute_unpaged(&format!(
684 "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{}'",
685 keyspace.replace('\'', "''")
686 ))
687 .await?;
688
689 let mut tables = Vec::new();
690 for row in &result.rows {
691 let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
692
693 let col_result = self
694 .execute_unpaged(&format!(
695 "SELECT column_name, type, kind \
696 FROM system_schema.columns \
697 WHERE keyspace_name = '{}' AND table_name = '{}'",
698 keyspace.replace('\'', "''"),
699 table_name.replace('\'', "''")
700 ))
701 .await?;
702
703 let mut columns = Vec::new();
704 let mut partition_key = Vec::new();
705 let mut clustering_key = Vec::new();
706
707 for col_row in &col_result.rows {
708 let col_name = col_row
709 .get(0)
710 .and_then(cql_value_to_string)
711 .unwrap_or_default();
712 let col_type = col_row
713 .get(1)
714 .and_then(cql_value_to_string)
715 .unwrap_or_default();
716 let kind = col_row
717 .get(2)
718 .and_then(cql_value_to_string)
719 .unwrap_or_default();
720
721 columns.push(ColumnMetadata {
722 name: col_name.clone(),
723 type_name: col_type,
724 });
725
726 match kind.as_str() {
727 "partition_key" => partition_key.push(col_name),
728 "clustering" => clustering_key.push(col_name),
729 _ => {}
730 }
731 }
732
733 tables.push(TableMetadata {
734 keyspace: keyspace.to_string(),
735 name: table_name,
736 columns,
737 partition_key,
738 clustering_key,
739 });
740 }
741
742 Ok(tables)
743 }
744
745 async fn get_table_metadata(
746 &self,
747 keyspace: &str,
748 table: &str,
749 ) -> Result<Option<TableMetadata>> {
750 let tables = self.get_tables(keyspace).await?;
751 Ok(tables.into_iter().find(|t| t.name == table))
752 }
753
754 async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
755 let query = format!(
756 "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
757 keyspace.replace('\'', "''")
758 );
759 let result = self.execute_unpaged(&query).await?;
760 let udts = result
761 .rows
762 .iter()
763 .filter_map(|row| {
764 let name = row.get_by_name("type_name", &result.columns)?.to_string();
765 let field_names =
766 Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
767 let field_types =
768 Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
769 Some(UdtMetadata {
770 keyspace: keyspace.to_string(),
771 name,
772 field_names,
773 field_types,
774 })
775 })
776 .collect();
777 Ok(udts)
778 }
779
780 async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
781 let query = format!(
782 "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
783 keyspace.replace('\'', "''")
784 );
785 let result = self.execute_unpaged(&query).await?;
786 let functions = result
787 .rows
788 .iter()
789 .filter_map(|row| {
790 let name = row
791 .get_by_name("function_name", &result.columns)?
792 .to_string();
793 let argument_types = Self::extract_string_list_val(
794 row.get_by_name("argument_types", &result.columns),
795 );
796 let return_type = row
797 .get_by_name("return_type", &result.columns)
798 .map(|v| v.to_string())
799 .unwrap_or_default();
800 Some(FunctionMetadata {
801 keyspace: keyspace.to_string(),
802 name,
803 argument_types,
804 return_type,
805 })
806 })
807 .collect();
808 Ok(functions)
809 }
810
811 async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
812 let query = format!(
813 "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
814 keyspace.replace('\'', "''")
815 );
816 let result = self.execute_unpaged(&query).await?;
817 let aggregates = result
818 .rows
819 .iter()
820 .filter_map(|row| {
821 let name = row
822 .get_by_name("aggregate_name", &result.columns)?
823 .to_string();
824 let argument_types = Self::extract_string_list_val(
825 row.get_by_name("argument_types", &result.columns),
826 );
827 let return_type = row
828 .get_by_name("return_type", &result.columns)
829 .map(|v| v.to_string())
830 .unwrap_or_default();
831 Some(AggregateMetadata {
832 keyspace: keyspace.to_string(),
833 name,
834 argument_types,
835 return_type,
836 })
837 })
838 .collect();
839 Ok(aggregates)
840 }
841
842 async fn get_cluster_name(&self) -> Result<Option<String>> {
843 let result = self
844 .execute_unpaged("SELECT cluster_name FROM system.local")
845 .await?;
846 Ok(result
847 .rows
848 .first()
849 .and_then(|row| row.get(0))
850 .and_then(cql_value_to_string))
851 }
852
853 async fn get_cql_version(&self) -> Result<Option<String>> {
854 let result = self
855 .execute_unpaged("SELECT cql_version FROM system.local")
856 .await?;
857 Ok(result
858 .rows
859 .first()
860 .and_then(|row| row.get(0))
861 .and_then(cql_value_to_string))
862 }
863
864 async fn get_release_version(&self) -> Result<Option<String>> {
865 let result = self
866 .execute_unpaged("SELECT release_version FROM system.local")
867 .await?;
868 Ok(result
869 .rows
870 .first()
871 .and_then(|row| row.get(0))
872 .and_then(cql_value_to_string))
873 }
874
875 async fn get_scylla_version(&self) -> Result<Option<String>> {
876 let result = self
879 .execute_unpaged("SELECT scylla_version FROM system.local")
880 .await;
881 match result {
882 Ok(r) => Ok(r
883 .rows
884 .first()
885 .and_then(|row| row.get(0))
886 .and_then(cql_value_to_string)),
887 Err(_) => Ok(None), }
889 }
890
891 async fn is_connected(&self) -> bool {
892 self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
893 .await
894 .is_ok()
895 }
896}
897
898fn cql_value_to_string(v: &CqlValue) -> Option<String> {
900 match v {
901 CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
902 CqlValue::Inet(addr) => Some(addr.to_string()),
903 CqlValue::Null => None,
904 other => Some(other.to_string()),
905 }
906}
907
908fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
910 match v {
911 CqlValue::Int(i) => Some(*i),
912 CqlValue::BigInt(i) => Some(*i as i32),
913 CqlValue::SmallInt(i) => Some(*i as i32),
914 CqlValue::TinyInt(i) => Some(*i as i32),
915 _ => None,
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::*;
922
923 #[test]
924 fn convert_scylla_value_text() {
925 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
926 assert_eq!(v, CqlValue::Text("hello".to_string()));
927 }
928
929 #[test]
930 fn convert_scylla_value_int() {
931 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
932 assert_eq!(v, CqlValue::Int(42));
933 }
934
935 #[test]
936 fn convert_scylla_value_boolean() {
937 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
938 assert_eq!(v, CqlValue::Boolean(true));
939 }
940
941 #[test]
942 fn convert_scylla_value_null() {
943 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
944 assert_eq!(v, CqlValue::Null);
945 }
946
947 #[test]
948 fn convert_scylla_value_list() {
949 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
950 ScyllaCqlValue::Int(1),
951 ScyllaCqlValue::Int(2),
952 ]));
953 assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
954 }
955
956 #[test]
957 fn convert_scylla_value_uuid() {
958 let id = Uuid::nil();
959 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
960 assert_eq!(v, CqlValue::Uuid(id));
961 }
962
963 #[test]
964 fn convert_scylla_value_blob() {
965 let v =
966 ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
967 assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
968 }
969
970 #[test]
971 fn convert_scylla_value_float() {
972 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
973 assert_eq!(v, CqlValue::Float(1.5));
974 }
975
976 #[test]
977 fn convert_scylla_value_double() {
978 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
979 assert_eq!(v, CqlValue::Double(1.5));
980 }
981
982 #[test]
983 fn convert_scylla_value_map() {
984 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
985 ScyllaCqlValue::Text("key".to_string()),
986 ScyllaCqlValue::Int(42),
987 )]));
988 assert_eq!(
989 v,
990 CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
991 );
992 }
993
994 #[test]
995 fn convert_scylla_value_set() {
996 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
997 ScyllaCqlValue::Int(1),
998 ScyllaCqlValue::Int(2),
999 ]));
1000 assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1001 }
1002
1003 #[test]
1004 fn convert_scylla_value_udt() {
1005 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1006 keyspace: "ks".to_string(),
1007 name: "my_type".to_string(),
1008 fields: vec![
1009 ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1010 ("f2".to_string(), None),
1011 ],
1012 });
1013 assert_eq!(
1014 v,
1015 CqlValue::UserDefinedType {
1016 keyspace: "ks".to_string(),
1017 type_name: "my_type".to_string(),
1018 fields: vec![
1019 ("f1".to_string(), Some(CqlValue::Int(1))),
1020 ("f2".to_string(), None),
1021 ],
1022 }
1023 );
1024 }
1025
1026 #[test]
1027 fn to_scylla_consistency_mapping() {
1028 use scylla::statement::Consistency as SC;
1029 assert!(matches!(
1030 ScyllaDriver::to_scylla_consistency(Consistency::One),
1031 SC::One
1032 ));
1033 assert!(matches!(
1034 ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1035 SC::Quorum
1036 ));
1037 assert!(matches!(
1038 ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1039 SC::LocalQuorum
1040 ));
1041 assert!(matches!(
1042 ScyllaDriver::to_scylla_consistency(Consistency::All),
1043 SC::All
1044 ));
1045 }
1046
1047 #[test]
1048 fn to_scylla_serial_consistency_mapping() {
1049 use scylla::statement::SerialConsistency as SC;
1050 assert!(matches!(
1051 ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1052 Some(SC::Serial)
1053 ));
1054 assert!(matches!(
1055 ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1056 Some(SC::LocalSerial)
1057 ));
1058 assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1059 }
1060
1061 #[test]
1062 fn cql_value_to_string_helper() {
1063 assert_eq!(
1064 cql_value_to_string(&CqlValue::Text("hello".to_string())),
1065 Some("hello".to_string())
1066 );
1067 assert_eq!(
1068 cql_value_to_string(&CqlValue::Int(42)),
1069 Some("42".to_string())
1070 );
1071 assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1072 }
1073
1074 #[test]
1075 fn cql_value_to_i32_helper() {
1076 assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1077 assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1078 assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1079 }
1080}