Update serialize and deserialize for primitive types

This commit is contained in:
Joshua Suskalo 2021-09-15 16:53:29 -05:00
parent c5df70ac1a
commit 0c040e3a73

View file

@ -1,10 +1,17 @@
(ns coffi.ffi
(:refer-clojure :exclude [defstruct])
(:require
[clojure.java.io :as io]
[clojure.spec.alpha :as s])
(:import
(java.lang.invoke VarHandle)
(java.lang.invoke
VarHandle
MethodHandle
MethodType)
(jdk.incubator.foreign
CLinker
FunctionDescriptor
GroupLayout
MemoryAccess
MemoryAddress
MemoryHandles
@ -12,176 +19,14 @@
MemoryLayout$PathElement
MemoryLayouts
MemorySegment
ResourceScope)))
(defmulti serialize*
"Writes a serialized version of the `obj` to the given `segment`.
Any new allocations made during the serialization should be tied to the given
`scope`, except in extenuating circumstances."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[obj type segment scope]
type))
(defmethod serialize* ::byte
[obj _type segment _scope]
(MemoryAccess/setByte segment (byte obj)))
(defmethod serialize* ::short
[obj _type segment _scope]
(MemoryAccess/setShort segment (short obj)))
(defmethod serialize* ::int
[obj _type segment _scope]
(MemoryAccess/setInt segment (int obj)))
(defmethod serialize* ::long
[obj _type segment _scope]
(MemoryAccess/setLong segment (long obj)))
(defmethod serialize* ::long-long
[obj _type segment _scope]
(MemoryAccess/setLong segment (long obj)))
(defmethod serialize* ::char
[obj _type segment _scope]
(MemoryAccess/setChar segment (char obj)))
(defmethod serialize* ::float
[obj _type segment _scope]
(MemoryAccess/setFloat segment (float obj)))
(defmethod serialize* ::double
[obj _type segment _scope]
(MemoryAccess/setDouble segment (double obj)))
(defmethod serialize* ::pointer
[obj _type segment _scope]
(MemoryAccess/setAddress segment obj))
(defmulti deserialize
"Deserializes the given segment into a Clojure data structure."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[segment type]
type))
(defmethod deserialize ::byte
[segment _type]
(MemoryAccess/getByte segment))
(defmethod deserialize ::short
[segment _type]
(MemoryAccess/getShort segment))
(defmethod deserialize ::int
[segment _type]
(MemoryAccess/getInt segment))
(defmethod deserialize ::long
[segment _type]
(MemoryAccess/getLong segment))
(defmethod deserialize ::long-long
[segment _type]
(MemoryAccess/getLong segment))
(defmethod deserialize ::char
[segment _type]
(MemoryAccess/getChar segment))
(defmethod deserialize ::float
[segment _type]
(MemoryAccess/getFloat segment))
(defmethod deserialize ::double
[segment _type]
(MemoryAccess/getDouble segment))
(defmethod deserialize ::pointer
[segment _type]
(MemoryAccess/getAddress segment))
(defmulti size-of
"The size in bytes of the given `type`."
(fn [type] type))
(defmethod size-of ::byte
[_type]
Byte/SIZE)
(defmethod size-of ::short
[_type]
Short/SIZE)
(defmethod size-of ::int
[_type]
Integer/SIZE)
(defmethod size-of ::long
[_type]
Long/SIZE)
(defmethod size-of ::long-long
[_type]
Long/SIZE)
(defmethod size-of ::char
[_type]
Byte/SIZE)
(defmethod size-of ::float
[_type]
Float/SIZE)
(defmethod size-of ::double
[_type]
Double/SIZE)
(defmethod size-of ::pointer
[_type]
(.byteSize MemoryLayouts/ADDRESS))
(def c-layout
"Map of primitive type names to the [[CLinker]] types for a method handle."
{::byte CLinker/C_CHAR
::short CLinker/C_SHORT
::int CLinker/C_INT
::long CLinker/C_LONG
::long-long CLinker/C_LONG_LONG
::char CLinker/C_CHAR
::float CLinker/C_FLOAT
::double CLinker/C_DOUBLE
::pointer CLinker/C_POINTER})
(def java-layout
"Map of primitive type names to the Java types for a method handle."
{::byte Byte/TYPE
::short Short/TYPE
::int Integer/TYPE
::long Long/TYPE
::long-long Long/TYPE
::char Byte/TYPE
::float Float/TYPE
::double Double/TYPE
::pointer MemoryAddress})
ResourceScope
SegmentAllocator)))
(defn alloc
"Allocates `size` bytes."
([size] (alloc size (ResourceScope/newImplicitScope)))
([size scope] (MemorySegment/allocateNative ^long size ^ResourceScope scope)))
(defn alloc-instance
"Allocates a memory segment for the given `type`."
([type] (alloc-instance type (ResourceScope/newImplicitScope)))
([type scope] (MemorySegment/allocateNative ^long (size-of type) ^ResourceScope scope)))
(defn serialize
"Serializes the `obj` into a newly-allocated [[MemorySegment]]."
([obj type] (serialize obj type (ResourceScope/newImplicitScope)))
([obj type scope] (serialize* obj type (alloc-instance type scope) scope)))
(defn stack-scope
"Constructs a new scope for use only in this thread.
@ -198,6 +43,18 @@
[]
(ResourceScope/newSharedScope))
(defmacro with-acquired
"Acquires a `scope` to ensure it will not be released until the `body` completes.
This is only necessary to do on shared scopes, however if you are operating on
an arbitrary passed scope, it is best practice to wrap code that interacts
with it wrapped in this."
[scope & body]
`(let [scope# ~scope
handle# (.acquire ^ResourceScope scope#)]
(try ~@body
(finally (.release ^ResourceScope scope# handle#)))))
(defn address-of
"Gets the address of a given segment.
@ -222,25 +79,354 @@
([segment offset size]
(.asSlice ^MemorySegment segment ^long offset ^long size)))
(defn slice-at
(defn slice-into
"Get a slice over the `segment` starting at the `address`."
([segment address]
([address segment]
(.asSlice ^MemorySegment segment ^MemoryAddress address))
([segment address size]
([address segment size]
(.asSlice ^MemorySegment segment ^MemoryAddress address ^long size)))
(defn with-offset
"Get a new address `offset` from the old `address`."
[address offset]
(.addOffset ^MemoryAddress address ^long offset))
(defn as-segment
"Dereferences an `address` into a memory segment associated with the `scope`.
If `cleanup` is provided, it is a 0-arity function run when the scope is
closed. This can be used to register a free method for the memory, or do other
cleanup in a way that doesn't require modifying the code at the point of
freeing, and allows shared or garbage collected resources to be freed
correctly."
([address size scope]
(.asSegment ^MemoryAddress address size scope))
([address size scope cleanup]
(.asSegment ^MemoryAddress address size cleanup scope)))
(defn add-close-action!
"Adds a 0-arity function to be run when the `scope` closes."
[scope action]
(.addCloseAction ^ResourceScope scope action))
#_(defn seq-of
"Constructs a lazy sequence of `type` elements deserialized from `segment`."
[type segment]
(let [size (size-of type)]
(letfn [(rec [segment]
(lazy-seq
(when (>= (.byteSize ^MemorySegment segment) size)
(cons (deserialize-from type segment)
(rec (slice segment size))))))]
(rec segment))))
(def primitive-types
"A set of keywords representing all the primitive types which may be passed to
native functions."
#{::byte ::short ::int ::long ::long-long
::char
::float ::double
::pointer})
(def c-prim-layout
"Map of primitive type names to the [[CLinker]] types for a method handle."
{::byte CLinker/C_CHAR
::short CLinker/C_SHORT
::int CLinker/C_INT
::long CLinker/C_LONG
::long-long CLinker/C_LONG_LONG
::char CLinker/C_CHAR
::float CLinker/C_FLOAT
::double CLinker/C_DOUBLE
::pointer CLinker/C_POINTER})
(defmulti c-layout
"Gets the layout object for a given `type`.
If a type is primitive it will return the appropriate primitive
layout (see [[c-prim-layout]]).
Otherwise, it should return a [[GroupLayout]] for the given type."
(fn [type] type))
(defmethod c-layout :default
[type]
(c-prim-layout type))
(defmulti primitive-type
"Gets the primitive type that is used to pass as an argument for the `type`.
This is for objects which are passed to native functions as primitive types,
but which need additional logic to be performed during serialization.
Returns nil for any type "
(fn [type] type))
(defmethod primitive-type :default
[type]
(contains? primitive-types type))
(defmethod primitive-type ::c-string
[_type]
::pointer)
(def java-prim-layout
"Map of primitive type names to the Java types for a method handle."
{::byte Byte/TYPE
::short Short/TYPE
::int Integer/TYPE
::long Long/TYPE
::long-long Long/TYPE
::char Byte/TYPE
::float Float/TYPE
::double Double/TYPE
::pointer MemoryAddress
::void Void/TYPE})
(defmulti java-layout
"Gets the Java class to an argument of this type for a method handle."
(fn [type] type))
(defmethod java-layout :default
[type]
(java-prim-layout type MemorySegment))
(defmulti serialize*
"Constructs a serialized version of the `obj` and returns it.
Any new allocations made during the serialization should be tied to the given
`scope`, except in extenuating circumstances.
This method should only be implemented for types serialize to primitives."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[obj type scope]))
(defmethod serialize* ::c-string
[obj _type scope]
(address-of (CLinker/toCString (str obj) ^ResourceScope scope)))
(defmulti serialize-into
"Writes a serialized version of the `obj` to the given `segment`.
Any new allocations made during the serialization should be tied to the given
`scope`, except in extenuating circumstances.
This method should be implemented for any type which does not
override [[c-layout]].
For any other type, this will serialize it as [[serialize*]] before writing
the result value into the `segment`."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[obj type segment scope]
type))
(defmethod serialize-into :default
[obj type segment scope]
(let [new-type (c-layout type)]
(if (qualified-keyword? new-type)
(serialize-into (serialize* obj type scope) new-type segment scope)
(throw (ex-info "Attempted to serialize an object to a type that has not been overriden."
{:type type
:object obj})))))
(defmethod serialize-into ::byte
[obj _type segment _scope]
(MemoryAccess/setByte segment (byte obj)))
(defmethod serialize-into ::short
[obj _type segment _scope]
(MemoryAccess/setShort segment (short obj)))
(defmethod serialize-into ::int
[obj _type segment _scope]
(MemoryAccess/setInt segment (int obj)))
(defmethod serialize-into ::long
[obj _type segment _scope]
(MemoryAccess/setLong segment (long obj)))
(defmethod serialize-into ::long-long
[obj _type segment _scope]
(MemoryAccess/setLong segment (long obj)))
(defmethod serialize-into ::char
[obj _type segment _scope]
(MemoryAccess/setChar segment (char obj)))
(defmethod serialize-into ::float
[obj _type segment _scope]
(MemoryAccess/setFloat segment (float obj)))
(defmethod serialize-into ::double
[obj _type segment _scope]
(MemoryAccess/setDouble segment (double obj)))
(defmethod serialize-into ::pointer
[obj _type segment _scope]
(MemoryAccess/setAddress segment obj))
(defmulti deserialize-from
"Deserializes the given segment into a Clojure data structure."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[segment type]
type))
(defmethod deserialize-from ::byte
[segment _type]
(MemoryAccess/getByte segment))
(defmethod deserialize-from ::short
[segment _type]
(MemoryAccess/getShort segment))
(defmethod deserialize-from ::int
[segment _type]
(MemoryAccess/getInt segment))
(defmethod deserialize-from ::long
[segment _type]
(MemoryAccess/getLong segment))
(defmethod deserialize-from ::long-long
[segment _type]
(MemoryAccess/getLong segment))
(defmethod deserialize-from ::char
[segment _type]
(MemoryAccess/getChar segment))
(defmethod deserialize-from ::float
[segment _type]
(MemoryAccess/getFloat segment))
(defmethod deserialize-from ::double
[segment _type]
(MemoryAccess/getDouble segment))
(defmethod deserialize-from ::pointer
[segment _type]
(MemoryAccess/getAddress segment))
(defmulti deserialize
"Deserializes a primitive object into a Clojure data structure.
This is intended for use with types that are returned as a primitive but which
need additional processing before they can be returned."
(fn
#_{:clj-kondo/ignore [:unused-binding]}
[obj type]
type))
(defmethod deserialize :default
[obj _type]
obj)
(defmethod deserialize-from ::c-string
[segment type]
(-> segment
(deserialize-from ::pointer)
(deserialize type)))
(defmethod deserialize ::c-string
[obj _type]
(CLinker/toJavaString obj))
(defn size-of
"The size in bytes of the given `type`."
[type]
(let [layout ^MemoryLayout (c-layout type)]
(.byteSize
(cond-> layout
(qualified-keyword? layout) ^MemoryLayout c-layout))))
(defn alloc-instance
"Allocates a memory segment for the given `type`."
([type] (alloc-instance type (ResourceScope/newImplicitScope)))
([type scope] (MemorySegment/allocateNative ^long (size-of type) ^ResourceScope scope)))
(defn serialize
"Serializes the `obj` into a newly-allocated [[MemorySegment]]."
([obj type] (serialize obj type (ResourceScope/newImplicitScope)))
([obj type scope] (serialize-into obj type (alloc-instance type scope) scope)))
(defn load-system-library
"Loads the library named `libname` from the system's load path."
[libname]
(System/loadLibrary (name ~libname)))
(defn load-library
"Loads the library at `path`."
[path]
(System/load (.getAbsolutePath (io/file path))))
(defn- find-symbol
"Gets the [[MemoryAddress]] of a symbol from the loaded libraries."
[sym]
(.. (CLinker/systemLookup) (lookup sym) (get)))
(defn- method-type
"Gets the [[MethodType]] for a set of `args` and `ret` types."
([args] (method-type args ::void))
([args ret]
(MethodType/methodType
^Class ret
^"[Ljava.lang.Class;" (into-array Class (map java-layout args)))))
(defn- function-descriptor
"Gets the [[FunctionDescriptor]] for a set of `args` and `ret` types."
([args] (function-descriptor args ::void))
([args ret]
(let [args-arr (into-array MemoryLayout (map c-layout args))]
(if-not (identical? ret ::void)
(FunctionDescriptor/of
(c-layout ret)
args-arr)
(FunctionDescriptor/ofVoid
args-arr)))))
(defn- downcall-handle
"Gets the [[MethodHandle]] for the function at the `address`."
[address method-type function-descriptor]
(.downcallHandle (CLinker/getInstance) address method-type function-descriptor))
(comment
(let [args-types [::c-string]
ret-type ::int
downcall (downcall-handle
(find-symbol "strlen")
(method-type args-types ret-type)
(function-descriptor args-types ret-type))
strlen (fn [str]
(with-open [scope (stack-scope)]
(let [arg1 (serialize (nth args-types 0) str scope)]
(deserialize (.invoke downcall arg1) ret-type))))]
(def
^{:arglists '([str])}
strlen
"Counts the number of bytes in a C string."
strlen))
)
#_:clj-kondo/ignore
(comment
;;; Prospective syntax for ffi
;; This function has no out params, and no extra marshalling work, so it has no
;; body
(defcfun strlen
(-> (defcfn strlen
"Counts the number of bytes in a C String."
"strlen" [::c-string] ::int)
quote
macroexpand-1)
;; This function has an output parameter and requires some clojure code to
;; translate the values from the c fn to something sensible in clojure.
(defcfun some-func
(defcfn some-func
"Gets some output value"
"someFunc" [::pointer] ::int
[]
@ -248,13 +434,13 @@
(let [out-int (alloc-instance ::int scope)
success? (zero? (some-func (address-of out-int)))]
(if success?
(deserialize ::int out-int)
(deserialize-from ::int out-int)
(throw (ex-info (getErrorString) {}))))))
;; This function probably wouldn't actually get wrapped, since the cost of
;; marshalling is greater than the speed boost of using an in-place sort. That
;; said, this is a nice sample of what more complex marshalling looks like.
(defcfun qsort
(defcfn qsort
"Quicksort implementation"
"qsort"
[::pointer ::long ::long (fn [::pointer ::pointer] ::int)]
@ -262,12 +448,12 @@
[type comparator list]
(with-open [scope (stack-scope)]
(let [copied-list (alloc (* (count list) (size-of type)) scope)
_ (dorun (map #(serialize* %1 type %2 scope) list (seq-of type copied-list)))
_ (dorun (map #(serialize-into %1 type %2 scope) list (seq-of type copied-list)))
comp-fn (fn [addr1 addr2]
(let [obj1 (deserialize type (slice-global addr1 (size-of type)))
obj2 (deserialize type (slice-global addr2 (size-of type)))]
(let [obj1 (deserialize-from type (slice-global addr1 (size-of type)))
obj2 (deserialize-from type (slice-global addr2 (size-of type)))]
(comparator obj1 obj2)))]
(qsort copied-list (count list) (size-of type) comp-fn)
(for [segment (seq-of type copied-list)]
(deserialize type segment)))))
(deserialize-from type segment)))))
)