WIP Finish updating files to fully be on JDK 21

This is still incomplete, I get crashes on upcalls.
This commit is contained in:
Joshua Suskalo 2024-01-29 06:22:39 -06:00
parent 2325abf53b
commit b7092b4af6
No known key found for this signature in database
GPG key ID: 9B6BA586EFF1B9F0
3 changed files with 78 additions and 88 deletions

View file

@ -14,10 +14,9 @@
MethodHandles MethodHandles
MethodType) MethodType)
(java.lang.foreign (java.lang.foreign
Addressable
Linker Linker
Linker$Option
FunctionDescriptor FunctionDescriptor
MemoryAddress
MemoryLayout MemoryLayout
MemorySegment MemorySegment
SegmentAllocator))) SegmentAllocator)))
@ -56,7 +55,8 @@
(defn- downcall-handle (defn- downcall-handle
"Gets the [[MethodHandle]] for the function at the `sym`." "Gets the [[MethodHandle]] for the function at the `sym`."
[sym function-descriptor] [sym function-descriptor]
(.downcallHandle (Linker/nativeLinker) sym function-descriptor)) (.downcallHandle (Linker/nativeLinker) sym function-descriptor
(make-array Linker$Option 0)))
(def ^:private load-instructions (def ^:private load-instructions
"Mapping from primitive types to the instruction used to load them onto the stack." "Mapping from primitive types to the instruction used to load them onto the stack."
@ -130,15 +130,6 @@
[:invokevirtual (prim-classes prim-type) (unbox-fn-for-type prim-type) [prim]]] [:invokevirtual (prim-classes prim-type) (unbox-fn-for-type prim-type) [prim]]]
[])))) []))))
(defn- coerce-addressable
"If the passed `type` is [[MemoryAddress]], returns [[Addressable]], otherwise returns `type`.
This is used to declare the return types of upcall stubs."
[type]
(if (= type MemoryAddress)
Addressable
type))
(defn- downcall-class (defn- downcall-class
"Class definition for an implementation of [[IFn]] which calls a closed over "Class definition for an implementation of [[IFn]] which calls a closed over
method handle without reflection, unboxing primitives when needed." method handle without reflection, unboxing primitives when needed."
@ -175,7 +166,7 @@
args) args)
[:invokevirtual MethodHandle "invokeExact" [:invokevirtual MethodHandle "invokeExact"
(cond->> (cond->>
(conj (mapv (comp coerce-addressable insn-layout) args) (conj (mapv insn-layout args)
(insn-layout ret)) (insn-layout ret))
(not (mem/primitive-type ret)) (cons SegmentAllocator))] (not (mem/primitive-type ret)) (cons SegmentAllocator))]
(to-object-asm ret) (to-object-asm ret)
@ -343,7 +334,7 @@
;; taking restargs, and so the downcall must be applied ;; taking restargs, and so the downcall must be applied
(-> `(~@(when (symbol? args) [`apply]) (-> `(~@(when (symbol? args) [`apply])
~downcall-sym ~downcall-sym
~@(when allocator? [`(mem/session-allocator ~session)]) ~@(when allocator? [`(mem/arena-allocator ~session)])
~@(if (symbol? args) ~@(if (symbol? args)
[args] [args]
args)) args))
@ -410,7 +401,7 @@
(fn native-fn [& args] (fn native-fn [& args]
(with-open [session (mem/stack-session)] (with-open [session (mem/stack-session)]
(mem/deserialize-from (mem/deserialize-from
(apply downcall (mem/session-allocator session) (apply downcall (mem/arena-allocator session)
(map #(mem/serialize %1 %2 session) args arg-types)) (map #(mem/serialize %1 %2 session) args arg-types))
ret-type))))) ret-type)))))
@ -435,6 +426,7 @@
If your `args` and `ret` are constants, then it is more efficient to If your `args` and `ret` are constants, then it is more efficient to
call [[make-downcall]] followed by [[make-serde-wrapper]] because the latter call [[make-downcall]] followed by [[make-serde-wrapper]] because the latter
has an inline definition which will result in less overhead from serdes." has an inline definition which will result in less overhead from serdes."
;; TODO(Joshua): Add an inline arity for when the args and ret types are constant
[symbol args ret] [symbol args ret]
(-> symbol (-> symbol
(make-downcall args ret) (make-downcall args ret)
@ -489,7 +481,7 @@
{:name :upcall {:name :upcall
:flags #{:public} :flags #{:public}
:desc (conj (mapv insn-layout arg-types) :desc (conj (mapv insn-layout arg-types)
(coerce-addressable (insn-layout ret-type))) (insn-layout ret-type))
:emit [[:aload 0] :emit [[:aload 0]
[:getfield :this "upcall_ifn" IFn] [:getfield :this "upcall_ifn" IFn]
(loop [types arg-types (loop [types arg-types
@ -505,7 +497,7 @@
inc))) inc)))
acc)) acc))
[:invokeinterface IFn "invoke" (repeat (inc (count arg-types)) Object)] [:invokeinterface IFn "invoke" (repeat (inc (count arg-types)) Object)]
(to-prim-asm (coerce-addressable ret-type)) (to-prim-asm ret-type)
[(return-for-type ret-type :areturn)]]}]}) [(return-for-type ret-type :areturn)]]}]})
(defn- upcall (defn- upcall
@ -518,7 +510,7 @@
([args] (method-type args ::mem/void)) ([args] (method-type args ::mem/void))
([args ret] ([args ret]
(MethodType/methodType (MethodType/methodType
^Class (coerce-addressable (mem/java-layout ret)) ^Class (mem/java-layout ret)
^"[Ljava.lang.Class;" (into-array Class (map mem/java-layout args))))) ^"[Ljava.lang.Class;" (into-array Class (map mem/java-layout args)))))
(defn- upcall-handle (defn- upcall-handle
@ -545,21 +537,22 @@
(mem/global-session)))) (mem/global-session))))
(defmethod mem/serialize* ::fn (defmethod mem/serialize* ::fn
[f [_fn arg-types ret-type & {:keys [raw-fn?]}] session] [f [_fn arg-types ret-type & {:keys [raw-fn?]}] arena]
(println "Attempting to serialize function of type" (str ret-type "(*)(" (clojure.string/join "," arg-types) ")"))
(.upcallStub (.upcallStub
(Linker/nativeLinker) (Linker/nativeLinker)
(cond-> f ^MethodHandle (cond-> f
(not raw-fn?) (upcall-serde-wrapper arg-types ret-type) (not raw-fn?) (upcall-serde-wrapper arg-types ret-type)
:always (upcall-handle arg-types ret-type)) :always (upcall-handle arg-types ret-type))
(function-descriptor arg-types ret-type) ^FunctionDescriptor (function-descriptor arg-types ret-type)
session)) ^Arena arena
(make-array Linker$Option 0)))
(defmethod mem/deserialize* ::fn (defmethod mem/deserialize* ::fn
[addr [_fn arg-types ret-type & {:keys [raw-fn?]}]] [addr [_fn arg-types ret-type & {:keys [raw-fn?]}]]
(when-not (mem/null? addr) (when-not (mem/null? addr)
(vary-meta (vary-meta
(-> addr (-> ^MemorySegment addr
(MemorySegment/ofAddress mem/pointer-size (mem/connected-session))
(downcall-handle (function-descriptor arg-types ret-type)) (downcall-handle (function-descriptor arg-types ret-type))
(downcall-fn arg-types ret-type) (downcall-fn arg-types ret-type)
(cond-> (not raw-fn?) (make-serde-wrapper arg-types ret-type))) (cond-> (not raw-fn?) (make-serde-wrapper arg-types ret-type)))
@ -640,9 +633,8 @@
See [[freset!]], [[fswap!]]." See [[freset!]], [[fswap!]]."
[symbol-or-addr type] [symbol-or-addr type]
(StaticVariable. (mem/as-segment (.address (ensure-symbol symbol-or-addr)) (StaticVariable. (.reinterpret ^MemorySegment (ensure-symbol symbol-or-addr)
(mem/size-of type) ^long (mem/size-of type))
(mem/global-session))
type (atom nil))) type (atom nil)))
(defmacro defvar (defmacro defvar

View file

@ -24,7 +24,7 @@
(pos? r) (conj [::padding [::mem/padding (- align r)]]) (pos? r) (conj [::padding [::mem/padding (- align r)]])
:always (conj field)) :always (conj field))
fields)) fields))
(let [strongest-alignment (mem/align-of struct-spec) (let [strongest-alignment (reduce max (map (comp mem/align-of second) (nth struct-spec 1)))
r (rem offset strongest-alignment)] r (rem offset strongest-alignment)]
(cond-> aligned-fields (cond-> aligned-fields
(pos? r) (conj [::padding [::mem/padding (- strongest-alignment r)]])))))] (pos? r) (conj [::padding [::mem/padding (- strongest-alignment r)]])))))]

