1use anyhow::{bail, Result};
8
9use crate::config::MergedConfig;
10use crate::driver::types::CqlValue;
11use crate::driver::{
12 AggregateMetadata, ConnectionConfig, Consistency, CqlDriver, CqlResult, CqlRowStream,
13 FunctionMetadata, KeyspaceMetadata, PreparedId, ScyllaDriver, SslConfig, TableMetadata,
14 TracingSession, UdtMetadata,
15};
16
17pub struct CqlSession {
19 driver: ScyllaDriver,
20 current_keyspace: Option<String>,
22 pub connection_display: String,
24 pub cluster_name: Option<String>,
26 pub cql_version: Option<String>,
28 pub release_version: Option<String>,
30 pub scylla_version: Option<String>,
32}
33
34impl CqlSession {
35 pub async fn connect(config: &MergedConfig) -> Result<Self> {
37 let ssl_config = if config.ssl {
38 Some(SslConfig {
39 certfile: config.cqlshrc.ssl.certfile.clone(),
40 validate: config.cqlshrc.ssl.validate.unwrap_or(false),
41 userkey: config.cqlshrc.ssl.userkey.clone(),
42 usercert: config.cqlshrc.ssl.usercert.clone(),
43 host_certfiles: config.cqlshrc.certfiles.clone(),
44 })
45 } else {
46 None
47 };
48
49 let conn_config = ConnectionConfig {
50 host: config.host.clone(),
51 port: config.port,
52 username: config.username.clone(),
53 password: config.password.clone(),
54 keyspace: config.keyspace.clone(),
55 connect_timeout: config.connect_timeout,
56 request_timeout: config.request_timeout,
57 ssl: config.ssl,
58 ssl_config,
59 protocol_version: config.protocol_version,
60 };
61
62 let driver = ScyllaDriver::connect(&conn_config).await?;
63
64 let connection_display = format!("{}:{}", config.host, config.port);
65
66 let cluster_name = driver.get_cluster_name().await.ok().flatten();
68 let cql_version = driver.get_cql_version().await.ok().flatten();
69 let release_version = driver.get_release_version().await.ok().flatten();
70 let scylla_version = driver.get_scylla_version().await.ok().flatten();
71
72 if let Some(cl_str) = &config.consistency_level {
74 if let Some(cl) = Consistency::from_str_cql(cl_str) {
75 driver.set_consistency(cl);
76 }
77 }
78
79 if let Some(scl_str) = &config.serial_consistency_level {
81 if let Some(scl) = Consistency::from_str_cql(scl_str) {
82 driver.set_serial_consistency(Some(scl));
83 }
84 }
85
86 Ok(CqlSession {
87 driver,
88 current_keyspace: config.keyspace.clone(),
89 connection_display,
90 cluster_name,
91 cql_version,
92 release_version,
93 scylla_version,
94 })
95 }
96
97 pub async fn check_schema_agreement(&self) -> bool {
101 use std::collections::HashSet;
102
103 let mut versions = HashSet::new();
104
105 if let Ok(result) = self
107 .driver
108 .execute_unpaged("SELECT schema_version FROM system.local WHERE key='local'")
109 .await
110 {
111 for row in &result.rows {
112 if let Some(v) = row.get(0) {
113 versions.insert(v.to_string());
114 }
115 }
116 }
117
118 if let Ok(result) = self
120 .driver
121 .execute_unpaged("SELECT schema_version FROM system.peers")
122 .await
123 {
124 for row in &result.rows {
125 if let Some(v) = row.get(0) {
126 versions.insert(v.to_string());
127 }
128 }
129 }
130
131 versions.len() <= 1
133 }
134
135 pub async fn execute(&mut self, query: &str) -> Result<CqlResult> {
137 let trimmed = query.trim();
138
139 if let Some(keyspace) = parse_use_command(trimmed) {
141 self.use_keyspace(&keyspace).await?;
142 return Ok(CqlResult::empty());
143 }
144
145 self.driver.execute_unpaged(query).await
146 }
147
148 pub async fn execute_query(&self, query: &str) -> Result<CqlResult> {
153 self.driver.execute_unpaged(query).await
154 }
155
156 pub async fn execute_paged(&self, query: &str, page_size: i32) -> Result<CqlResult> {
158 self.driver.execute_paged(query, page_size).await
159 }
160
161 pub async fn execute_streaming(&self, query: &str, page_size: i32) -> Result<CqlRowStream> {
162 self.driver.execute_streaming(query, page_size).await
163 }
164
165 pub async fn prepare(&self, query: &str) -> Result<PreparedId> {
167 self.driver.prepare(query).await
168 }
169
170 pub async fn execute_prepared(
172 &self,
173 id: &PreparedId,
174 values: &[CqlValue],
175 ) -> Result<CqlResult> {
176 self.driver.execute_prepared(id, values).await
177 }
178
179 pub async fn use_keyspace(&mut self, keyspace: &str) -> Result<()> {
181 self.driver.use_keyspace(keyspace).await?;
182 self.current_keyspace = Some(keyspace.to_string());
183 Ok(())
184 }
185
186 pub fn current_keyspace(&self) -> Option<&str> {
188 self.current_keyspace.as_deref()
189 }
190
191 pub fn get_consistency(&self) -> Consistency {
193 self.driver.get_consistency()
194 }
195
196 pub fn set_consistency(&self, consistency: Consistency) {
198 self.driver.set_consistency(consistency);
199 }
200
201 pub fn set_consistency_str(&self, level: &str) -> Result<()> {
203 let consistency = Consistency::from_str_cql(level)
204 .ok_or_else(|| anyhow::anyhow!("invalid consistency level: {level}"))?;
205 self.driver.set_consistency(consistency);
206 Ok(())
207 }
208
209 pub fn get_serial_consistency(&self) -> Option<Consistency> {
211 self.driver.get_serial_consistency()
212 }
213
214 pub fn set_serial_consistency(&self, consistency: Option<Consistency>) {
216 self.driver.set_serial_consistency(consistency);
217 }
218
219 pub fn set_serial_consistency_str(&self, level: &str) -> Result<()> {
221 let consistency = Consistency::from_str_cql(level)
222 .ok_or_else(|| anyhow::anyhow!("invalid serial consistency level: {level}"))?;
223 match consistency {
224 Consistency::Serial | Consistency::LocalSerial => {
225 self.driver.set_serial_consistency(Some(consistency));
226 Ok(())
227 }
228 _ => bail!("serial consistency must be SERIAL or LOCAL_SERIAL, got {level}"),
229 }
230 }
231
232 pub fn set_tracing(&self, enabled: bool) {
234 self.driver.set_tracing(enabled);
235 }
236
237 pub fn is_tracing_enabled(&self) -> bool {
239 self.driver.is_tracing_enabled()
240 }
241
242 pub fn last_trace_id(&self) -> Option<uuid::Uuid> {
244 self.driver.last_trace_id()
245 }
246
247 pub async fn get_trace_session(&self, trace_id: uuid::Uuid) -> Result<Option<TracingSession>> {
249 self.driver.get_trace_session(trace_id).await
250 }
251
252 pub async fn get_keyspaces(&self) -> Result<Vec<KeyspaceMetadata>> {
254 self.driver.get_keyspaces().await
255 }
256
257 pub async fn get_tables(&self, keyspace: &str) -> Result<Vec<TableMetadata>> {
259 self.driver.get_tables(keyspace).await
260 }
261
262 pub async fn get_table_metadata(
264 &self,
265 keyspace: &str,
266 table: &str,
267 ) -> Result<Option<TableMetadata>> {
268 self.driver.get_table_metadata(keyspace, table).await
269 }
270
271 pub async fn get_udts(&self, keyspace: &str) -> Result<Vec<UdtMetadata>> {
273 self.driver.get_udts(keyspace).await
274 }
275
276 pub async fn get_functions(&self, keyspace: &str) -> Result<Vec<FunctionMetadata>> {
278 self.driver.get_functions(keyspace).await
279 }
280
281 pub async fn get_aggregates(&self, keyspace: &str) -> Result<Vec<AggregateMetadata>> {
283 self.driver.get_aggregates(keyspace).await
284 }
285
286 pub async fn is_connected(&self) -> bool {
288 self.driver.is_connected().await
289 }
290}
291
292fn parse_use_command(query: &str) -> Option<String> {
294 let upper = query.to_uppercase();
295 let trimmed = upper.trim().trim_end_matches(';').trim();
296
297 if !trimmed.starts_with("USE ") {
298 return None;
299 }
300
301 let keyspace = query
302 .trim()
303 .trim_end_matches(';')
304 .trim()
305 .strip_prefix("USE ")
306 .or_else(|| {
307 query
308 .trim()
309 .trim_end_matches(';')
310 .trim()
311 .strip_prefix("use ")
312 })
313 .map(|s| s.trim())?;
314
315 let keyspace = if (keyspace.starts_with('"') && keyspace.ends_with('"'))
317 || (keyspace.starts_with('\'') && keyspace.ends_with('\''))
318 {
319 &keyspace[1..keyspace.len() - 1]
320 } else {
321 keyspace
322 };
323
324 if keyspace.is_empty() {
325 None
326 } else {
327 Some(keyspace.to_string())
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn parse_use_simple() {
337 assert_eq!(
338 parse_use_command("USE my_keyspace"),
339 Some("my_keyspace".to_string())
340 );
341 }
342
343 #[test]
344 fn parse_use_semicolon() {
345 assert_eq!(
346 parse_use_command("USE my_keyspace;"),
347 Some("my_keyspace".to_string())
348 );
349 }
350
351 #[test]
352 fn parse_use_lowercase() {
353 assert_eq!(
354 parse_use_command("use test_ks"),
355 Some("test_ks".to_string())
356 );
357 }
358
359 #[test]
360 fn parse_use_quoted() {
361 assert_eq!(
362 parse_use_command("USE \"MyKeyspace\""),
363 Some("MyKeyspace".to_string())
364 );
365 }
366
367 #[test]
368 fn parse_use_single_quoted() {
369 assert_eq!(parse_use_command("USE 'my_ks'"), Some("my_ks".to_string()));
370 }
371
372 #[test]
373 fn parse_use_with_whitespace() {
374 assert_eq!(
375 parse_use_command(" USE my_keyspace ; "),
376 Some("my_keyspace".to_string())
377 );
378 }
379
380 #[test]
381 fn parse_not_use_command() {
382 assert_eq!(parse_use_command("SELECT * FROM table"), None);
383 assert_eq!(parse_use_command("INSERT INTO users"), None);
384 }
385
386 #[test]
387 fn parse_use_empty() {
388 assert_eq!(parse_use_command("USE "), None);
389 assert_eq!(parse_use_command("USE ;"), None);
390 }
391}