diff --git a/src/coffi/ffi.clj b/src/coffi/ffi.clj index ab79866..3d9e312 100644 --- a/src/coffi/ffi.clj +++ b/src/coffi/ffi.clj @@ -157,14 +157,15 @@ "Gets the primitive type that is used to pass as an argument for the `type`. This is for objects which are passed to native functions as primitive types, - but which need additional logic to be performed during serialization. + but which need additional logic to be performed during serialization and + deserialization. - Returns nil for any type " + Returns nil for any type which does not have a primitive representation." (fn [type] type)) (defmethod primitive-type :default [type] - (contains? primitive-types type)) + (primitive-types type)) (defmethod primitive-type ::c-string [_type] @@ -191,20 +192,37 @@ [type] (java-prim-layout type MemorySegment)) +(defn size-of + "The size in bytes of the given `type`." + [type] + (let [layout ^MemoryLayout (c-layout type)] + (.byteSize + (cond-> layout + (qualified-keyword? layout) ^MemoryLayout c-layout)))) + +(defn alloc-instance + "Allocates a memory segment for the given `type`." + ([type] (alloc-instance type (ResourceScope/newImplicitScope))) + ([type scope] (MemorySegment/allocateNative ^long (size-of type) ^ResourceScope scope))) + (defmulti serialize* "Constructs a serialized version of the `obj` and returns it. Any new allocations made during the serialization should be tied to the given `scope`, except in extenuating circumstances. - This method should only be implemented for types serialize to primitives." + This method should only be implemented for types that serialize to primitives." (fn #_{:clj-kondo/ignore [:unused-binding]} [obj type scope])) -(defmethod serialize* ::c-string - [obj _type scope] - (address-of (CLinker/toCString (str obj) ^ResourceScope scope))) +(defmethod serialize* :default + [obj type _scope] + (if (primitive-type type) + obj + (throw (ex-info "Attempted to serialize a non-primitive type with primitive methods" + {:type type + :object obj})))) (defmulti serialize-into "Writes a serialized version of the `obj` to the given `segment`. @@ -222,6 +240,10 @@ [obj type segment scope] type)) +(defmethod serialize* ::c-string + [obj _type scope] + (address-of (CLinker/toCString (str obj) ^ResourceScope scope))) + (defmethod serialize-into :default [obj type segment scope] (let [new-type (c-layout type)] @@ -310,7 +332,7 @@ [segment _type] (MemoryAccess/getAddress segment)) -(defmulti deserialize +(defmulti deserialize* "Deserializes a primitive object into a Clojure data structure. This is intended for use with types that are returned as a primitive but which @@ -320,7 +342,7 @@ [obj type] type)) -(defmethod deserialize :default +(defmethod deserialize* :default [obj _type] obj) @@ -328,24 +350,26 @@ [segment type] (-> segment (deserialize-from ::pointer) - (deserialize type))) + (deserialize* type))) -(defmethod deserialize ::c-string +(defmethod deserialize* ::c-string [obj _type] (CLinker/toJavaString obj)) -(defn size-of - "The size in bytes of the given `type`." - [type] - (let [layout ^MemoryLayout (c-layout type)] - (.byteSize - (cond-> layout - (qualified-keyword? layout) ^MemoryLayout c-layout)))) +(defn serialize + []) -(defn alloc-instance - "Allocates a memory segment for the given `type`." - ([type] (alloc-instance type (ResourceScope/newImplicitScope))) - ([type scope] (MemorySegment/allocateNative ^long (size-of type) ^ResourceScope scope))) +(defn deserialize + "Deserializes an arbitrary type regardless of if it is primitive. + + For types which have a primitive representation, this deserializes the + primitive representation. For types which do not, this deserializes out of + a [[MemorySegment]]." + [obj type] + ((if (primitive-type type) + deserialize* + deserialize-from) + obj type)) (defn serialize "Serializes the `obj` into a newly-allocated [[MemorySegment]]." @@ -392,6 +416,50 @@ [address method-type function-descriptor] (.downcallHandle (CLinker/getInstance) address method-type function-descriptor)) +(s/def ::defcfn-args + (s/cat :name simple-symbol? + :doc (s/? string?) + :symbol (s/nonconforming + (s/or :string string? + :symbol simple-symbol?)) + :native-arglist (s/coll-of qualified-keyword? :kind vector?) + :return-type qualified-keyword? + :fn-tail (s/? + (s/cat :arglist (s/coll-of simple-symbol? :kind vector?) + :body (s/* any?))))) + +(defmacro defcfn + {:arglists '([name docstring? symbol arg-types ret-type arglist & body])} + [& args] + (let [args (s/conform ::defcfn-args args) + scope (gensym "scope") + arg-syms (repeatedly (count (:native-arglist args)) #(gensym "arg"))] + `(let [args-types# ~(:native-arglist args) + ret-type# ~(:return-type args) + downcall# (downcall-handle + (find-symbol ~(:symbol args)) + (method-type args-types# ret-type#) + (function-descriptor args-types# ret-type#)) + ~(:name args) (fn [& args#] + (with-open [~scope (stack-scope)] + (let [[~@arg-syms] (map #(serialize ))] + (.invoke downcall# ~@arg-syms)))) + fun# ~(if (:fn-tail args) + `(fn ~(-> args :fn-tail :arglist) + ~@(-> args :fn-tail :body)) + (:name args))] + (def + ~(vary-meta (:name args) + update :arglists + (fn [old-list] + (or old-list + (list + (or (-> args :fn-tail :arglist) + (mapv (comp symbol name) + (:native-arglist args))))))) + ~@(list (:doc args)) + fun#)))) + (comment (let [args-types [::c-string] @@ -403,7 +471,7 @@ strlen (fn [str] (with-open [scope (stack-scope)] (let [arg1 (serialize (nth args-types 0) str scope)] - (deserialize (.invoke downcall arg1) ret-type))))] + (deserialize* (.invoke downcall arg1) ret-type))))] (def ^{:arglists '([str])} strlen