Commit 61bd3e18 authored by echel0n's avatar echel0n

Refactored database `with_session` staticmethod to classmethod.

Renamed old config encrypt/decrypt function names to legacy.
Refactored new config encrypt/decrypt code to functions.
parent 90518a7f
This diff is collapsed.
......@@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with SiCKRAGE. If not, see <http://www.gnu.org/licenses/>.
import datetime
import functools
import os
import pickle
import shutil
......@@ -76,7 +77,7 @@ class ContextSession(sqlalchemy.orm.Session):
for i in range(self.max_attempts):
try:
self.commit()
except OperationalError:
except OperationalError as e:
sickrage.app.log.debug('Retrying database commit, attempt {}'.format(i))
self.rollback()
sleep(1)
......@@ -96,9 +97,23 @@ class ContextSession(sqlalchemy.orm.Session):
self.safe_commit(close=True)
class SRDatabaseBase(object):
def as_dict(self):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
def update(self, **kwargs):
primary_keys = [pk.name for pk in self.__table__.primary_key]
for key, value in kwargs.items():
if key not in primary_keys:
setattr(self, key, value)
class SRDatabase(object):
def __init__(self, name, db_type='sqlite', db_prefix='sickrage', db_host='localhost', db_port='3306', db_username='sickrage', db_password='sickrage'):
session = sessionmaker(class_=ContextSession)
def __init__(self, name, db_version=0, db_type='sqlite', db_prefix='sickrage', db_host='localhost', db_port='3306', db_username='sickrage', db_password='sickrage'):
self.name = name
self.db_version = db_version
self.db_type = db_type
self.db_prefix = db_prefix
self.db_host = db_host
......@@ -111,6 +126,8 @@ class SRDatabase(object):
self.db_path = os.path.join(sickrage.app.data_dir, '{}.db'.format(self.name))
self.db_repository = os.path.join(os.path.dirname(__file__), self.name, 'db_repository')
self.session.configure(bind=self.engine)
if not self.version:
api.version_control(self.engine, self.db_repository, api.version(self.db_repository))
else:
......@@ -119,6 +136,28 @@ class SRDatabase(object):
except DatabaseAlreadyControlledError:
pass
@classmethod
def with_session(cls, *args, **kwargs):
def decorator(func):
def wrapper(*args, **kwargs):
if kwargs.get('session'):
return func(*args, **kwargs)
with _Session() as session:
kwargs['session'] = session
return func(*args, **kwargs)
return wrapper
if len(args) == 1 and not kwargs and callable(args[0]):
# Used without arguments, e.g. @with_session
# We default to expire_on_commit being false, in case the decorated function returns db instances
_Session = functools.partial(cls.session, expire_on_commit=False)
return decorator(args[0])
else:
# Arguments were specified, turn them into arguments for Session creation e.g. @with_session(autocommit=True)
_Session = functools.partial(cls.session, *args, **kwargs)
return decorator
@property
def engine(self):
if self.db_type == 'sqlite':
......@@ -131,10 +170,6 @@ class SRDatabase(object):
'mysql+pymysql://{}:{}@{}:{}/{}_{}'.format(self.db_username, self.db_password, self.db_host, self.db_port, self.db_prefix, self.name),
echo=False)
@property
def session(self):
return self.session
@property
def version(self):
try:
......
......@@ -15,61 +15,29 @@
#
# You should have received a copy of the GNU General Public License
# along with SiCKRAGE. If not, see <http://www.gnu.org/licenses/>.
import functools
import time
from sqlalchemy import Column, Integer, Text, String
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.orm import sessionmaker
from sickrage.core.databases import SRDatabase, ContextSession
from sickrage.core.databases import SRDatabase, SRDatabaseBase, ContextSession
@as_declarative()
class CacheDBBase(object):
def as_dict(self):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
def update(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class CacheDBBase(SRDatabaseBase):
pass
class CacheDB(SRDatabase):
db_version = 4
session = sessionmaker(class_=ContextSession)
def __init__(self, db_type, db_prefix, db_host, db_port, db_username, db_password):
super(CacheDB, self).__init__('cache', db_type, db_prefix, db_host, db_port, db_username, db_password)
CacheDB.session.configure(bind=self.engine)
super(CacheDB, self).__init__('cache', 4, db_type, db_prefix, db_host, db_port, db_username, db_password)
CacheDBBase.metadata.create_all(self.engine)
for model in CacheDBBase._decl_class_registry.values():
if hasattr(model, '__tablename__'):
self.tables[model.__tablename__] = model
@staticmethod
def with_session(*args, **kwargs):
def decorator(func):
def wrapper(*args, **kwargs):
if kwargs.get('session'):
return func(*args, **kwargs)
with _Session() as session:
kwargs['session'] = session
return func(*args, **kwargs)
return wrapper
if len(args) == 1 and not kwargs and callable(args[0]):
# Used without arguments, e.g. @with_session
# We default to expire_on_commit being false, in case the decorated function returns db instances
_Session = functools.partial(CacheDB.session, expire_on_commit=False)
return decorator(args[0])
else:
# Arguments were specified, turn them into arguments for Session creation e.g. @with_session(autocommit=True)
_Session = functools.partial(CacheDB.session, *args, **kwargs)
return decorator
def cleanup(self):
def remove_duplicates_from_last_search_table():
found = []
......
......@@ -15,62 +15,29 @@
#
# You should have received a copy of the GNU General Public License
# along with SiCKRAGE. If not, see <http://www.gnu.org/licenses/>.
import functools
from sqlalchemy import Column, Integer, Text, ForeignKeyConstraint, String, DateTime
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm import sessionmaker
from sickrage.core.databases import SRDatabase, ContextSession
from sickrage.core.databases import SRDatabase, SRDatabaseBase, ContextSession
@as_declarative()
class MainDBBase(object):
def as_dict(self):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
def update(self, **kwargs):
primary_keys = [pk.name for pk in self.__table__.primary_key]
for key, value in kwargs.items():
if key not in primary_keys:
setattr(self, key, value)
class MainDBBase(SRDatabaseBase):
pass
class MainDB(SRDatabase):
db_version = 10
session = sessionmaker(class_=ContextSession)
def __init__(self, db_type, db_prefix, db_host, db_port, db_username, db_password):
super(MainDB, self).__init__('main', db_type, db_prefix, db_host, db_port, db_username, db_password)
MainDB.session.configure(bind=self.engine)
super(MainDB, self).__init__('main', 10, db_type, db_prefix, db_host, db_port, db_username, db_password)
MainDBBase.metadata.create_all(self.engine)
for model in MainDBBase._decl_class_registry.values():
if hasattr(model, '__tablename__'):
self.tables[model.__tablename__] = model
@staticmethod
def with_session(*args, **kwargs):
def decorator(func):
def wrapper(*args, **kwargs):
if kwargs.get('session'):
return func(*args, **kwargs)
with _Session() as session:
kwargs['session'] = session
return func(*args, **kwargs)
return wrapper
if len(args) == 1 and not kwargs and callable(args[0]):
# Used without arguments, e.g. @with_session
# We default to expire_on_commit being false, in case the decorated function returns db instances
_Session = functools.partial(MainDB.session, expire_on_commit=False)
return decorator(args[0])
else:
# Arguments were specified, turn them into arguments for Session creation e.g. @with_session(autocommit=True)
_Session = functools.partial(MainDB.session, *args, **kwargs)
return decorator
class IMDbInfo(MainDBBase):
__tablename__ = 'imdb_info'
__table_args__ = (
......
......@@ -103,7 +103,6 @@ def snatch_episode(result, end_status=SNATCHED, session=None):
sickrage.app.alerts.message(_('Episode snatched'), result.name)
# don't notify when we re-download an episode
trakt_data = []
for episode_number in result.episodes:
episode_obj = show_object.get_episode(result.season, episode_number)
......@@ -115,6 +114,7 @@ def snatch_episode(result, end_status=SNATCHED, session=None):
session.safe_commit()
# don't notify when we re-download an episode
if episode_obj.status not in Quality.DOWNLOADED:
try:
Notifiers.mass_notify_snatch(episode_obj._format_pattern('%SN - %Sx%0E - %EN - %QN') + " from " + result.provider.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment