diff --git a/src/clj/coffi/ffi.clj b/src/clj/coffi/ffi.clj index a94982d..d1cfc1a 100644 --- a/src/clj/coffi/ffi.clj +++ b/src/clj/coffi/ffi.clj @@ -230,20 +230,138 @@ (let [args (concat required-args types)] (make-downcall symbol args ret))))) -;; TODO(Joshua): Optimize this to not serialize things if possible +(def ^:private primitive-cast-sym + "Map from non-pointer primitive types to functions that cast to the appropriate + java primitive." + {::mem/byte `byte + ::mem/short `short + ::mem/int `int + ::mem/long `long + ::mem/long-long `long + ::mem/char `char + ::mem/float `float + ::mem/double `double}) + +(defn- inline-serde-wrapper + "Builds a form that returns a function that calls `downcall` with serdes. + + The return type and any arguments that are primitives will not + be (de)serialized except to be cast. If all arguments and return are + primitive, the `downcall` is returned directly. In cases where arguments must + be serialized, a new [[mem/stack-scope]] is generated." + [downcall arg-types ret-type] + (let [const-ret? (s/valid? ::mem/type ret-type) + primitive-ret? (mem/primitive? ret-type) + scope (gensym "scope")] + (if-not (seqable? arg-types) + (let [args (gensym "args") + serialized-args `(map (fn [arg# type#] (mem/serialize arg# type# ~scope)) ~args ~arg-types) + prim-call `(apply ~downcall ~serialized-args) + non-prim-call `(apply ~downcall (mem/scope-allocator ~scope) ~serialized-args)] + (cond + (and const-ret? + primitive-ret?) + `(fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scope)] + ~prim-call)) + + const-ret? + `(fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scopee)] + ~(if (mem/primitive-type ret-type) + `(mem/deserialize* ~prim-call ~ret-type) + `(mem/deserialize-from ~non-prim-call ~ret-type)))) + + :else + `(if (mem/primitive-type ~ret-type) + (fn ~'native-fn + [~'& ~args] + (with-open [~scope mem/stack-scope] + (mem/deserialize* ~prim-call ~ret-type))) + (fn ~'native-fn + [~'& ~args] + (with-open [~scope mem/stack-scope] + (mem/deserialize-from ~non-prim-call ~ret-type)))))) + (let [arg-syms (repeatedly (count arg-types) #(gensym "arg")) + serialize-args (map (fn [sym type] + (if (s/valid? ::mem/type type) + (if-not (mem/primitive? type) + (list sym + (if (mem/primitive-type type) + `(mem/serialize* ~sym ~type ~scope) + `(let [alloc# (mem/alloc-instance ~type ~scope)] + (mem/serialize-into ~sym ~type alloc# ~scope) + alloc#))) + (if (primitive-cast-sym type) + (list sym (list (primitive-cast-sym type) sym)) + nil)) + (list sym `(mem/serialize ~sym ~type ~scope)))) + arg-syms arg-types) + wrap-serialize (fn [expr] + `(with-open [~scope (mem/stack-scope)] + (let [~@(mapcat identity serialize-args)] + ~expr))) + native-fn (fn [expr] + `(fn ~'native-fn [~@arg-syms] + ~expr)) + none-to-serialize? (zero? (count (filter some? serialize-args)))] + (cond + (and none-to-serialize? + primitive-ret?) + downcall + + primitive-ret? + (-> (cons downcall arg-syms) + wrap-serialize + native-fn) + + :else + (let [call (cons downcall arg-syms) + prim-call `(mem/deserialize* ~call ~ret-type) + non-prim-call `(mem/deserialize-from ~(list* (first call) + `(mem/scope-allocator ~scope) + (rest call)) + ~ret-type)] + (cond + (and none-to-serialize? + const-ret?) + (native-fn (if (mem/primitive-type ret-type) + prim-call + non-prim-call)) + + none-to-serialize? + `(if (mem/primitive-type ~ret-type) + ~(native-fn prim-call) + ~(native-fn non-prim-call)) + + const-ret? + (native-fn (wrap-serialize + (if (mem/primitive-type ret-type) + prim-call + non-prim-call))) + + :else + `(if (mem/primitive-type ~ret-type) + ~(native-fn (wrap-serialize prim-call)) + ~(native-fn (wrap-serialize non-prim-call)))))))))) + (defn make-serde-wrapper "Constructs a wrapper function for the `downcall` which serializes the arguments and deserializes the return value." + {:inline (fn [downcall arg-types ret-type] + (inline-serde-wrapper downcall arg-types ret-type))} [downcall arg-types ret-type] (if (mem/primitive-type ret-type) (fn native-fn [& args] (with-open [scope (mem/stack-scope)] - (mem/deserialize + (mem/deserialize* (apply downcall (map #(mem/serialize %1 %2 scope) args arg-types)) ret-type))) (fn native-fn [& args] (with-open [scope (mem/stack-scope)] - (mem/deserialize + (mem/deserialize-from (apply downcall (mem/scope-allocator scope) (map #(mem/serialize %1 %2 scope) args arg-types)) ret-type))))) @@ -264,15 +382,15 @@ "Constructs a Clojure function to call the native function referenced by `symbol`. The function returned will serialize any passed arguments into the `args` - types, and deserialize the return to the `ret` type." + types, and deserialize the return to the `ret` type. + + If your `args` and `ret` are constants, then it is more efficient to + call [[make-downcall]] followed by [[make-serde-wrapper]] because the latter + has an inline definition which will result in less overhead from serdes." [symbol args ret] (-> symbol - ensure-address (make-downcall args ret) - (cond-> - (every? #(= % (mem/primitive-type %)) - (cons ret args)) - (make-serde-wrapper args ret)))) + (make-serde-wrapper args ret))) (defn vacfn-factory "Constructs a varargs factory to call the native function referenced by `symbol`. @@ -560,7 +678,8 @@ :multi-arity fn-tail nil))] `(let [~address (find-symbol ~(name (:symbol args))) - ~native-sym (cfn ~address ~(:native-arglist args) ~(:return-type args)) + ~native-sym (-> (make-downcall ~address ~(:native-arglist args) ~(:return-type args)) + (make-serde-wrapper ~(:native-arglist args) ~(:return-type args))) fun# ~(if (:wrapper args) `(fn ~(:name args) ~@fn-tail)