1use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use anyhow::{anyhow, Context, Result};
13use async_trait::async_trait;
14use chrono::{Datelike, Timelike};
15use futures::{StreamExt, TryStreamExt};
16use scylla::client::session::Session;
17use scylla::client::session_builder::SessionBuilder;
18use scylla::response::query_result::QueryResult;
19use scylla::statement::prepared::PreparedStatement;
20use scylla::statement::Statement;
21use scylla::value::{
22 Counter as ScyllaCounter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid,
23 CqlValue as ScyllaCqlValue, CqlVarint, Row,
24};
25use uuid::Uuid;
26
27use super::types::{CqlColumn, CqlResult, CqlRow, CqlRowStream, CqlValue};
28use super::{
29 AggregateMetadata, ColumnMetadata, ConnectionConfig, Consistency, CqlDriver, FunctionMetadata,
30 KeyspaceMetadata, PreparedId, SslConfig, TableMetadata, TracingEvent, TracingSession,
31 UdtMetadata,
32};
33
34pub struct ScyllaDriver {
36 session: Session,
37 prepared_cache: Mutex<HashMap<Vec<u8>, PreparedStatement>>,
39 consistency: Mutex<Consistency>,
41 serial_consistency: Mutex<Option<Consistency>>,
43 tracing_enabled: AtomicBool,
45 last_trace_id: Mutex<Option<Uuid>>,
47}
48
49impl ScyllaDriver {
50 fn build_rustls_config(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
52 use rustls::pki_types::CertificateDer;
53 use std::fs::File;
54 use std::io::BufReader;
55
56 if !ssl_config.validate {
57 return Self::build_rustls_config_no_verify(ssl_config);
58 }
59
60 let mut root_store = rustls::RootCertStore::empty();
61
62 if let Some(certfile) = &ssl_config.certfile {
64 let file = File::open(certfile)
65 .with_context(|| format!("opening CA certificate: {certfile}"))?;
66 let mut reader = BufReader::new(file);
67 let certs = rustls_pemfile::certs(&mut reader)
68 .collect::<std::result::Result<Vec<_>, _>>()
69 .with_context(|| format!("parsing CA certificate: {certfile}"))?;
70 for cert in certs {
71 root_store
72 .add(cert)
73 .context("adding CA certificate to root store")?;
74 }
75 }
76
77 let builder = rustls::ClientConfig::builder().with_root_certificates(root_store);
78
79 let config = if let (Some(usercert_path), Some(userkey_path)) =
81 (&ssl_config.usercert, &ssl_config.userkey)
82 {
83 let cert_file = File::open(usercert_path)
84 .with_context(|| format!("opening client certificate: {usercert_path}"))?;
85 let mut cert_reader = BufReader::new(cert_file);
86 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
87 .collect::<std::result::Result<Vec<_>, _>>()
88 .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
89
90 let key_file = File::open(userkey_path)
91 .with_context(|| format!("opening client key: {userkey_path}"))?;
92 let mut key_reader = BufReader::new(key_file);
93 let key = rustls_pemfile::private_key(&mut key_reader)
94 .with_context(|| format!("parsing client key: {userkey_path}"))?
95 .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
96
97 builder
98 .with_client_auth_cert(certs, key)
99 .context("configuring mutual TLS")?
100 } else {
101 builder.with_no_client_auth()
102 };
103
104 Ok(Arc::new(config))
105 }
106
107 fn build_rustls_config_no_verify(ssl_config: &SslConfig) -> Result<Arc<rustls::ClientConfig>> {
108 use rustls::client::danger::{
109 HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
110 };
111 use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
112 use rustls::{DigitallySignedStruct, Error, SignatureScheme};
113 use std::fs::File;
114 use std::io::BufReader;
115
116 #[derive(Debug)]
117 struct NoVerifier;
118
119 impl ServerCertVerifier for NoVerifier {
120 fn verify_server_cert(
121 &self,
122 _end_entity: &CertificateDer<'_>,
123 _intermediates: &[CertificateDer<'_>],
124 _server_name: &ServerName<'_>,
125 _ocsp_response: &[u8],
126 _now: UnixTime,
127 ) -> std::result::Result<ServerCertVerified, Error> {
128 Ok(ServerCertVerified::assertion())
129 }
130
131 fn verify_tls12_signature(
132 &self,
133 _message: &[u8],
134 _cert: &CertificateDer<'_>,
135 _dss: &DigitallySignedStruct,
136 ) -> std::result::Result<HandshakeSignatureValid, Error> {
137 Ok(HandshakeSignatureValid::assertion())
138 }
139
140 fn verify_tls13_signature(
141 &self,
142 _message: &[u8],
143 _cert: &CertificateDer<'_>,
144 _dss: &DigitallySignedStruct,
145 ) -> std::result::Result<HandshakeSignatureValid, Error> {
146 Ok(HandshakeSignatureValid::assertion())
147 }
148
149 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
150 rustls::crypto::ring::default_provider()
151 .signature_verification_algorithms
152 .supported_schemes()
153 }
154 }
155
156 let builder = rustls::ClientConfig::builder()
157 .dangerous()
158 .with_custom_certificate_verifier(Arc::new(NoVerifier));
159
160 let config = if let (Some(usercert_path), Some(userkey_path)) =
161 (&ssl_config.usercert, &ssl_config.userkey)
162 {
163 let cert_file = File::open(usercert_path)
164 .with_context(|| format!("opening client certificate: {usercert_path}"))?;
165 let mut cert_reader = BufReader::new(cert_file);
166 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
167 .collect::<std::result::Result<Vec<_>, _>>()
168 .with_context(|| format!("parsing client certificate: {usercert_path}"))?;
169
170 let key_file = File::open(userkey_path)
171 .with_context(|| format!("opening client key: {userkey_path}"))?;
172 let mut key_reader = BufReader::new(key_file);
173 let key = rustls_pemfile::private_key(&mut key_reader)
174 .with_context(|| format!("parsing client key: {userkey_path}"))?
175 .ok_or_else(|| anyhow!("no private key found in {userkey_path}"))?;
176
177 builder
178 .with_client_auth_cert(certs, key)
179 .context("configuring mutual TLS")?
180 } else {
181 builder.with_no_client_auth()
182 };
183
184 Ok(Arc::new(config))
185 }
186
187 fn extract_string_list_val(val: Option<&CqlValue>) -> Vec<String> {
189 match val {
190 Some(CqlValue::List(items)) => items.iter().map(|v| v.to_string()).collect(),
191 _ => Vec::new(),
192 }
193 }
194
195 fn convert_query_result(result: QueryResult) -> Result<CqlResult> {
197 let tracing_id = result.tracing_id();
198 let warnings: Vec<String> = result.warnings().map(|s| s.to_string()).collect();
199
200 if !result.is_rows() {
202 return Ok(CqlResult {
203 columns: Vec::new(),
204 rows: Vec::new(),
205 has_rows: false,
206 tracing_id,
207 warnings,
208 });
209 }
210
211 let rows_result = result
213 .into_rows_result()
214 .context("converting query result to rows")?;
215
216 let col_specs = rows_result.column_specs();
218 let columns: Vec<CqlColumn> = col_specs
219 .iter()
220 .map(|spec| CqlColumn {
221 name: spec.name().to_string(),
222 type_name: format!("{:?}", spec.typ()),
223 })
224 .collect();
225
226 let typed_rows = rows_result.rows::<Row>().context("deserializing rows")?;
228
229 let mut cql_rows = Vec::new();
230 for row_result in typed_rows {
231 let row = row_result.context("deserializing row")?;
232 let values: Vec<CqlValue> = row
233 .columns
234 .into_iter()
235 .enumerate()
236 .map(|(col_idx, opt_val)| match opt_val {
237 Some(v) => {
238 tracing::debug!(
239 column = col_idx,
240 variant = ?std::mem::discriminant(&v),
241 "converting ScyllaCqlValue: {v:?}"
242 );
243 Self::convert_scylla_value(v)
244 }
245 None => {
246 tracing::debug!(column = col_idx, "column value is None (null)");
247 CqlValue::Null
248 }
249 })
250 .collect();
251 cql_rows.push(CqlRow { values });
252 }
253
254 Ok(CqlResult {
255 columns,
256 rows: cql_rows,
257 has_rows: true,
258 tracing_id,
259 warnings,
260 })
261 }
262
263 fn convert_scylla_value(value: ScyllaCqlValue) -> CqlValue {
265 match value {
266 ScyllaCqlValue::Ascii(s) => CqlValue::Ascii(s),
267 ScyllaCqlValue::Boolean(b) => CqlValue::Boolean(b),
268 ScyllaCqlValue::Blob(bytes) => CqlValue::Blob(bytes),
269 ScyllaCqlValue::Counter(c) => CqlValue::Counter(c.0),
270 ScyllaCqlValue::Decimal(d) => {
271 let (int_val, scale) = d.as_signed_be_bytes_slice_and_exponent();
272 let big_int = num_bigint::BigInt::from_signed_bytes_be(int_val);
273 CqlValue::Decimal(bigdecimal::BigDecimal::new(big_int, scale.into()))
274 }
275 ScyllaCqlValue::Date(d) => {
276 let days = d.0;
278 let epoch_offset = days as i64 - (1i64 << 31);
279 match chrono::NaiveDate::from_num_days_from_ce_opt((epoch_offset + 719_163) as i32)
280 {
281 Some(date) => CqlValue::Date(date),
282 None => CqlValue::Text(format!("<invalid date: {days}>")),
283 }
284 }
285 ScyllaCqlValue::Double(d) => CqlValue::Double(d),
286 ScyllaCqlValue::Duration(d) => CqlValue::Duration {
287 months: d.months,
288 days: d.days,
289 nanoseconds: d.nanoseconds,
290 },
291 ScyllaCqlValue::Empty => CqlValue::Null,
292 ScyllaCqlValue::Float(f) => CqlValue::Float(f),
293 ScyllaCqlValue::Int(i) => CqlValue::Int(i),
294 ScyllaCqlValue::BigInt(i) => CqlValue::BigInt(i),
295 ScyllaCqlValue::Text(s) => CqlValue::Text(s),
296 ScyllaCqlValue::Timestamp(t) => CqlValue::Timestamp(t.0),
297 ScyllaCqlValue::Inet(addr) => CqlValue::Inet(addr),
298 ScyllaCqlValue::List(items) => {
299 CqlValue::List(items.into_iter().map(Self::convert_scylla_value).collect())
300 }
301 ScyllaCqlValue::Map(entries) => CqlValue::Map(
302 entries
303 .into_iter()
304 .map(|(k, v)| (Self::convert_scylla_value(k), Self::convert_scylla_value(v)))
305 .collect(),
306 ),
307 ScyllaCqlValue::Set(items) => {
308 CqlValue::Set(items.into_iter().map(Self::convert_scylla_value).collect())
309 }
310 ScyllaCqlValue::UserDefinedType {
311 keyspace,
312 name,
313 fields,
314 } => CqlValue::UserDefinedType {
315 keyspace,
316 type_name: name,
317 fields: fields
318 .into_iter()
319 .map(|(n, val)| (n, val.map(Self::convert_scylla_value)))
320 .collect(),
321 },
322 ScyllaCqlValue::SmallInt(i) => CqlValue::SmallInt(i),
323 ScyllaCqlValue::TinyInt(i) => CqlValue::TinyInt(i),
324 ScyllaCqlValue::Time(t) => {
325 let nanos = t.0;
326 let secs = (nanos / 1_000_000_000) as u32;
327 let nano_part = (nanos % 1_000_000_000) as u32;
328 match chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part) {
329 Some(time) => CqlValue::Time(time),
330 None => CqlValue::Text(format!("<invalid time: {nanos}>")),
331 }
332 }
333 ScyllaCqlValue::Timeuuid(u) => CqlValue::TimeUuid(u.into()),
334 ScyllaCqlValue::Tuple(items) => CqlValue::Tuple(
335 items
336 .into_iter()
337 .map(|v| v.map(Self::convert_scylla_value))
338 .collect(),
339 ),
340 ScyllaCqlValue::Uuid(u) => CqlValue::Uuid(u),
341 ScyllaCqlValue::Varint(v) => {
342 let big_int =
343 num_bigint::BigInt::from_signed_bytes_be(v.as_signed_bytes_be_slice());
344 CqlValue::Varint(big_int)
345 }
346 _ => {
348 tracing::warn!("unhandled ScyllaCqlValue variant: {value:?}");
349 CqlValue::Text(format!("{value:?}"))
350 }
351 }
352 }
353
354 fn internal_to_scylla_cql(v: &CqlValue) -> ScyllaCqlValue {
356 match v {
357 CqlValue::Ascii(s) => ScyllaCqlValue::Ascii(s.clone()),
358 CqlValue::Boolean(b) => ScyllaCqlValue::Boolean(*b),
359 CqlValue::Blob(bytes) => ScyllaCqlValue::Blob(bytes.clone()),
360 CqlValue::Counter(n) => ScyllaCqlValue::Counter(ScyllaCounter(*n)),
361 CqlValue::Double(d) => ScyllaCqlValue::Double(*d),
362 CqlValue::Duration {
363 months,
364 days,
365 nanoseconds,
366 } => ScyllaCqlValue::Duration(CqlDuration {
367 months: *months,
368 days: *days,
369 nanoseconds: *nanoseconds,
370 }),
371 CqlValue::Float(f) => ScyllaCqlValue::Float(*f),
372 CqlValue::Int(i) => ScyllaCqlValue::Int(*i),
373 CqlValue::BigInt(i) => ScyllaCqlValue::BigInt(*i),
374 CqlValue::SmallInt(i) => ScyllaCqlValue::SmallInt(*i),
375 CqlValue::TinyInt(i) => ScyllaCqlValue::TinyInt(*i),
376 CqlValue::Text(s) => ScyllaCqlValue::Text(s.clone()),
377 CqlValue::Timestamp(ms) => ScyllaCqlValue::Timestamp(CqlTimestamp(*ms)),
378 CqlValue::Inet(addr) => ScyllaCqlValue::Inet(*addr),
379 CqlValue::Uuid(u) => ScyllaCqlValue::Uuid(*u),
380 CqlValue::TimeUuid(u) => ScyllaCqlValue::Timeuuid(CqlTimeuuid::from(*u)),
381 CqlValue::Date(d) => {
382 let days_from_ce = d.num_days_from_ce();
384 let epoch_offset = days_from_ce as i64 - 719_163;
385 let cql_days = (epoch_offset + (1i64 << 31)) as u32;
386 ScyllaCqlValue::Date(CqlDate(cql_days))
387 }
388 CqlValue::Time(t) => {
389 let nanos =
390 t.num_seconds_from_midnight() as i64 * 1_000_000_000 + t.nanosecond() as i64;
391 ScyllaCqlValue::Time(CqlTime(nanos))
392 }
393 CqlValue::Varint(bi) => {
394 let bytes = bi.to_signed_bytes_be();
395 ScyllaCqlValue::Varint(CqlVarint::from_signed_bytes_be(bytes))
396 }
397 CqlValue::Decimal(d) => {
398 let (int_val, scale) = d.as_bigint_and_exponent();
399 let bytes = int_val.to_signed_bytes_be();
400 ScyllaCqlValue::Decimal(CqlDecimal::from_signed_be_bytes_slice_and_exponent(
401 &bytes,
402 scale as i32,
403 ))
404 }
405 CqlValue::List(items) => {
406 ScyllaCqlValue::List(items.iter().map(Self::internal_to_scylla_cql).collect())
407 }
408 CqlValue::Set(items) => {
409 ScyllaCqlValue::Set(items.iter().map(Self::internal_to_scylla_cql).collect())
410 }
411 CqlValue::Map(entries) => ScyllaCqlValue::Map(
412 entries
413 .iter()
414 .map(|(k, v)| {
415 (
416 Self::internal_to_scylla_cql(k),
417 Self::internal_to_scylla_cql(v),
418 )
419 })
420 .collect(),
421 ),
422 CqlValue::Tuple(items) => ScyllaCqlValue::Tuple(
423 items
424 .iter()
425 .map(|opt| opt.as_ref().map(Self::internal_to_scylla_cql))
426 .collect(),
427 ),
428 CqlValue::UserDefinedType {
429 keyspace,
430 type_name,
431 fields,
432 } => ScyllaCqlValue::UserDefinedType {
433 keyspace: keyspace.clone(),
434 name: type_name.clone(),
435 fields: fields
436 .iter()
437 .map(|(n, v)| (n.clone(), v.as_ref().map(Self::internal_to_scylla_cql)))
438 .collect(),
439 },
440 CqlValue::Null | CqlValue::Unset => ScyllaCqlValue::Empty,
441 }
442 }
443
444 fn to_scylla_consistency(c: Consistency) -> scylla::statement::Consistency {
446 use scylla::statement::Consistency as SC;
447 match c {
448 Consistency::Any => SC::Any,
449 Consistency::One => SC::One,
450 Consistency::Two => SC::Two,
451 Consistency::Three => SC::Three,
452 Consistency::Quorum => SC::Quorum,
453 Consistency::All => SC::All,
454 Consistency::LocalQuorum => SC::LocalQuorum,
455 Consistency::EachQuorum => SC::EachQuorum,
456 Consistency::Serial => SC::Serial,
457 Consistency::LocalSerial => SC::LocalSerial,
458 Consistency::LocalOne => SC::LocalOne,
459 }
460 }
461
462 fn to_scylla_serial_consistency(
464 c: Consistency,
465 ) -> Option<scylla::statement::SerialConsistency> {
466 use scylla::statement::SerialConsistency as SC;
467 match c {
468 Consistency::Serial => Some(SC::Serial),
469 Consistency::LocalSerial => Some(SC::LocalSerial),
470 _ => None,
471 }
472 }
473
474 fn build_query(&self, cql: &str) -> Statement {
476 let mut stmt = Statement::new(cql);
477
478 let consistency = *self.consistency.lock().unwrap();
479 stmt.set_consistency(Self::to_scylla_consistency(consistency));
480
481 let serial = *self.serial_consistency.lock().unwrap();
482 if let Some(sc) = serial {
483 if let Some(sc) = Self::to_scylla_serial_consistency(sc) {
484 stmt.set_serial_consistency(Some(sc));
485 }
486 }
487
488 if self.tracing_enabled.load(Ordering::Relaxed) {
489 stmt.set_tracing(true);
490 }
491
492 stmt
493 }
494
495 fn store_trace_id(&self, result: &QueryResult) {
497 if let Some(trace_id) = result.tracing_id() {
498 *self.last_trace_id.lock().unwrap() = Some(trace_id);
499 }
500 }
501}
502
503#[async_trait]
504impl CqlDriver for ScyllaDriver {
505 async fn connect(config: &ConnectionConfig) -> Result<Self> {
506 let addr = format!("{}:{}", config.host, config.port);
507
508 let mut builder = SessionBuilder::new().known_node(&addr);
509
510 builder = builder.pool_size(scylla::client::PoolSize::PerHost(
513 std::num::NonZeroUsize::new(1).unwrap(),
514 ));
515
516 if let (Some(username), Some(password)) = (&config.username, &config.password) {
517 builder = builder.user(username, password);
518 }
519
520 builder = builder.connection_timeout(Duration::from_secs(config.connect_timeout));
521
522 if let Some(keyspace) = &config.keyspace {
523 builder = builder.use_keyspace(keyspace, false);
524 }
525
526 let contact_point = tokio::net::lookup_host(&addr)
534 .await
535 .ok()
536 .and_then(|mut addrs| addrs.next());
537 if let Some(contact_point) = contact_point {
538 let translator = Arc::new(
539 super::proxy_address_translator::ProxyAddressTranslator::new(contact_point),
540 );
541 builder = builder.address_translator(translator);
542 }
543
544 if config.ssl {
545 let tls_config = if let Some(ssl_config) = &config.ssl_config {
546 Self::build_rustls_config(ssl_config)?
547 } else {
548 Self::build_rustls_config_no_verify(&SslConfig::default())?
549 };
550 builder = builder.tls_context(Some(tls_config));
551 }
552
553 let session = builder.build().await.context("connecting to cluster")?;
554
555 Ok(ScyllaDriver {
556 session,
557 prepared_cache: Mutex::new(HashMap::new()),
558 consistency: Mutex::new(Consistency::One),
559 serial_consistency: Mutex::new(None),
560 tracing_enabled: AtomicBool::new(false),
561 last_trace_id: Mutex::new(None),
562 })
563 }
564
565 async fn execute_unpaged(&self, query: &str) -> Result<CqlResult> {
566 let stmt = self.build_query(query);
567
568 let result = self.session.query_unpaged(stmt, ()).await?;
569
570 self.store_trace_id(&result);
571 Self::convert_query_result(result)
572 }
573
574 async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
575 let mut stmt = self.build_query(query);
576 stmt.set_page_size(page_size);
577
578 let query_pager = self
579 .session
580 .query_iter(stmt, ())
581 .await
582 .context("starting paged query")?;
583
584 let col_specs = query_pager.column_specs();
586 let columns: Vec<CqlColumn> = col_specs
587 .iter()
588 .map(|spec| CqlColumn {
589 name: spec.name().to_string(),
590 type_name: format!("{:?}", spec.typ()),
591 })
592 .collect();
593
594 let mut rows_stream = query_pager.rows_stream::<Row>()?;
596 let mut cql_rows = Vec::new();
597
598 while let Some(row) = rows_stream.try_next().await? {
599 let values: Vec<CqlValue> = row
600 .columns
601 .into_iter()
602 .map(|opt_val| match opt_val {
603 Some(v) => Self::convert_scylla_value(v),
604 None => CqlValue::Null,
605 })
606 .collect();
607 cql_rows.push(CqlRow { values });
608 }
609
610 Ok(CqlResult {
611 columns,
612 rows: cql_rows,
613 has_rows: true,
614 tracing_id: None,
615 warnings: Vec::new(),
616 })
617 }
618
619 async fn execute_streaming(&self, query: &str, page_size: i32) -> Result<CqlRowStream> {
620 let mut stmt = self.build_query(query);
621 stmt.set_page_size(page_size);
622
623 let query_pager = self
624 .session
625 .query_iter(stmt, ())
626 .await
627 .context("starting streaming query")?;
628
629 let col_specs = query_pager.column_specs();
630 let columns: Vec<CqlColumn> = col_specs
631 .iter()
632 .map(|spec| CqlColumn {
633 name: spec.name().to_string(),
634 type_name: format!("{:?}", spec.typ()),
635 })
636 .collect();
637
638 let rows_stream = query_pager.rows_stream::<Row>()?;
639
640 let mapped_stream = rows_stream.map(|row_result| {
641 row_result
642 .map(|row| {
643 let values: Vec<CqlValue> = row
644 .columns
645 .into_iter()
646 .map(|opt_val| match opt_val {
647 Some(v) => Self::convert_scylla_value(v),
648 None => CqlValue::Null,
649 })
650 .collect();
651 CqlRow { values }
652 })
653 .map_err(|e| anyhow::anyhow!("{}", e))
654 });
655
656 Ok(CqlRowStream {
657 columns,
658 rows: Box::pin(mapped_stream),
659 })
660 }
661
662 async fn prepare(&self, query: &str) -> Result<PreparedId> {
663 let prepared = self
664 .session
665 .prepare(query)
666 .await
667 .context("preparing CQL statement")?;
668
669 let id = prepared.get_id().to_vec();
670 self.prepared_cache
671 .lock()
672 .unwrap()
673 .insert(id.clone(), prepared);
674
675 Ok(PreparedId { inner: id })
676 }
677
678 async fn execute_prepared(
679 &self,
680 prepared_id: &PreparedId,
681 values: &[CqlValue],
682 ) -> Result<CqlResult> {
683 let prepared = self
684 .prepared_cache
685 .lock()
686 .unwrap()
687 .get(&prepared_id.inner)
688 .cloned()
689 .ok_or_else(|| anyhow!("prepared statement not found in cache"))?;
690
691 let scylla_values: Vec<Option<ScyllaCqlValue>> = values
694 .iter()
695 .map(|v| match v {
696 CqlValue::Null | CqlValue::Unset => None,
697 other => Some(Self::internal_to_scylla_cql(other)),
698 })
699 .collect();
700
701 let result = self
702 .session
703 .execute_unpaged(&prepared, scylla_values)
704 .await
705 .context("executing prepared statement")?;
706
707 self.store_trace_id(&result);
708 Self::convert_query_result(result)
709 }
710
711 async fn use_keyspace(&self, keyspace: &str) -> Result<()> {
712 self.session
713 .use_keyspace(keyspace, false)
714 .await
715 .with_context(|| format!("switching to keyspace: {keyspace}"))?;
716 Ok(())
717 }
718
719 fn get_consistency(&self) -> Consistency {
720 *self.consistency.lock().unwrap()
721 }
722
723 fn set_consistency(&self, consistency: Consistency) {
724 *self.consistency.lock().unwrap() = consistency;
725 }
726
727 fn get_serial_consistency(&self) -> Option<Consistency> {
728 *self.serial_consistency.lock().unwrap()
729 }
730
731 fn set_serial_consistency(&self, consistency: Option<Consistency>) {
732 *self.serial_consistency.lock().unwrap() = consistency;
733 }
734
735 fn set_tracing(&self, enabled: bool) {
736 self.tracing_enabled.store(enabled, Ordering::Relaxed);
737 }
738
739 fn is_tracing_enabled(&self) -> bool {
740 self.tracing_enabled.load(Ordering::Relaxed)
741 }
742
743 fn last_trace_id(&self) -> Option<Uuid> {
744 *self.last_trace_id.lock().unwrap()
745 }
746
747 async fn get_trace_session(&self, trace_id: Uuid) -> Result<Option<TracingSession>> {
748 let query = format!(
749 "SELECT client, command, coordinator, duration, parameters, request, started_at \
750 FROM system_traces.sessions WHERE session_id = {}",
751 trace_id
752 );
753 let result = self.execute_unpaged(&query).await?;
754
755 if result.rows.is_empty() {
756 return Ok(None);
757 }
758
759 let events_query = format!(
760 "SELECT activity, source, source_elapsed, thread \
761 FROM system_traces.events WHERE session_id = {}",
762 trace_id
763 );
764 let events_result = self.execute_unpaged(&events_query).await?;
765
766 let events: Vec<TracingEvent> = events_result
767 .rows
768 .iter()
769 .map(|row| TracingEvent {
770 activity: row.get(0).and_then(cql_value_to_string),
771 source: row.get(1).and_then(cql_value_to_string),
772 source_elapsed: row.get(2).and_then(cql_value_to_i32),
773 thread: row.get(3).and_then(cql_value_to_string),
774 })
775 .collect();
776
777 let session_row = &result.rows[0];
778 Ok(Some(TracingSession {
779 trace_id,
780 client: session_row.get(0).and_then(cql_value_to_string),
781 command: session_row.get(1).and_then(cql_value_to_string),
782 coordinator: session_row.get(2).and_then(cql_value_to_string),
783 duration: session_row.get(3).and_then(cql_value_to_i32),
784 parameters: HashMap::new(),
785 request: session_row.get(5).and_then(cql_value_to_string),
786 started_at: session_row.get(6).and_then(cql_value_to_string),
787 events,
788 }))
789 }
790
791 async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
792 let result = self
793 .execute_unpaged(
794 "SELECT keyspace_name, replication, durable_writes \
795 FROM system_schema.keyspaces",
796 )
797 .await?;
798
799 let mut keyspaces = Vec::new();
800 for row in &result.rows {
801 let name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
802 let durable_writes = match row.get(2) {
803 Some(CqlValue::Boolean(b)) => *b,
804 _ => true,
805 };
806
807 keyspaces.push(KeyspaceMetadata {
808 name,
809 replication: HashMap::new(),
810 durable_writes,
811 });
812 }
813
814 Ok(keyspaces)
815 }
816
817 async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
818 let ks_escaped = keyspace.replace('\'', "''");
819
820 let result = self
821 .execute_unpaged(&format!(
822 "SELECT table_name FROM system_schema.tables WHERE keyspace_name = '{ks_escaped}'"
823 ))
824 .await?;
825
826 let mut tables = Vec::new();
827 for row in &result.rows {
828 let table_name = row.get(0).and_then(cql_value_to_string).unwrap_or_default();
829 let tbl_escaped = table_name.replace('\'', "''");
830
831 let col_result = self
832 .execute_unpaged(&format!(
833 "SELECT column_name, type, kind, position, clustering_order \
834 FROM system_schema.columns \
835 WHERE keyspace_name = '{ks_escaped}' AND table_name = '{tbl_escaped}'"
836 ))
837 .await?;
838
839 let mut pk_cols: Vec<(i32, String, String)> = Vec::new();
840 let mut ck_cols: Vec<(i32, String, String, String)> = Vec::new();
841 let mut regular_cols: Vec<(String, String)> = Vec::new();
842
843 for col_row in &col_result.rows {
844 let col_name = col_row
845 .get_by_name("column_name", &col_result.columns)
846 .map(|v| v.to_string())
847 .unwrap_or_default();
848 let col_type = col_row
849 .get_by_name("type", &col_result.columns)
850 .map(|v| v.to_string())
851 .unwrap_or_default();
852 let kind = col_row
853 .get_by_name("kind", &col_result.columns)
854 .map(|v| v.to_string())
855 .unwrap_or_default();
856 let position = col_row
857 .get_by_name("position", &col_result.columns)
858 .and_then(|v| v.to_string().parse::<i32>().ok())
859 .unwrap_or(0);
860 let clustering_order = col_row
861 .get_by_name("clustering_order", &col_result.columns)
862 .map(|v| v.to_string())
863 .unwrap_or_else(|| "none".to_string());
864
865 match kind.as_str() {
866 "partition_key" => pk_cols.push((position, col_name, col_type)),
867 "clustering" => ck_cols.push((position, col_name, col_type, clustering_order)),
868 _ => regular_cols.push((col_name, col_type)),
869 }
870 }
871
872 pk_cols.sort_by_key(|c| c.0);
873 ck_cols.sort_by_key(|c| c.0);
874
875 let partition_key: Vec<String> = pk_cols.iter().map(|c| c.1.clone()).collect();
876 let clustering_key: Vec<String> = ck_cols.iter().map(|c| c.1.clone()).collect();
877 let clustering_order: Vec<String> = ck_cols
878 .iter()
879 .map(|c| {
880 let order = c.3.to_uppercase();
881 if order == "NONE" || order.is_empty() {
882 "ASC".to_string()
883 } else {
884 order
885 }
886 })
887 .collect();
888
889 let mut columns: Vec<ColumnMetadata> = Vec::new();
890 for (_, name, typ) in &pk_cols {
891 columns.push(ColumnMetadata {
892 name: name.clone(),
893 type_name: typ.clone(),
894 });
895 }
896 for (_, name, typ, _) in &ck_cols {
897 columns.push(ColumnMetadata {
898 name: name.clone(),
899 type_name: typ.clone(),
900 });
901 }
902 for (name, typ) in ®ular_cols {
903 columns.push(ColumnMetadata {
904 name: name.clone(),
905 type_name: typ.clone(),
906 });
907 }
908
909 let props_result = self
910 .execute_unpaged(&format!(
911 "SELECT bloom_filter_fp_chance, caching, comment, compaction, compression, \
912 crc_check_chance, default_time_to_live, gc_grace_seconds, \
913 max_index_interval, memtable_flush_period_in_ms, min_index_interval, \
914 speculative_retry \
915 FROM system_schema.tables \
916 WHERE keyspace_name = '{ks_escaped}' AND table_name = '{tbl_escaped}'"
917 ))
918 .await?;
919
920 let mut properties = std::collections::BTreeMap::new();
921 if let Some(props_row) = props_result.rows.first() {
922 let prop_names = [
923 "bloom_filter_fp_chance",
924 "caching",
925 "comment",
926 "compaction",
927 "compression",
928 "crc_check_chance",
929 "default_time_to_live",
930 "gc_grace_seconds",
931 "max_index_interval",
932 "memtable_flush_period_in_ms",
933 "min_index_interval",
934 "speculative_retry",
935 ];
936 for prop_name in &prop_names {
937 if let Some(val) = props_row.get_by_name(prop_name, &props_result.columns) {
938 properties.insert(prop_name.to_string(), val.to_string());
939 }
940 }
941 }
942
943 tables.push(TableMetadata {
944 keyspace: keyspace.to_string(),
945 name: table_name,
946 columns,
947 partition_key,
948 clustering_key,
949 clustering_order,
950 properties,
951 });
952 }
953
954 Ok(tables)
955 }
956
957 async fn get_table_metadata(
958 &self,
959 keyspace: &str,
960 table: &str,
961 ) -> Result<Option<TableMetadata>> {
962 let tables = self.get_tables(keyspace).await?;
963 Ok(tables.into_iter().find(|t| t.name == table))
964 }
965
966 async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
967 let query = format!(
968 "SELECT type_name, field_names, field_types FROM system_schema.types WHERE keyspace_name = '{}'",
969 keyspace.replace('\'', "''")
970 );
971 let result = self.execute_unpaged(&query).await?;
972 let udts = result
973 .rows
974 .iter()
975 .filter_map(|row| {
976 let name = row.get_by_name("type_name", &result.columns)?.to_string();
977 let field_names =
978 Self::extract_string_list_val(row.get_by_name("field_names", &result.columns));
979 let field_types =
980 Self::extract_string_list_val(row.get_by_name("field_types", &result.columns));
981 Some(UdtMetadata {
982 keyspace: keyspace.to_string(),
983 name,
984 field_names,
985 field_types,
986 })
987 })
988 .collect();
989 Ok(udts)
990 }
991
992 async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
993 let query = format!(
994 "SELECT function_name, argument_types, return_type FROM system_schema.functions WHERE keyspace_name = '{}'",
995 keyspace.replace('\'', "''")
996 );
997 let result = self.execute_unpaged(&query).await?;
998 let functions = result
999 .rows
1000 .iter()
1001 .filter_map(|row| {
1002 let name = row
1003 .get_by_name("function_name", &result.columns)?
1004 .to_string();
1005 let argument_types = Self::extract_string_list_val(
1006 row.get_by_name("argument_types", &result.columns),
1007 );
1008 let return_type = row
1009 .get_by_name("return_type", &result.columns)
1010 .map(|v| v.to_string())
1011 .unwrap_or_default();
1012 Some(FunctionMetadata {
1013 keyspace: keyspace.to_string(),
1014 name,
1015 argument_types,
1016 return_type,
1017 })
1018 })
1019 .collect();
1020 Ok(functions)
1021 }
1022
1023 async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
1024 let query = format!(
1025 "SELECT aggregate_name, argument_types, return_type FROM system_schema.aggregates WHERE keyspace_name = '{}'",
1026 keyspace.replace('\'', "''")
1027 );
1028 let result = self.execute_unpaged(&query).await?;
1029 let aggregates = result
1030 .rows
1031 .iter()
1032 .filter_map(|row| {
1033 let name = row
1034 .get_by_name("aggregate_name", &result.columns)?
1035 .to_string();
1036 let argument_types = Self::extract_string_list_val(
1037 row.get_by_name("argument_types", &result.columns),
1038 );
1039 let return_type = row
1040 .get_by_name("return_type", &result.columns)
1041 .map(|v| v.to_string())
1042 .unwrap_or_default();
1043 Some(AggregateMetadata {
1044 keyspace: keyspace.to_string(),
1045 name,
1046 argument_types,
1047 return_type,
1048 })
1049 })
1050 .collect();
1051 Ok(aggregates)
1052 }
1053
1054 async fn get_cluster_name(&self) -> Result<Option<String>> {
1055 let result = self
1056 .execute_unpaged("SELECT cluster_name FROM system.local")
1057 .await?;
1058 Ok(result
1059 .rows
1060 .first()
1061 .and_then(|row| row.get(0))
1062 .and_then(cql_value_to_string))
1063 }
1064
1065 async fn get_cql_version(&self) -> Result<Option<String>> {
1066 let result = self
1067 .execute_unpaged("SELECT cql_version FROM system.local")
1068 .await?;
1069 Ok(result
1070 .rows
1071 .first()
1072 .and_then(|row| row.get(0))
1073 .and_then(cql_value_to_string))
1074 }
1075
1076 async fn get_release_version(&self) -> Result<Option<String>> {
1077 let result = self
1078 .execute_unpaged("SELECT release_version FROM system.local")
1079 .await?;
1080 Ok(result
1081 .rows
1082 .first()
1083 .and_then(|row| row.get(0))
1084 .and_then(cql_value_to_string))
1085 }
1086
1087 async fn get_scylla_version(&self) -> Result<Option<String>> {
1088 let result = self
1091 .execute_unpaged("SELECT scylla_version FROM system.local")
1092 .await;
1093 match result {
1094 Ok(r) => Ok(r
1095 .rows
1096 .first()
1097 .and_then(|row| row.get(0))
1098 .and_then(cql_value_to_string)),
1099 Err(_) => Ok(None), }
1101 }
1102
1103 async fn is_connected(&self) -> bool {
1104 self.execute_unpaged("SELECT key FROM system.local LIMIT 1")
1105 .await
1106 .is_ok()
1107 }
1108}
1109
1110fn cql_value_to_string(v: &CqlValue) -> Option<String> {
1112 match v {
1113 CqlValue::Text(s) | CqlValue::Ascii(s) => Some(s.clone()),
1114 CqlValue::Inet(addr) => Some(addr.to_string()),
1115 CqlValue::Null => None,
1116 other => Some(other.to_string()),
1117 }
1118}
1119
1120fn cql_value_to_i32(v: &CqlValue) -> Option<i32> {
1122 match v {
1123 CqlValue::Int(i) => Some(*i),
1124 CqlValue::BigInt(i) => Some(*i as i32),
1125 CqlValue::SmallInt(i) => Some(*i as i32),
1126 CqlValue::TinyInt(i) => Some(*i as i32),
1127 _ => None,
1128 }
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134
1135 #[test]
1136 fn convert_scylla_value_text() {
1137 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Text("hello".to_string()));
1138 assert_eq!(v, CqlValue::Text("hello".to_string()));
1139 }
1140
1141 #[test]
1142 fn convert_scylla_value_int() {
1143 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Int(42));
1144 assert_eq!(v, CqlValue::Int(42));
1145 }
1146
1147 #[test]
1148 fn convert_scylla_value_boolean() {
1149 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Boolean(true));
1150 assert_eq!(v, CqlValue::Boolean(true));
1151 }
1152
1153 #[test]
1154 fn convert_scylla_value_null() {
1155 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Empty);
1156 assert_eq!(v, CqlValue::Null);
1157 }
1158
1159 #[test]
1160 fn convert_scylla_value_list() {
1161 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::List(vec![
1162 ScyllaCqlValue::Int(1),
1163 ScyllaCqlValue::Int(2),
1164 ]));
1165 assert_eq!(v, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1166 }
1167
1168 #[test]
1169 fn convert_scylla_value_uuid() {
1170 let id = Uuid::nil();
1171 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Uuid(id));
1172 assert_eq!(v, CqlValue::Uuid(id));
1173 }
1174
1175 #[test]
1176 fn convert_scylla_value_blob() {
1177 let v =
1178 ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
1179 assert_eq!(v, CqlValue::Blob(vec![0xde, 0xad, 0xbe, 0xef]));
1180 }
1181
1182 #[test]
1183 fn convert_scylla_value_float() {
1184 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Float(1.5));
1185 assert_eq!(v, CqlValue::Float(1.5));
1186 }
1187
1188 #[test]
1189 fn convert_scylla_value_double() {
1190 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Double(1.5));
1191 assert_eq!(v, CqlValue::Double(1.5));
1192 }
1193
1194 #[test]
1195 fn convert_scylla_value_map() {
1196 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Map(vec![(
1197 ScyllaCqlValue::Text("key".to_string()),
1198 ScyllaCqlValue::Int(42),
1199 )]));
1200 assert_eq!(
1201 v,
1202 CqlValue::Map(vec![(CqlValue::Text("key".to_string()), CqlValue::Int(42))])
1203 );
1204 }
1205
1206 #[test]
1207 fn convert_scylla_value_set() {
1208 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::Set(vec![
1209 ScyllaCqlValue::Int(1),
1210 ScyllaCqlValue::Int(2),
1211 ]));
1212 assert_eq!(v, CqlValue::Set(vec![CqlValue::Int(1), CqlValue::Int(2)]));
1213 }
1214
1215 #[test]
1216 fn convert_scylla_value_udt() {
1217 let v = ScyllaDriver::convert_scylla_value(ScyllaCqlValue::UserDefinedType {
1218 keyspace: "ks".to_string(),
1219 name: "my_type".to_string(),
1220 fields: vec![
1221 ("f1".to_string(), Some(ScyllaCqlValue::Int(1))),
1222 ("f2".to_string(), None),
1223 ],
1224 });
1225 assert_eq!(
1226 v,
1227 CqlValue::UserDefinedType {
1228 keyspace: "ks".to_string(),
1229 type_name: "my_type".to_string(),
1230 fields: vec![
1231 ("f1".to_string(), Some(CqlValue::Int(1))),
1232 ("f2".to_string(), None),
1233 ],
1234 }
1235 );
1236 }
1237
1238 #[test]
1239 fn to_scylla_consistency_mapping() {
1240 use scylla::statement::Consistency as SC;
1241 assert!(matches!(
1242 ScyllaDriver::to_scylla_consistency(Consistency::One),
1243 SC::One
1244 ));
1245 assert!(matches!(
1246 ScyllaDriver::to_scylla_consistency(Consistency::Quorum),
1247 SC::Quorum
1248 ));
1249 assert!(matches!(
1250 ScyllaDriver::to_scylla_consistency(Consistency::LocalQuorum),
1251 SC::LocalQuorum
1252 ));
1253 assert!(matches!(
1254 ScyllaDriver::to_scylla_consistency(Consistency::All),
1255 SC::All
1256 ));
1257 }
1258
1259 #[test]
1260 fn to_scylla_serial_consistency_mapping() {
1261 use scylla::statement::SerialConsistency as SC;
1262 assert!(matches!(
1263 ScyllaDriver::to_scylla_serial_consistency(Consistency::Serial),
1264 Some(SC::Serial)
1265 ));
1266 assert!(matches!(
1267 ScyllaDriver::to_scylla_serial_consistency(Consistency::LocalSerial),
1268 Some(SC::LocalSerial)
1269 ));
1270 assert!(ScyllaDriver::to_scylla_serial_consistency(Consistency::One).is_none());
1271 }
1272
1273 #[test]
1274 fn cql_value_to_string_helper() {
1275 assert_eq!(
1276 cql_value_to_string(&CqlValue::Text("hello".to_string())),
1277 Some("hello".to_string())
1278 );
1279 assert_eq!(
1280 cql_value_to_string(&CqlValue::Int(42)),
1281 Some("42".to_string())
1282 );
1283 assert_eq!(cql_value_to_string(&CqlValue::Null), None);
1284 }
1285
1286 #[test]
1287 fn cql_value_to_i32_helper() {
1288 assert_eq!(cql_value_to_i32(&CqlValue::Int(42)), Some(42));
1289 assert_eq!(cql_value_to_i32(&CqlValue::BigInt(100)), Some(100));
1290 assert_eq!(cql_value_to_i32(&CqlValue::Text("x".to_string())), None);
1291 }
1292}