Support custom dialects: addresses #401

Still needs tests and documentation.
This commit is contained in:
Sean Corfield 2022-04-30 22:03:36 -07:00
parent 70e8afc273
commit 8c8b05e67f
2 changed files with 52 additions and 16 deletions

View file

@ -1,5 +1,8 @@
# Changes # Changes
* 2.3.next in progress
* Address [#401](https://github.com/seancorfield/honeysql/issues/401) by adding `register-dialect!` and `get-dialect`, and also making `add-clause-before` and `strop` public so that new dialects are easier to construct.
* 2.2.891 -- 2022-04-23 * 2.2.891 -- 2022-04-23
* Address [#404](https://github.com/seancorfield/honeysql/issues/404) by documenting PostgreSQL's `ARRAY` constructor syntax and how to produce it. * Address [#404](https://github.com/seancorfield/honeysql/issues/404) by documenting PostgreSQL's `ARRAY` constructor syntax and how to produce it.
* Address parts of [#403](https://github.com/seancorfield/honeysql/issues/403) by improving the documentation for `:array` and also improving the exception that was thrown when it was misused. * Address parts of [#403](https://github.com/seancorfield/honeysql/issues/403) by improving the documentation for `:array` and also improving the exception that was thrown when it was misused.

View file

@ -65,7 +65,7 @@
:returning :returning
:with-data]) :with-data])
(defn- add-clause-before (defn add-clause-before
"Low-level helper just to insert a new clause. "Low-level helper just to insert a new clause.
If the clause is already in the list, this moves it to the end." If the clause is already in the list, this moves it to the end."
@ -87,23 +87,24 @@
order)) order))
(conj order clause)))) (conj order clause))))
(defn- strop (defn strop
"Escape any embedded closing strop characters." "Escape any embedded closing strop characters."
[s x e] [s x e]
(str s (str/replace x (str e) (str e e)) e)) (str s (str/replace x (str e) (str e e)) e))
(def ^:private dialects (def ^:private dialects
(reduce-kv (fn [m k v] (atom
(assoc m k (assoc v :dialect k))) (reduce-kv (fn [m k v]
{} (assoc m k (assoc v :dialect k)))
{:ansi {:quote #(strop \" % \")} {}
:sqlserver {:quote #(strop \[ % \])} {:ansi {:quote #(strop \" % \")}
:mysql {:quote #(strop \` % \`) :sqlserver {:quote #(strop \[ % \])}
:clause-order-fn #(add-clause-before % :set :where)} :mysql {:quote #(strop \` % \`)
:oracle {:quote #(strop \" % \") :as false}})) :clause-order-fn #(add-clause-before % :set :where)}
:oracle {:quote #(strop \" % \") :as false}})))
; should become defonce ; should become defonce
(def ^:private default-dialect (atom (:ansi dialects))) (def ^:private default-dialect (atom (:ansi @dialects)))
(def ^:private default-quoted (atom nil)) (def ^:private default-quoted (atom nil))
(def ^:private ^:dynamic *dialect* nil) (def ^:private ^:dynamic *dialect* nil)
@ -219,7 +220,7 @@
(for [v [:foo-bar "foo-bar" ; symbol is the same as keyword (for [v [:foo-bar "foo-bar" ; symbol is the same as keyword
:f-o.b-r :f-o/b-r] :f-o.b-r :f-o/b-r]
a [true false] d [true false] q [true false]] a [true false] d [true false] q [true false]]
(binding [*dialect* (:mysql dialects) *quoted* q] (binding [*dialect* (:mysql @dialects) *quoted* q]
(if q (if q
[v a d (format-entity v {:aliased a :drop-ns d}) [v a d (format-entity v {:aliased a :drop-ns d})
(binding [*quoted-snake* true] (binding [*quoted-snake* true]
@ -1406,9 +1407,9 @@
["?" expr]))) ["?" expr])))
(defn- check-dialect [dialect] (defn- check-dialect [dialect]
(when-not (contains? dialects dialect) (when-not (contains? @dialects dialect)
(throw (ex-info (str "Invalid dialect: " dialect) (throw (ex-info (str "Invalid dialect: " dialect)
{:valid-dialects (vec (sort (keys dialects)))}))) {:valid-dialects (vec (sort (keys @dialects)))})))
dialect) dialect)
(def through-opts (def through-opts
@ -1443,7 +1444,7 @@
([data opts] ([data opts]
(let [cache (:cache opts) (let [cache (:cache opts)
dialect? (contains? opts :dialect) dialect? (contains? opts :dialect)
dialect (when dialect? (get dialects (check-dialect (:dialect opts))))] dialect (when dialect? (get @dialects (check-dialect (:dialect opts))))]
(binding [*dialect* (if dialect? dialect @default-dialect) (binding [*dialect* (if dialect? dialect @default-dialect)
*caching* cache *caching* cache
*checking* (if (contains? opts :checking) *checking* (if (contains? opts :checking)
@ -1482,7 +1483,7 @@
Dialects are always applied to the base order to create the current order." Dialects are always applied to the base order to create the current order."
[dialect & {:keys [quoted]}] [dialect & {:keys [quoted]}]
(reset! default-dialect (get dialects (check-dialect dialect))) (reset! default-dialect (get @dialects (check-dialect dialect)))
(when-let [f (:clause-order-fn @default-dialect)] (when-let [f (:clause-order-fn @default-dialect)]
(reset! current-clause-order (f @base-clause-order))) (reset! current-clause-order (f @base-clause-order)))
(reset! default-quoted quoted)) (reset! default-quoted quoted))
@ -1524,6 +1525,38 @@
(swap! current-clause-order add-clause-before clause before) (swap! current-clause-order add-clause-before clause before)
(swap! clause-format assoc clause f)))) (swap! clause-format assoc clause f))))
(defn register-dialect!
"Register a new dialect. Accepts a dialect name (keyword) and a hash
map that must contain at least a `:quoted` key whose value is a unary
function that accepts a string and returns it quoted per the dialect.
It may also contain a `:clause-order-fn` key whose value is a unary
function that accepts a list of SQL clauses (keywords) in order of
precedence and returns an updated list of SQL clauses in order. It
may use `add-clause-before` to achieve this. Currently, the only
dialect that does this is MySQL, whose `SET` clause (`:set`) has a
non-standard precedence, compared to other SQL dialects."
[dialect dialect-spec]
(when-not (keyword? dialect)
(throw (ex-info "Dialect must be a keyword" {:dialect dialect})))
(when-not (map? dialect-spec)
(throw (ex-info "Dialect spec must be a hash map containing at least a :quoted function"
{:dialect-spec dialect-spec})))
(when-not (fn? (:quoted dialect-spec))
(throw (ex-info "Dialect spec is missing a :quoted function"
{:dialect-spec dialect-spec})))
(when-let [cof (:clause-order-fn dialect-spec)]
(when-not (fn? cof)
(throw (ex-info "Dialect spec contains :clause-order-fn but it is not a function"
{:dialect-spec dialect-spec}))))
(swap! dialects assoc dialect (assoc dialect-spec :dialect dialect)))
(defn get-dialect
"Given a dialect name (keyword), return its definition.
Returns `nil` if the dialect is unknown."
[dialect]
(get @dialects dialect))
(defn register-fn! (defn register-fn!
"Register a new function (as special syntax). The `formatter` is either "Register a new function (as special syntax). The `formatter` is either
a keyword, meaning that this new function should use the same syntax as a keyword, meaning that this new function should use the same syntax as