diff --git a/Mirador_backend/app.py b/Mirador_backend/app.py index 5c652c5d664be3e80522dbbaedb1dcc666158835..9ac4c9337c118174b9e3443e9829bf4235386cad 100644 --- a/Mirador_backend/app.py +++ b/Mirador_backend/app.py @@ -6,20 +6,22 @@ from flask_apispec.extension import FlaskApiSpec from Mirador_backend.routes import v1 from Mirador_backend.utils.config import ConfigFactory from Mirador_backend.utils.database import db +from Mirador_backend.utils.app import setApp app = Flask(__name__) +setApp(app) api = Api(app) docs = FlaskApiSpec(app) -app.config.from_object(ConfigFactory.getConfig()) -app.db = db +app.config.from_object(ConfigFactory()) +app.db = db(app.config['DB'], app.config['DEBUG']) v1.register_routes(api, docs) @app.teardown_appcontext def shutdown_session(exception=None): - db.close() + app.db.close() if __name__ == '__main__': diff --git a/Mirador_backend/models.py b/Mirador_backend/models.py index 5847f07b2286dd32c282f6568169871768a89a3c..c35529c34ba82bc88086b71c0c3d8e50d95a5e59 100644 --- a/Mirador_backend/models.py +++ b/Mirador_backend/models.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, JSON from sqlalchemy import select from sqlalchemy.exc import NoResultFound import json -from Mirador_backend.utils.database import db +from Mirador_backend.utils.app import getApp class NotFound(NoResultFound): @@ -16,7 +16,10 @@ class BaseModel(DeclarativeBase): """A generic model class """ - session = db.db_session + session = getApp().db.session + """Access the db session""" + + engine = getApp().db.engine """Access the db session""" fillable = [] @@ -30,9 +33,13 @@ class BaseModel(DeclarativeBase): } """Actual data fields""" - def create_tables(): + def create_all(): """Creates all tables for all models""" - BaseModel.metadata.create_all(db.engine) + BaseModel.metadata.create_all(getApp().db.engine) + + def drop_all(): + """drop all tables for all models""" + BaseModel.metadata.drop_all(getApp().db.engine) @classmethod def query(cls, stmt): diff --git a/Mirador_backend/tests/test_mirador_resource.py b/Mirador_backend/tests/test_mirador_resource.py index 37da0e2066a14383a8a473bf7fccb3135a400ad2..a0a05581b9dbbe998773de46cb1e65241858c1d1 100644 --- a/Mirador_backend/tests/test_mirador_resource.py +++ b/Mirador_backend/tests/test_mirador_resource.py @@ -4,15 +4,21 @@ from Mirador_backend.tests.tester import TestCase class MiradorResourceTest(TestCase): + def getData(self): + return self.getFixtureRecordsByModel(self.fixtures[0], 'Mirador_backend.models.MiradorResource') + def testGetOne(self): + id = 1 payload = json.dumps({}) - response = self.client.get(self.base + '/1', headers={"Content-Type": "application/json"}, data=payload) + response = self.client.get(f'{self.base}/{id}', headers={"Content-Type": "application/json"}, data=payload) self.assertEqual(200, response.status_code) - self.assertEqual(None, response.json) + + expected = next(x for x in self.getData() if x['id'] == id) + self.assertEqual(expected, response.json) def testGetAll(self): payload = json.dumps({}) response = self.client.get(self.base, headers={"Content-Type": "application/json"}, data=payload) self.assertEqual(200, response.status_code) - self.assertEqual([], response.json) + self.assertEqual({str(x['id']): x for x in self.getData()}, response.json) diff --git a/Mirador_backend/tests/tester.py b/Mirador_backend/tests/tester.py index b1aae8b2bd9dfa96c6934171eb957fb62b89f32d..57a9bff763954ead9f7f854f7e33bec6cc55d424 100644 --- a/Mirador_backend/tests/tester.py +++ b/Mirador_backend/tests/tester.py @@ -1,18 +1,28 @@ import unittest +import json +from os.path import dirname, abspath +from flask_fixtures import FixturesMixin from Mirador_backend.app import app -from flask_fixtures import FixturesMixin +from Mirador_backend.models import BaseModel class TestCase(unittest.TestCase, FixturesMixin): fixtures = ['test.json'] app = app - db = app.db + db = BaseModel def setUp(self): self.base = 'v1/mirador_resource' self.client = app.test_client() + def getFixture(self, name): + with open(f'{dirname(abspath(__file__))}/fixtures/{name}') as f: + return json.load(f) + + def getFixtureRecordsByModel(self, name, model): + return next(x['records'] for x in self.getFixture(name) if x['model'] == model) + def tearDown(self): # TODO pass diff --git a/Mirador_backend/utils/app.py b/Mirador_backend/utils/app.py new file mode 100644 index 0000000000000000000000000000000000000000..71f6b56ec5782e1acf0afb0376b8cb088d70f74b --- /dev/null +++ b/Mirador_backend/utils/app.py @@ -0,0 +1,9 @@ +def getApp(): + return getApp.app + + +getApp.app = None + + +def setApp(app): + getApp.app = app diff --git a/Mirador_backend/utils/config.py b/Mirador_backend/utils/config.py index a7f9976b9863113db59ade31fa655abc2030ca2b..002af568656fde59a17a71b24f981e74ee97f8fe 100644 --- a/Mirador_backend/utils/config.py +++ b/Mirador_backend/utils/config.py @@ -14,7 +14,7 @@ class Config(object): APISPEC_SWAGGER_UI_URL = '/swagger-ui/' # URI to access UI of API Doc ENV = getenv('ENV', 'dev') TESTING = False - DEBUG = False + DEBUG = getenv('ENV', 'dev') != 'prod' DB = { 'TYPE': getenv('DB_TYPE', 'mariadb+mariadbconnector'), 'USER': getenv('MYSQL_USER', 'user'), @@ -26,14 +26,10 @@ class Config(object): class TestingConfig(Config): TESTING = True + DEBUG = True DB = None FIXTURES_DIRS = ['tests/fixtures'] -class ConfigFactory(): - config = None - - def getConfig(): - if ConfigFactory.config is None: - ConfigFactory.config = TestingConfig() if getenv('ENV', 'dev') == 'test' else Config() - return ConfigFactory.config +def ConfigFactory(): + return TestingConfig if getenv('ENV', 'dev') == 'test' else Config diff --git a/Mirador_backend/utils/database.py b/Mirador_backend/utils/database.py index 20bbc19e2ad29b501d6c080987ecb494245d286d..8d35a0d9336999665dea11abc82a34d1e10736dc 100644 --- a/Mirador_backend/utils/database.py +++ b/Mirador_backend/utils/database.py @@ -1,20 +1,20 @@ from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker -from Mirador_backend.utils.config import ConfigFactory - - -def getDBUri(): - params = ConfigFactory.getConfig().DB - if params is None: - return 'sqlite://' - return f'{params["TYPE"]}://{params["USER"]}:{params["PASSWORD"]}@{params["HOST"]}/{params["BASE"]}' class db(): - engine = create_engine(getDBUri(), echo=ConfigFactory.getConfig().DEBUG) - db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) + def __getDBUri(self): + params = self.config + if params is None: + return 'sqlite://' + return f'{params["TYPE"]}://{params["USER"]}:{params["PASSWORD"]}@{params["HOST"]}/{params["BASE"]}' + + def __init__(self, config, debug): + self.config = config + self.engine = create_engine(self.__getDBUri(), echo=debug) + self.session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=self.engine)) - def close(): - db.db_session.remove() + def close(self): + self.session.remove() diff --git a/cli b/cli index 60329d4a7ceef40382cbc9a209fc0787a1023414..d40535c1a16acef24889953e6749dec17cb27142 100755 --- a/cli +++ b/cli @@ -69,7 +69,7 @@ case $action in set +x ;; "mysql_init") - $cmd python -c 'from Mirador_backend.models import BaseModel; BaseModel.create_tables()' + $cmd python -c 'from Mirador_backend.models import BaseModel; BaseModel.create_all()' ;; "shell") $cmd ipython