diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e50d58..02c4bdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # Change Log All notable changes to this project will be documented in this file. This change log follows the conventions of [keepachangelog.com](http://keepachangelog.com/). +## [0.1.241] - 2021-10-14 +### Performance +- Added an `:inline` function to `make-serde-wrapper` to remove serialization overhead on primitives +- Added multimethod implementations for primitives in (de)serialization functions, rather than using the default + +### Fixed +- `cfn` didn't add serializers with non-primitive types in some cases + ## [0.1.220] - 2021-10-09 ### Fixed - All-primitive method types still used serialization when called from `cfn` @@ -47,6 +55,7 @@ All notable changes to this project will be documented in this file. This change - Support for serializing and deserializing arbitrary Clojure functions - Support for serializing and deserializing arbitrary Clojure data structures +[0.1.241]: https://github.com/IGJoshua/coffi/compare/v0.1.220...v0.1.241 [0.1.220]: https://github.com/IGJoshua/coffi/compare/v0.1.205...v0.1.220 [0.1.205]: https://github.com/IGJoshua/coffi/compare/v0.1.192...v0.1.205 [0.1.192]: https://github.com/IGJoshua/coffi/compare/v0.1.184...v0.1.192 diff --git a/README.md b/README.md index c896dce..b57713c 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ This library is available on Clojars. Add one of the following entries to the `:deps` key of your `deps.edn`: ```clojure -org.suskalo/coffi {:mvn/version "0.1.220"} -io.github.IGJoshua/coffi {:git/tag "v0.1.220" :git/sha "abcbf0f"} +org.suskalo/coffi {:mvn/version "0.1.241"} +io.github.IGJoshua/coffi {:git/tag "v0.1.241" :git/sha "5fa1f15"} ``` If you use this library as a git dependency, you will need to prepare the diff --git a/src/clj/coffi/ffi.clj b/src/clj/coffi/ffi.clj index a3bc11c..ebb8094 100644 --- a/src/clj/coffi/ffi.clj +++ b/src/clj/coffi/ffi.clj @@ -20,7 +20,6 @@ CLinker FunctionDescriptor MemoryLayout - MemorySegment SegmentAllocator))) ;;; FFI Code loading and function access @@ -231,19 +230,145 @@ (let [args (concat required-args types)] (make-downcall symbol args ret))))) +(def ^:private primitive-cast-sym + "Map from non-pointer primitive types to functions that cast to the appropriate + java primitive." + {::mem/byte `byte + ::mem/short `short + ::mem/int `int + ::mem/long `long + ::mem/long-long `long + ::mem/char `char + ::mem/float `float + ::mem/double `double}) + +(defn- inline-serde-wrapper + "Builds a form that returns a function that calls `downcall` with serdes. + + The return type and any arguments that are primitives will not + be (de)serialized except to be cast. If all arguments and return are + primitive, the `downcall` is returned directly. In cases where arguments must + be serialized, a new [[mem/stack-scope]] is generated." + [downcall arg-types ret-type] + (let [const-ret? (s/valid? ::mem/type ret-type) + primitive-ret? (mem/primitive? ret-type) + scope (gensym "scope") + downcall-sym (gensym "downcall")] + `(let [~downcall-sym ~downcall] + ~(if-not (seqable? arg-types) + (let [args (gensym "args") + ret (gensym "ret") + serialized-args `(map (fn [arg# type#] (mem/serialize arg# type# ~scope)) ~args ~arg-types) + prim-call `(apply ~downcall-sym ~serialized-args) + non-prim-call `(apply ~downcall-sym (mem/scope-allocator ~scope) ~serialized-args)] + (cond + (and const-ret? + primitive-ret?) + `(fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scope)] + ~prim-call)) + + const-ret? + `(let [~ret ~ret-type] + (fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scopee)] + ~(if (mem/primitive-type ret-type) + `(mem/deserialize* ~prim-call ~ret) + `(mem/deserialize-from ~non-prim-call ~ret))))) + + :else + `(let [~ret ~ret-type] + (if (mem/primitive-type ~ret) + (fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scope)] + (mem/deserialize* ~prim-call ~ret))) + (fn ~'native-fn + [~'& ~args] + (with-open [~scope (mem/stack-scope)] + (mem/deserialize-from ~non-prim-call ~ret))))))) + (let [arg-syms (repeatedly (count arg-types) #(gensym "arg")) + ret (gensym "ret") + serialize-args (map (fn [sym type] + (if (s/valid? ::mem/type type) + (if-not (mem/primitive? type) + (list sym + (if (mem/primitive-type type) + `(mem/serialize* ~sym ~type ~scope) + `(let [alloc# (mem/alloc-instance ~type ~scope)] + (mem/serialize-into ~sym ~type alloc# ~scope) + alloc#))) + (if (primitive-cast-sym type) + (list sym (list (primitive-cast-sym type) sym)) + nil)) + (list sym `(mem/serialize ~sym ~type ~scope)))) + arg-syms arg-types) + wrap-serialize (fn [expr] + `(with-open [~scope (mem/stack-scope)] + (let [~@(mapcat identity serialize-args)] + ~expr))) + native-fn (fn [expr] + `(fn ~'native-fn [~@arg-syms] + ~expr)) + none-to-serialize? (zero? (count (filter some? serialize-args)))] + (cond + (and none-to-serialize? + primitive-ret?) + downcall-sym + + primitive-ret? + (-> (cons downcall-sym arg-syms) + wrap-serialize + native-fn) + + :else + `(let [~ret ~ret-type] + ~(let [call (cons downcall-sym arg-syms) + prim-call `(mem/deserialize* ~call ~ret) + non-prim-call `(mem/deserialize-from ~(list* (first call) + `(mem/scope-allocator ~scope) + (rest call)) + ~ret)] + (cond + (and none-to-serialize? + const-ret?) + (native-fn (if (mem/primitive-type ret-type) + prim-call + non-prim-call)) + + none-to-serialize? + (if (mem/primitive-type ~ret) + ~(native-fn prim-call) + ~(native-fn non-prim-call)) + + const-ret? + (native-fn (wrap-serialize + (if (mem/primitive-type ret-type) + prim-call + non-prim-call))) + + :else + `(if (mem/primitive-type ~ret) + ~(native-fn (wrap-serialize prim-call)) + ~(native-fn (wrap-serialize non-prim-call)))))))))))) + (defn make-serde-wrapper "Constructs a wrapper function for the `downcall` which serializes the arguments and deserializes the return value." + {:inline (fn [downcall arg-types ret-type] + (inline-serde-wrapper downcall arg-types ret-type))} [downcall arg-types ret-type] (if (mem/primitive-type ret-type) (fn native-fn [& args] (with-open [scope (mem/stack-scope)] - (mem/deserialize + (mem/deserialize* (apply downcall (map #(mem/serialize %1 %2 scope) args arg-types)) ret-type))) (fn native-fn [& args] (with-open [scope (mem/stack-scope)] - (mem/deserialize + (mem/deserialize-from (apply downcall (mem/scope-allocator scope) (map #(mem/serialize %1 %2 scope) args arg-types)) ret-type))))) @@ -264,15 +389,15 @@ "Constructs a Clojure function to call the native function referenced by `symbol`. The function returned will serialize any passed arguments into the `args` - types, and deserialize the return to the `ret` type." + types, and deserialize the return to the `ret` type. + + If your `args` and `ret` are constants, then it is more efficient to + call [[make-downcall]] followed by [[make-serde-wrapper]] because the latter + has an inline definition which will result in less overhead from serdes." [symbol args ret] (-> symbol - ensure-address (make-downcall args ret) - (cond-> - (every? #(= % (mem/primitive-type %)) - (cons ret args)) - (make-serde-wrapper args ret)))) + (make-serde-wrapper args ret))) (defn vacfn-factory "Constructs a varargs factory to call the native function referenced by `symbol`. @@ -548,8 +673,6 @@ :style/indent [:defn]} [& args] (let [args (s/conform ::defcfn-args args) - args-types (gensym "args-types") - ret-type (gensym "ret-type") address (gensym "symbol") native-sym (gensym "native") [arity fn-tail] (-> args :wrapper :fn-tail) @@ -561,10 +684,11 @@ :single-arity [fn-tail] :multi-arity fn-tail nil))] - `(let [~args-types ~(:native-arglist args) - ~ret-type ~(:return-type args) - ~address (find-symbol ~(name (:symbol args))) - ~native-sym (cfn ~address ~args-types ~ret-type) + `(let [~address (find-symbol ~(name (:symbol args))) + ~(or (-> args :wrapper :native-fn) + native-sym) + (-> (make-downcall ~address ~(:native-arglist args) ~(:return-type args)) + (make-serde-wrapper ~(:native-arglist args) ~(:return-type args))) fun# ~(if (:wrapper args) `(fn ~(:name args) ~@fn-tail) diff --git a/src/clj/coffi/mem.clj b/src/clj/coffi/mem.clj index 7378384..8e87d4a 100644 --- a/src/clj/coffi/mem.clj +++ b/src/clj/coffi/mem.clj @@ -205,14 +205,6 @@ (map #(slice segment (* % size) size) (range num-segments)))) -(def primitive-types - "A set of keywords representing all the primitive types which may be passed to - or returned from native functions." - #{::byte ::short ::int ::long ::long-long - ::char - ::float ::double - ::pointer ::void}) - (defn- type-dispatch "Gets a type dispatch value from a (potentially composite) type." [type] @@ -220,6 +212,11 @@ (qualified-keyword? type) type (sequential? type) (keyword (first type)))) +(def primitive? + "A set of all primitive types." + #{::byte ::short ::int ::long ::long-long + ::char ::float ::double ::pointer}) + (defmulti primitive-type "Gets the primitive type that is used to pass as an argument for the `type`. @@ -227,28 +224,55 @@ but which need additional logic to be performed during serialization and deserialization. + Implementations of this method should take into account that type arguments + may not always be evaluated before passing to this function. + Returns nil for any type which does not have a primitive representation." type-dispatch) (defmethod primitive-type :default - [type] - (primitive-types type)) + [_type] + nil) + +(defmethod primitive-type ::byte + [_type] + ::byte) + +(defmethod primitive-type ::short + [_type] + ::short) + +(defmethod primitive-type ::int + [_type] + ::int) + +(defmethod primitive-type ::long + [_type] + ::long) + +(defmethod primitive-type ::long-long + [_type] + ::long-long) + +(defmethod primitive-type ::char + [_type] + ::char) + +(defmethod primitive-type ::float + [_type] + ::float) + +(defmethod primitive-type ::double + [_type] + ::double) (defmethod primitive-type ::pointer [_type] ::pointer) -(def c-prim-layout - "Map of primitive type names to the [[CLinker]] types for a method handle." - {::byte CLinker/C_CHAR - ::short CLinker/C_SHORT - ::int CLinker/C_INT - ::long CLinker/C_LONG - ::long-long CLinker/C_LONG_LONG - ::char CLinker/C_CHAR - ::float CLinker/C_FLOAT - ::double CLinker/C_DOUBLE - ::pointer CLinker/C_POINTER}) +(defmethod primitive-type ::void + [_type] + ::void) (defmulti c-layout "Gets the layout object for a given `type`. @@ -261,7 +285,43 @@ (defmethod c-layout :default [type] - (c-prim-layout (or (primitive-type type) type))) + (c-layout (primitive-type type))) + +(defmethod c-layout ::byte + [_type] + CLinker/C_CHAR) + +(defmethod c-layout ::short + [_type] + CLinker/C_SHORT) + +(defmethod c-layout ::int + [_type] + CLinker/C_INT) + +(defmethod c-layout ::long + [_type] + CLinker/C_LONG) + +(defmethod c-layout ::long-long + [_type] + CLinker/C_LONG_LONG) + +(defmethod c-layout ::char + [_type] + CLinker/C_CHAR) + +(defmethod c-layout ::float + [_type] + CLinker/C_FLOAT) + +(defmethod c-layout ::double + [_type] + CLinker/C_DOUBLE) + +(defmethod c-layout ::pointer + [_type] + CLinker/C_POINTER) (def java-prim-layout "Map of primitive type names to the Java types for a method handle." @@ -308,25 +368,43 @@ [obj type scope] (type-dispatch type))) -(def ^:private primitive-cast - "Map from primitive type names to the function to cast it to a primitive." - {::byte byte - ::short short - ::int int - ::long long - ::long-long long - ::char char - ::float float - ::double double}) - (defmethod serialize* :default [obj type _scope] - (if-let [prim (primitive-type type)] - (when-not (= ::void prim) - ((primitive-cast prim) obj)) - (throw (ex-info "Attempted to serialize a non-primitive type with primitive methods" - {:type type - :object obj})))) + (throw (ex-info "Attempted to serialize a non-primitive type with primitive methods" + {:type type + :object obj}))) + +(defmethod serialize* ::byte + [obj _type _scope] + (byte obj)) + +(defmethod serialize* ::short + [obj _type _scope] + (short obj)) + +(defmethod serialize* ::int + [obj _type _scope] + (int obj)) + +(defmethod serialize* ::long + [obj _type _scope] + (long obj)) + +(defmethod serialize* ::long-long + [obj _type _scope] + (long obj)) + +(defmethod serialize* ::char + [obj _type _scope] + (char obj)) + +(defmethod serialize* ::float + [obj _type _scope] + (float obj)) + +(defmethod serialize* ::double + [obj _type _scope] + (double obj)) (defmethod serialize* ::pointer [obj type scope] @@ -439,10 +517,9 @@ (defmethod deserialize-from :default [segment type] (if-some [prim (primitive-type type)] - (with-acquired [(segment-scope segment)] - (-> segment - (deserialize-from prim) - (deserialize* type))) + (-> segment + (deserialize-from prim) + (deserialize* type)) (throw (ex-info "Attempted to deserialize a non-primitive type that has not been overriden" {:type type :segment segment})))) @@ -497,11 +574,41 @@ (defmethod deserialize* :default [obj type] - (if (primitive-type type) - obj - (throw (ex-info "Attempted to deserialize a non-primitive type with primitive methods" - {:type type - :segment obj})))) + (throw (ex-info "Attempted to deserialize a non-primitive type with primitive methods" + {:type type + :segment obj}))) + +(defmethod deserialize* ::byte + [obj _type] + obj) + +(defmethod deserialize* ::short + [obj _type] + obj) + +(defmethod deserialize* ::int + [obj _type] + obj) + +(defmethod deserialize* ::long + [obj _type] + obj) + +(defmethod deserialize* ::long-long + [obj _type] + obj) + +(defmethod deserialize* ::char + [obj _type] + obj) + +(defmethod deserialize* ::float + [obj _type] + obj) + +(defmethod deserialize* ::double + [obj _type] + obj) (defmethod deserialize* ::pointer [addr type] @@ -511,6 +618,10 @@ (second type)) addr))) +(defmethod deserialize* ::void + [_obj _type] + nil) + (defn deserialize "Deserializes an arbitrary type.