View file

@ -13,11 +13,7 @@
struct, or array, then [[c-layout]] must be overriden to return the native struct, or array, then [[c-layout]] must be overriden to return the native
layout of the type, and [[serialize-into]] and [[deserialize-from]] should be layout of the type, and [[serialize-into]] and [[deserialize-from]] should be
overriden to allow marshaling values of the type into and out of memory overriden to allow marshaling values of the type into and out of memory
segments. segments."
When writing code that manipulates a segment, it's best practice to
use [[with-acquired]] on the [[segment-session]] in order to ensure it won't be
released during its manipulation."
(:require (:require
[clojure.set :as set] [clojure.set :as set]
[clojure.spec.alpha :as s]) [clojure.spec.alpha :as s])
@ -160,6 +156,17 @@
^Arena [] ^Arena []
(global-arena)) (global-arena))
(defn arena-allocator
"Constructs a [[SegmentAllocator]] from the given [[Arena]].
This is primarily used when working with unwrapped downcall functions. When a
downcall function returns a non-primitive type, it must be provided with an
allocator."
^SegmentAllocator [^Arena scope]
(reify SegmentAllocator
(^MemorySegment allocate [_this ^long byte-size ^long byte-alignment]
(.allocate scope ^long byte-size ^long byte-alignment))))
(defn ^:deprecated session-allocator (defn ^:deprecated session-allocator
"Constructs a segment allocator from the given `session`. "Constructs a segment allocator from the given `session`.
@ -167,7 +174,7 @@
downcall function returns a non-primitive type, it must be provided with an downcall function returns a non-primitive type, it must be provided with an
allocator." allocator."
^SegmentAllocator [^Arena session] ^SegmentAllocator [^Arena session]
(assert false "Segment allocators can no longer be constructed from sessions.")) (arena-allocator session))
(defn ^:deprecated scope-allocator (defn ^:deprecated scope-allocator
"Constructs a segment allocator from the given `scope`. "Constructs a segment allocator from the given `scope`.
@ -176,7 +183,7 @@
downcall function returns a non-primitive type, it must be provided with an downcall function returns a non-primitive type, it must be provided with an
allocator." allocator."
^SegmentAllocator [^Arena scope] ^SegmentAllocator [^Arena scope]
(assert false "Segment allocators can no longer be constructed from scopes.")) (arena-allocator scope))
(defn ^:deprecated segment-session (defn ^:deprecated segment-session
"Gets the memory session used to construct the `segment`." "Gets the memory session used to construct the `segment`."
@ -282,8 +289,7 @@
"Clones the content of `segment` into a new segment of the same size." "Clones the content of `segment` into a new segment of the same size."
(^MemorySegment [segment] (clone-segment segment (connected-session))) (^MemorySegment [segment] (clone-segment segment (connected-session)))
(^MemorySegment [^MemorySegment segment session] (^MemorySegment [^MemorySegment segment session]
(with-acquired [(segment-session segment) session] (copy-segment ^MemorySegment (alloc (.byteSize segment) session) segment)))
(copy-segment ^MemorySegment (alloc (.byteSize segment) session) segment))))
(defn slice-segments (defn slice-segments
"Constructs a lazy seq of `size`-length memory segments, sliced from `segment`." "Constructs a lazy seq of `size`-length memory segments, sliced from `segment`."
@ -338,7 +344,7 @@
"The [[MemoryLayout]] for a c-sized double in [[native-endian]] [[ByteOrder]]." "The [[MemoryLayout]] for a c-sized double in [[native-endian]] [[ByteOrder]]."
ValueLayout/JAVA_DOUBLE) ValueLayout/JAVA_DOUBLE)
(def ^ValueLayout$OfAddress pointer-layout (def ^AddressLayout pointer-layout
"The [[MemoryLayout]] for a native pointer in [[native-endian]] [[ByteOrder]]." "The [[MemoryLayout]] for a native pointer in [[native-endian]] [[ByteOrder]]."
ValueLayout/ADDRESS) ValueLayout/ADDRESS)
@ -548,20 +554,20 @@
(.get segment (.withOrder ^ValueLayout$OfDouble double-layout byte-order) offset))) (.get segment (.withOrder ^ValueLayout$OfDouble double-layout byte-order) offset)))
(defn read-address (defn read-address
"Reads a [[MemoryAddress]] from the `segment`, at an optional `offset`." "Reads an address from the `segment`, at an optional `offset`, wrapped in a [[MemorySegment]]."
{:inline {:inline
(fn read-address-inline (fn read-address-inline
([segment] ([segment]
`(let [segment# ~segment] `(let [segment# ~segment]
(.get ^MemorySegment segment# ^ValueLayout$OfAddress pointer-layout 0))) (.get ^MemorySegment segment# ^AddressLayout pointer-layout 0)))
([segment offset] ([segment offset]
`(let [segment# ~segment `(let [segment# ~segment
offset# ~offset] offset# ~offset]
(.get ^MemorySegment segment# ^ValueLayout$OfAddress pointer-layout offset#))))} (.get ^MemorySegment segment# ^AddressLayout pointer-layout offset#))))}
(^MemoryAddress [^MemorySegment segment] (^MemorySegment [^MemorySegment segment]
(.get segment ^ValueLayout$OfAddress pointer-layout 0)) (.get segment ^AddressLayout pointer-layout 0))
(^MemoryAddress [^MemorySegment segment ^long offset] (^MemorySegment [^MemorySegment segment ^long offset]
(.get segment ^ValueLayout$OfAddress pointer-layout offset))) (.get segment ^AddressLayout pointer-layout offset)))
(defn write-byte (defn write-byte
"Writes a [[byte]] to the `segment`, at an optional `offset`." "Writes a [[byte]] to the `segment`, at an optional `offset`."
@ -746,22 +752,22 @@
(.set segment (.withOrder ^ValueLayout$OfDouble double-layout byte-order) offset value))) (.set segment (.withOrder ^ValueLayout$OfDouble double-layout byte-order) offset value)))
(defn write-address (defn write-address
"Writes a [[MemoryAddress]] to the `segment`, at an optional `offset`." "Writes the address of the [[MemorySegment]] `value` to the `segment`, at an optional `offset`."
{:inline {:inline
(fn write-address-inline (fn write-address-inline
([segment value] ([segment value]
`(let [segment# ~segment `(let [segment# ~segment
value# ~value] value# ~value]
(.set ^MemorySegment segment# ^ValueLayout$OfAddress pointer-layout 0 ^MemorySegment value#))) (.set ^MemorySegment segment# ^AddressLayout pointer-layout 0 ^MemorySegment value#)))
([segment offset value] ([segment offset value]
`(let [segment# ~segment `(let [segment# ~segment
offset# ~offset offset# ~offset
value# ~value] value# ~value]
(.set ^MemorySegment segment# ^ValueLayout$OfAddress pointer-layout offset# ^MemorySegment value#))))} (.set ^MemorySegment segment# ^AddressLayout pointer-layout offset# ^MemorySegment value#))))}
(^MemoryAddress [^MemorySegment segment ^MemoryAddress value] ([^MemorySegment segment ^MemorySegment value]
(.set segment ^ValueLayout$OfAddress pointer-layout 0 value)) (.set segment ^AddressLayout pointer-layout 0 value))
(^MemoryAddress [^MemorySegment segment ^long offset ^MemoryAddress value] ([^MemorySegment segment ^long offset ^MemorySegment value]
(.set segment ^ValueLayout$OfAddress pointer-layout offset value))) (.set segment ^AddressLayout pointer-layout offset value)))
(defn- type-dispatch (defn- type-dispatch
"Gets a type dispatch value from a (potentially composite) type." "Gets a type dispatch value from a (potentially composite) type."
@ -898,7 +904,7 @@
::char Byte/TYPE ::char Byte/TYPE
::float Float/TYPE ::float Float/TYPE
::double Double/TYPE ::double Double/TYPE
::pointer MemoryAddress ::pointer MemorySegment
::void Void/TYPE}) ::void Void/TYPE})
(defn java-layout (defn java-layout
@ -925,8 +931,8 @@
(defn alloc-instance (defn alloc-instance
"Allocates a memory segment for the given `type`." "Allocates a memory segment for the given `type`."
(^MemorySegment [type] (alloc-instance type (connected-session))) (^MemorySegment [type] (alloc-instance type (auto-arena)))
(^MemorySegment [type session] (MemorySegment/allocateNative ^long (size-of type) ^MemorySession session))) (^MemorySegment [type arena] (.allocate ^Arena arena ^long (size-of type) ^long (align-of type))))
(declare serialize serialize-into) (declare serialize serialize-into)
@ -980,12 +986,11 @@
[obj type session] [obj type session]
(if-not (null? obj) (if-not (null? obj)
(if (sequential? type) (if (sequential? type)
(with-acquired [session]
(let [segment (alloc-instance (second type) session)] (let [segment (alloc-instance (second type) session)]
(serialize-into obj (second type) segment session) (serialize-into obj (second type) segment session)
(address-of segment))) (address-of segment))
obj) obj)
(MemoryAddress/NULL))) (MemorySegment/NULL)))
(defmethod serialize* ::void (defmethod serialize* ::void
[_obj _type _session] [_obj _type _session]
@ -1001,10 +1006,7 @@
override [[c-layout]]. override [[c-layout]].
For any other type, this will serialize it as [[serialize*]] before writing For any other type, this will serialize it as [[serialize*]] before writing
the result value into the `segment`. the result value into the `segment`."
Implementations of this should be inside a [[with-acquired]] block for the
`session` if they perform multiple memory operations."
(fn (fn
#_{:clj-kondo/ignore [:unused-binding]} #_{:clj-kondo/ignore [:unused-binding]}
[obj type segment session] [obj type segment session]
@ -1013,8 +1015,7 @@
(defmethod serialize-into :default (defmethod serialize-into :default
[obj type segment session] [obj type segment session]
(if-some [prim-layout (primitive-type type)] (if-some [prim-layout (primitive-type type)]
(with-acquired [(segment-session segment) session] (serialize-into (serialize* obj type session) prim-layout segment session)
(serialize-into (serialize* obj type session) prim-layout segment session))
(throw (ex-info "Attempted to serialize an object to a type that has not been overridden" (throw (ex-info "Attempted to serialize an object to a type that has not been overridden"
{:type type {:type type
:object obj})))) :object obj}))))
@ -1059,11 +1060,10 @@
(defmethod serialize-into ::pointer (defmethod serialize-into ::pointer
[obj type segment session] [obj type segment session]
(with-acquired [(segment-session segment) session]
(write-address (write-address
segment segment
(cond-> obj (cond-> obj
(sequential? type) (serialize* type session))))) (sequential? type) (serialize* type session))))
(defn serialize (defn serialize
"Serializes an arbitrary type. "Serializes an arbitrary type.
@ -1085,10 +1085,7 @@
"Deserializes the given segment into a Clojure data structure. "Deserializes the given segment into a Clojure data structure.
For types that serialize to primitives, a default implementation will For types that serialize to primitives, a default implementation will
deserialize the primitive before calling [[deserialize*]]. deserialize the primitive before calling [[deserialize*]]."
Implementations of this should be inside a [[with-acquired]] block for the the
`segment`'s session if they perform multiple memory operations."
(fn (fn
#_{:clj-kondo/ignore [:unused-binding]} #_{:clj-kondo/ignore [:unused-binding]}
[segment type] [segment type]
@ -1144,9 +1141,8 @@
(defmethod deserialize-from ::pointer (defmethod deserialize-from ::pointer
[segment type] [segment type]
(with-acquired [(segment-session segment)]
(cond-> (read-address segment) (cond-> (read-address segment)
(sequential? type) (deserialize* type)))) (sequential? type) (deserialize* type)))
(defmulti deserialize* (defmulti deserialize*
"Deserializes a primitive object into a Clojure data structure. "Deserializes a primitive object into a Clojure data structure.
@ -1196,8 +1192,11 @@
[addr type] [addr type]
(when-not (null? addr) (when-not (null? addr)
(if (sequential? type) (if (sequential? type)
(deserialize-from (as-segment addr (size-of (second type))) (let [target-type (second type)]
(second type)) (deserialize-from
(.reinterpret ^MemorySegment (read-address addr)
^long (size-of target-type))
target-type))
addr))) addr)))
(defmethod deserialize* ::void (defmethod deserialize* ::void
@ -1219,8 +1218,7 @@
(defn seq-of (defn seq-of
"Constructs a lazy sequence of `type` elements deserialized from `segment`." "Constructs a lazy sequence of `type` elements deserialized from `segment`."
[type segment] [type segment]
(with-acquired [(segment-session segment)] (map #(deserialize % type) (slice-segments segment (size-of type))))
(map #(deserialize % type) (slice-segments segment (size-of type)))))
;;; Raw composite types ;;; Raw composite types
;; TODO(Joshua): Ensure that all the raw values don't have anything happen on ;; TODO(Joshua): Ensure that all the raw values don't have anything happen on
@ -1251,13 +1249,13 @@
(defmethod serialize* ::c-string (defmethod serialize* ::c-string
[obj _type session] [obj _type session]
(if obj (if obj
(address-of (.allocateUtf8String (session-allocator session) ^String obj)) (address-of (.allocateUtf8String (arena-allocator session) ^String obj))
(MemoryAddress/NULL))) (MemorySegment/NULL)))
(defmethod deserialize* ::c-string (defmethod deserialize* ::c-string
[addr _type] [addr _type]
(when-not (null? addr) (when-not (null? addr)
(.getUtf8String ^MemoryAddress addr 0))) (.getUtf8String (.reinterpret ^MemorySegment addr Integer/MAX_VALUE) 0)))
;;; Union types ;;; Union types
@ -1328,7 +1326,7 @@
(defmethod c-layout ::padding (defmethod c-layout ::padding
[[_padding size]] [[_padding size]]
(MemoryLayout/paddingLayout (* 8 size))) (MemoryLayout/paddingLayout size))
(defmethod serialize-into ::padding (defmethod serialize-into ::padding
[_obj [_padding _size] _segment _session] [_obj [_padding _size] _segment _session]