Skip to content
This repository has been archived by the owner on Jan 18, 2025. It is now read-only.

Commit

Permalink
Minor changes to Python binding generator for cleanliness.
Browse files Browse the repository at this point in the history
This hides all of our internal functions and moves the imports to
the top level to stop linter complaints in downstream packages.
  • Loading branch information
LTLA committed Sep 8, 2023
1 parent 6fb014a commit 6c18353
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
14 changes: 9 additions & 5 deletions src/cpptypes/create_py_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ def create_py_bindings(all_functions: dict, output_path: str, dll_prefix: str):
import os
import ctypes as ct
""")
if with_numpy:
handle.write("""import numpy as np
""")

def catch_errors(f):
handle.write("""
def _catch_errors(f):
def wrapper(*args):
errcode = ct.c_int32(0)
errmsg = ct.c_char_p(0)
Expand Down Expand Up @@ -133,8 +138,7 @@ def wrapper(*args):
if with_numpy:
handle.write("""
import numpy as np
def np2ct(x, expected, contiguous=True):
def _np2ct(x, expected, contiguous=True):
if not isinstance(x, np.ndarray):
raise ValueError('expected a NumPy array')
if x.dtype != expected:
Expand Down Expand Up @@ -185,13 +189,13 @@ def np2ct(x, expected, contiguous=True):
args += x.type.base_type
if "non_contig" in x.type.tags:
args += ", contiguous=False"
argnames2.append("np2ct(" + x.name + args + ")")
argnames2.append("_np2ct(" + x.name + args + ")")
else:
argnames2.append(x.name)
else:
argnames2 = argnames

handle.write("\n\ndef " + k + "(" + ", ".join(argnames) + """):
return catch_errors(lib.py_""" + k + ")(" + ", ".join(argnames2) + """)""")
return _catch_errors(lib.py_""" + k + ")(" + ", ".join(argnames2) + """)""")

return
8 changes: 4 additions & 4 deletions tests/test_create_py_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def test_create_py_bindings():

elif line == "def cocoa_best_function(chino, rize, midori):":
found_cocoa_def = True
assert handle.readline().strip() == "return catch_errors(lib.py_cocoa_best_function)(np2ct(chino, np.float64), rize, midori)"
assert handle.readline().strip() == "return _catch_errors(lib.py_cocoa_best_function)(_np2ct(chino, np.float64), rize, midori)"

elif line == "def syaro_second_function(chiya, moka, maya, megu):":
found_syaro_def = True
assert handle.readline().strip() == "return catch_errors(lib.py_syaro_second_function)(chiya, moka, np2ct(maya, np.int32, contiguous=False), np2ct(megu, np.float32))"
assert handle.readline().strip() == "return _catch_errors(lib.py_syaro_second_function)(chiya, moka, _np2ct(maya, np.int32, contiguous=False), _np2ct(megu, np.float32))"

assert found_cocoa_types
assert found_syaro_types
Expand Down Expand Up @@ -153,11 +153,11 @@ def test_create_py_bindings():

elif line == "def cocoa_best_function(chino, rize, midori):":
found_cocoa_def = True
assert handle.readline().strip() == "return catch_errors(lib.py_cocoa_best_function)(chino, rize, midori)"
assert handle.readline().strip() == "return _catch_errors(lib.py_cocoa_best_function)(chino, rize, midori)"

elif line == "def syaro_second_function(chiya, moka, maya, megu):":
found_syaro_def = True
assert handle.readline().strip() == "return catch_errors(lib.py_syaro_second_function)(chiya, moka, maya, megu)"
assert handle.readline().strip() == "return _catch_errors(lib.py_syaro_second_function)(chiya, moka, maya, megu)"

assert found_cocoa_types
assert found_syaro_types
Expand Down

0 comments on commit 6c18353

Please sign in to comment.