diff --git a/src/net/cgrand/xforms.clj b/src/net/cgrand/xforms.clj index 3902125..a4c9335 100644 --- a/src/net/cgrand/xforms.clj +++ b/src/net/cgrand/xforms.clj @@ -26,7 +26,12 @@ ([~acc] (~rf ~acc)) ([~acc ~binding] ~body))))) -(defprotocol KvRf "Marker protocol for reducing fns that takes key and val as separate arguments.") +(defprotocol KvRfable "Protocol for reducing fns that takes key and val as separate arguments." + (some-kvrf [f] "Returns a kvrf or nil")) + +(extend-protocol KvRfable + Object (some-kvrf [_] nil) + nil (some-kvrf [_] nil)) (defmacro kvrf [name? & fn-bodies] (let [name (if (symbol? name?) name? (gensym '_)) @@ -34,7 +39,8 @@ fn-bodies (if (vector? (first fn-bodies)) (list fn-bodies) fn-bodies)] `(reify clojure.lang.Fn - KvRf + KvRfable + (some-kvrf [this#] this#) clojure.lang.IFn ~@(clj/for [[args & body] fn-bodies] `(invoke [~name ~@args] ~@body))))) @@ -45,7 +51,7 @@ ([f] (fn [rf] (let [vacc (volatile! (f))] - (if (satisfies? KvRf f) + (if-some [f (some-kvrf f)] (kvrf ([] (rf)) ([acc] (rf (rf acc (f (unreduced @vacc))))) @@ -131,12 +137,10 @@ ([kfn vfn xform] (by-key kfn vfn vector xform)) ([kfn vfn pair xform] (fn [rf] - (let [make-rf (cond - (and (= vector pair) (satisfies? KvRf rf)) - (fn [k] (fn ([acc] acc) ([acc v] (rf acc k v)))) - pair - (fn [k] (fn ([acc] acc) ([acc v] (rf acc (pair k v))))) - :else + (let [make-rf (if pair + (if-some [rf (when (identical? vector pair) (some-kvrf rf))] + (fn [k] (fn ([acc] acc) ([acc v] (rf acc k v)))) + (fn [k] (fn ([acc] acc) ([acc v] (rf acc (pair k v)))))) (constantly (multiplexable rf))) m (volatile! (transient {}))] (if (and (= key' kfn) (= val' vfn))