Allow multiple function tails in defcfn

This commit is contained in:
Joshua Suskalo 2021-09-17 12:53:27 -05:00
parent e8a3c555bb
commit 3030fc66cb

View file

@ -692,18 +692,26 @@
(make-downcall symbol args ret))))) (make-downcall symbol args ret)))))
(s/def ::defcfn-args (s/def ::defcfn-args
(s/cat :name simple-symbol? (s/and
:doc (s/? string?) (s/cat :name simple-symbol?
:attr-map (s/? map?) :doc (s/? string?)
:symbol (s/nonconforming :attr-map (s/? map?)
(s/or :string string? :symbol (s/nonconforming
:symbol simple-symbol?)) (s/or :string string?
:native-arglist (s/coll-of qualified-keyword? :kind vector?) :symbol simple-symbol?))
:return-type qualified-keyword? :native-arglist (s/coll-of qualified-keyword? :kind vector?)
:fn-tail (s/? :return-type qualified-keyword?
(s/nonconforming :wrapper (s/?
(s/cat :arglist (s/coll-of simple-symbol? :kind vector?) (s/cat
:body (s/* any?)))))) :native-fn simple-symbol?
:fn-tail (let [fn-tail (s/cat :arglist (s/coll-of simple-symbol? :kind vector?)
:body (s/* any?))]
(s/alt
:single-arity fn-tail
:multi-arity (s/+ (s/spec fn-tail)))))))
#(if (:wrapper %)
(not= (:name %) (-> % :wrapper :native-fn))
true)))
(defmacro defcfn (defmacro defcfn
"Defines a Clojure function which maps to a native function. "Defines a Clojure function which maps to a native function.
@ -713,10 +721,10 @@
`arg-types` is a vector of qualified keywords representing the argument types. `arg-types` is a vector of qualified keywords representing the argument types.
`ret-type` is a single qualified keyword representing the return type. `ret-type` is a single qualified keyword representing the return type.
`fn-tail` is the body of the function (potentially with multiple arities) `fn-tail` is the body of the function (potentially with multiple arities)
which wraps the native one. Inside the function, `name` is bound to a function which wraps the native one. Inside the function, `native-fn` is bound to a
that will serialize its arguments, call the native function, and deserialize function that will serialize its arguments, call the native function, and
its return type. If any body is present, you must call this function in order deserialize its return type. If any body is present, you must call this
to call the native code. function in order to call the native code.
If no `fn-tail` is provided, then the resulting function will simply serialize If no `fn-tail` is provided, then the resulting function will simply serialize
the arguments according to `arg-types`, call the native function, and the arguments according to `arg-types`, call the native function, and
@ -728,38 +736,46 @@
See [[serialize]], [[deserialize]], [[make-downcall]]." See [[serialize]], [[deserialize]], [[make-downcall]]."
{:arglists '([name docstring? attr-map? symbol arg-types ret-type] {:arglists '([name docstring? attr-map? symbol arg-types ret-type]
[name docstring? attr-map? symbol arg-types ret-type & fn-tail])} [name docstring? attr-map? symbol arg-types ret-type native-fn & fn-tail])}
[& args] [& args]
(let [args (s/conform ::defcfn-args args) (let [args (s/conform ::defcfn-args args)
scope (gensym "scope") scope (gensym "scope")
arg-syms (repeatedly (count (:native-arglist args)) #(gensym "arg")) arg-syms (repeatedly (count (:native-arglist args)) #(gensym "arg"))
arg-types (repeatedly (count (:native-arglist args)) #(gensym "arg-type")) arg-types (repeatedly (count (:native-arglist args)) #(gensym "arg-type"))
ret-type (gensym "ret-type") ret-type (gensym "ret-type")
invoke (gensym "invoke")] invoke (gensym "invoke")
native-sym (gensym "native")
[arity fn-tail] (-> args :wrapper :fn-tail)
fn-tail (case arity
:single-arity (cons (:arglist fn-tail) (:body fn-tail))
:multi-arity (map #(cons (:arglist %) (:body %)) fn-tail)
nil)
arglists (map first (case arity
:single-arity [fn-tail]
:multi-arity fn-tail
nil))]
`(let [args-types# ~(:native-arglist args) `(let [args-types# ~(:native-arglist args)
[~@arg-types] args-types# [~@arg-types] args-types#
~ret-type ~(:return-type args) ~ret-type ~(:return-type args)
~invoke (-> (find-symbol ~(name (:symbol args))) ~invoke (make-downcall ~(name (:symbol args)) args-types# ~ret-type)
(downcall-handle ~(or (-> args :wrapper :native-fn) native-sym)
(method-type args-types# ~ret-type) ~(if (and (every? #(= % (primitive-type %))
(function-descriptor args-types# ~ret-type)) (:native-arglist args))
(downcall-fn args-types# ~ret-type)) (= (:return-type args)
~(:name args) ~(if (and (every? #(= % (primitive-type %)) (primitive-type (:return-type args))))
(:native-arglist args)) invoke
(= (:return-type args) `(fn [~@arg-syms]
(primitive-type (:return-type args)))) (with-open [~scope (stack-scope)]
invoke (deserialize (~invoke
`(fn [~@arg-syms] ~@(map
(with-open [~scope (stack-scope)] (fn [sym type]
(deserialize (~invoke `(serialize ~sym ~type ~scope))
~@(map arg-syms arg-types))
(fn [sym type] ~ret-type))))
`(serialize ~sym ~type ~scope)) fun# ~(if (:wrapper args)
arg-syms arg-types)) `(fn ~(:name args)
~ret-type)))) ~@fn-tail)
fun# ~(if (:fn-tail args) native-sym)]
`(fn ~@(:fn-tail args))
(:name args))]
(def (def
~(with-meta (:name args) ~(with-meta (:name args)
(merge (update (meta (:name args)) :arglists (merge (update (meta (:name args)) :arglists
@ -767,10 +783,10 @@
(list (list
'quote 'quote
(or old-list (or old-list
(seq arglists)
(list (list
(or (-> args :fn-tail :arglist) (mapv (comp symbol name)
(mapv (comp symbol name) (:native-arglist args)))))))
(:native-arglist args))))))))
(:attr-map args))) (:attr-map args)))
~@(list (:doc args)) ~@(list (:doc args))
fun#)))) fun#))))
@ -794,10 +810,11 @@
(defcfn some-func (defcfn some-func
"Gets some output value" "Gets some output value"
"someFunc" [::pointer] ::int "someFunc" [::pointer] ::int
native-func
[] []
(with-open [scope (stack-scope)] (with-open [scope (stack-scope)]
(let [out-int (alloc-instance ::int scope) (let [out-int (alloc-instance ::int scope)
success? (zero? (some-func (address-of out-int)))] success? (zero? (native-func (address-of out-int)))]
(if success? (if success?
(deserialize-from ::int out-int) (deserialize-from ::int out-int)
(throw (ex-info (getErrorString) {})))))) (throw (ex-info (getErrorString) {}))))))