From 5b2f8eaf9c4f5a684dbadd8855a1d083e91b2aaa Mon Sep 17 00:00:00 2001 From: "Timothy J. Warren" Date: Fri, 5 Apr 2019 20:46:07 -0400 Subject: [PATCH] Implement basic identifier quoting --- Cargo.toml | 1 + src/drivers.rs | 137 +++++++++++++++++++++++++++++++++++++++++++ src/drivers/mod.rs | 57 ------------------ src/drivers/mssql.rs | 53 +++++++++++++++++ src/drivers/mysql.rs | 39 +++++++++++- src/query_builder.rs | 4 +- 6 files changed, 231 insertions(+), 60 deletions(-) create mode 100644 src/drivers.rs delete mode 100644 src/drivers/mod.rs create mode 100644 src/drivers/mssql.rs diff --git a/Cargo.toml b/Cargo.toml index 5bde318..079f3e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,4 @@ default=['postgres'] postgres=['pg'] sqlite=['slite'] mysql=['my'] +mssql=[] diff --git a/src/drivers.rs b/src/drivers.rs new file mode 100644 index 0000000..dce6283 --- /dev/null +++ b/src/drivers.rs @@ -0,0 +1,137 @@ +//! Drivers +//! +//! Drivers represent a connection to a specific type of database engine +use std::fmt; + +#[cfg(feature = "postgres")] +mod postgres; + +#[cfg(feature = "sqlite")] +mod sqlite; + +#[cfg(feature = "mysql")] +mod mysql; + +#[cfg(feature = "mssql")] +mod mssql; + +#[derive(Debug)] +struct Connection; + +/// Result for a db query +#[derive(Debug)] +struct QueryResult; + +struct DriverBase { + escape_char_open: char, + escape_char_close: char, + has_truncate: bool, +} + +/// Database Driver Trait +/// +/// Interface between the database connection library and the query builder +pub trait DatabaseDriver: fmt::Debug { + /// Get which characters are used to delimit identifiers + /// such as tables, and columns + fn _quotes(&self) -> (char, char) { + ('"', '"') + } + + /// Vector version of `quote_identifier` + fn quote_identifiers(&self, identifiers: Vec) -> Vec { + let mut output: Vec = vec![]; + + for identifier in identifiers { + output.push(self.quote_identifier(&identifier).to_string()); + } + + output + } + + /// Quote the identifiers passed, so the database does not + /// normalize the identifiers (eg, table, column, etc.) + fn quote_identifier(&self, identifier: &str) -> String { + let mut identifier = &mut String::from(identifier); + + // If the identifier is actually a list, + // recurse to quote each identifier in the list + if identifier.contains(",") { + let mut quoted_parts: Vec = vec![]; + + for part in identifier.split(",") { + let new_part = part.trim(); + let new_part = &self.quote_identifier(new_part); + quoted_parts.push(new_part.to_owned()); + } + + // This was the only way I could figure to get + // around mutable string reference scope hell + identifier.replace_range(.., &mut quoted_parts.join(",")); + } + + let (open_char, close_char) = self._quotes(); + + let mut trimmed_hiers: Vec = vec![]; + for hier in identifier.split(".") { + let mut hier = &mut hier.trim(); + + if hier.starts_with(open_char) && hier.ends_with(close_char) { + trimmed_hiers.push(hier.to_string()); + } else { + let mut hier = format!("{}{}{}", open_char, hier, close_char); + trimmed_hiers.push(hier.to_string()); + } + } + trimmed_hiers.join(".") + } + + /// Runs a basic sql query on the database + fn query(&self, query: &str) -> Result<(), ()>; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct TestDriver; + + impl DatabaseDriver for TestDriver { + fn query(&self, _query: &str) -> Result<(), ()> { + Ok(()) + } + } + + #[test] + fn test_quote_identifier() { + let driver = TestDriver {}; + + assert_eq!( + driver.quote_identifier("foo, bar, baz"), + r#""foo","bar","baz""# + ); + assert_eq!( + driver.quote_identifier("foo.bar, baz, fizz"), + r#""foo"."bar","baz","fizz""# + ); + } + + #[test] + fn test_quote_identifiers() { + let driver = TestDriver {}; + + assert_eq!( + driver.quote_identifiers(vec![ + "\tfoo. bar".to_string(), + "baz".to_string(), + "fizz.\n\tbuzz.baz".to_string(), + ]), + vec![ + r#""foo"."bar""#.to_string(), + r#""baz""#.to_string(), + r#""fizz"."buzz"."baz""#.to_string(), + ] + ); + } +} diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs deleted file mode 100644 index f88d913..0000000 --- a/src/drivers/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Drivers -//! -//! Drivers represent a connection to a specific type of database engine -use std::fmt; - -#[cfg(feature = "postgres")] -mod postgres; - -#[cfg(feature = "sqlite")] -mod sqlite; - -#[cfg(feature = "mysql")] -mod mysql; - -#[derive(Debug)] -struct Connection; - -/// Result for a db query -#[derive(Debug)] -struct QueryResult; - -struct DriverBase { - escape_char_open: char, - escape_char_close: char, - has_truncate: bool, -} - -/// Database Driver Trait -/// -/// Interface between the database connection library and the query builder -pub trait DatabaseDriver: fmt::Debug { - /// Get which characters are used to delimit identifiers - /// such as tables, and columns - fn _quotes(&self) -> (char, char) { - ('"','"') - } - - /// Vector version of `quote_identifier` - fn quote_identifiers(&self, identifiers: Vec) -> Vec { - let mut output: Vec = vec![]; - - for identifier in identifiers { - output.push(self.quote_identifier(&identifier)); - } - - output - } - - /// Quote the identifiers passed, so the database does not - /// normalize the identifiers (eg, table, column, etc.) - fn quote_identifier(&self, identifier: &str) -> String { - identifier.to_string() - } - - /// Runs a basic sql query on the database - fn query(&self, query: &str) -> Result<(), ()>; -} diff --git a/src/drivers/mssql.rs b/src/drivers/mssql.rs new file mode 100644 index 0000000..8954568 --- /dev/null +++ b/src/drivers/mssql.rs @@ -0,0 +1,53 @@ +use super::*; + +#[derive(Debug)] +pub struct MSSQL; + +impl DatabaseDriver for MSSQL { + /// Get which characters are used to delimit identifiers + /// such as tables, and columns + fn _quotes(&self) -> (char, char) { + ('[', ']') + } + + fn query(&self, _query: &str) -> Result<(), ()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_quote_identifier_bracket_quote() { + let driver = MSSQL {}; + + assert_eq!( + driver.quote_identifier("foo, bar, baz"), + "[foo],[bar],[baz]" + ); + assert_eq!( + driver.quote_identifier("foo.bar, baz, fizz"), + "[foo].[bar],[baz],[fizz]" + ); + } + + #[test] + fn test_quote_identifiers_bracket_quote() { + let driver = MSSQL {}; + + assert_eq!( + driver.quote_identifiers(vec![ + "\tfoo. bar".to_string(), + "baz".to_string(), + "fizz.\n\tbuzz.baz".to_string(), + ]), + vec![ + "[foo].[bar]".to_string(), + "[baz]".to_string(), + "[fizz].[buzz].[baz]".to_string(), + ] + ); + } +} diff --git a/src/drivers/mysql.rs b/src/drivers/mysql.rs index a2d7c60..6f20aef 100644 --- a/src/drivers/mysql.rs +++ b/src/drivers/mysql.rs @@ -7,10 +7,47 @@ impl DatabaseDriver for MySQL { /// Get which characters are used to delimit identifiers /// such as tables, and columns fn _quotes(&self) -> (char, char) { - ('`','`') + ('`', '`') } fn query(&self, _query: &str) -> Result<(), ()> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_quote_identifier_backtick_quote() { + let driver = MySQL {}; + + assert_eq!( + driver.quote_identifier("foo, bar, baz"), + "`foo`,`bar`,`baz`" + ); + assert_eq!( + driver.quote_identifier("foo.bar, baz, fizz"), + "`foo`.`bar`,`baz`,`fizz`" + ); + } + + #[test] + fn test_quote_identifiers_backtick_quote() { + let driver = MySQL {}; + + assert_eq!( + driver.quote_identifiers(vec![ + "\tfoo. bar".to_string(), + "baz".to_string(), + "fizz.\n\tbuzz.baz".to_string(), + ]), + vec![ + "`foo`.`bar`".to_string(), + "`baz`".to_string(), + "`fizz`.`buzz`.`baz`".to_string(), + ] + ); + } +} diff --git a/src/query_builder.rs b/src/query_builder.rs index 192ac7a..c5d19a7 100644 --- a/src/query_builder.rs +++ b/src/query_builder.rs @@ -159,7 +159,7 @@ impl QueryBuilder { pub fn select(&mut self, fields: &str) -> &mut Self { unimplemented!(); } - + /// Set the fields to select from the database as a Vector pub fn select_vec(&mut self, fields: Vec<&str>) -> &mut Self { let fields = fields.join(","); @@ -243,7 +243,7 @@ impl QueryBuilder { self } - + // Specify a condition for a `where` clause where a column has a value pub fn where_eq(&mut self, key: &str, value: Box) -> &mut Self { self.r#where(key, "=", value)