about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2023-06-22 15:47:33 +0300
committerzsloan2023-06-22 10:18:31 -0500
commit18e2c59a2eb9b1bf952bec6ddfec0cd1abc7cc89 (patch)
treec98f664a3702c505a7d8f5d6f8f7dbf60035c684
parentb36758b1de6e7609129359d7f48a92558834e22d (diff)
downloadgenenetwork2-18e2c59a2eb9b1bf952bec6ddfec0cd1abc7cc89.tar.gz
Pass in the URI to the database
Pass in the URI to the database, rather than coupling the
`database_connection` function to the application environment and
settings.
-rw-r--r--wqflask/wqflask/database.py22
1 files changed, 2 insertions, 20 deletions
diff --git a/wqflask/wqflask/database.py b/wqflask/wqflask/database.py
index 663e2ebf..ad2d8216 100644
--- a/wqflask/wqflask/database.py
+++ b/wqflask/wqflask/database.py
@@ -14,24 +14,6 @@ class Connection(Protocol):
     def cursor(self) -> Any:
         ...
 
-
-def read_from_pyfile(pyfile: str, setting: str) -> Any:
-    orig_sys_path = sys.path[:]
-    sys.path.insert(0, os.path.dirname(pyfile))
-    module = importlib.import_module(os.path.basename(pyfile).strip(".py"))
-    sys.path = orig_sys_path[:]
-    return module.__dict__.get(setting)
-
-
-def get_setting(setting: str) -> str:
-    """Read setting from the environment or settings file."""
-    return os.environ.get(
-        setting, read_from_pyfile(
-            os.environ.get(
-                "GN2_SETTINGS", os.path.abspath("../etc/default_settings.py")),
-            setting))
-
-
 def parse_db_url(sql_uri: str) -> Tuple:
     """
     Parse SQL_URI env variable from an sql URI
@@ -44,14 +26,14 @@ def parse_db_url(sql_uri: str) -> Tuple:
 
 
 @contextlib.contextmanager
-def database_connection() -> Iterator[Connection]:
+def database_connection(sql_uri: str) -> Iterator[Connection]:
     """Provide a context manager for opening, closing, and rolling
     back - if supported - a database connection.  Should an error occur,
     and if the table supports transactions, the connection will be
     rolled back.
 
     """
-    host, user, passwd, db_name, port = parse_db_url(get_setting("SQL_URI"))
+    host, user, passwd, db_name, port = parse_db_url(sql_uri)
     connection = MySQLdb.connect(
         db=db_name, user=user, passwd=passwd or '', host=host,
         port=(port or 3306), autocommit=False  # Required for roll-backs