Commit f0b27988 authored by echel0n's avatar echel0n
Browse files

Added sso_api_key column to config general table

Added API login method for SiCKRAGE external API using new SSO API key
Removed SiCKRAGE SSO offline token usage in favour of SiCKRAGE SSO API key
Improved OAuth2 handling for external SiCKRAGE API
Added external series provider API support
Added migration code to convert lang column on tv_shows table to ISO639-3
Added migration code to convert offline token to apikey
parent 3935f29f
This diff is collapsed.
......@@ -49,12 +49,6 @@ class AMQPBase(object):
IOLoop.current().call_later(5, self.reconnect)
return
# refresh api token if needed
if sickrage.app.api.token_time_remaining < (int(sickrage.app.api.token['expires_in']) / 2):
if not sickrage.app.api.refresh_token():
IOLoop.current().call_later(5, self.reconnect)
return
# declare server amqp queue
if not sickrage.app.api.server.declare_amqp_queue(sickrage.app.config.general.server_id):
IOLoop.current().call_later(5, self.reconnect)
......
......@@ -8,13 +8,11 @@ import oauthlib.oauth2
import requests
import requests.exceptions
from jose import ExpiredSignatureError
from oauthlib.oauth2 import MissingTokenError, InvalidGrantError, InvalidClientIdError
from keycloak.exceptions import KeycloakClientError
from requests_oauthlib import OAuth2Session
from sqlalchemy import orm
import sickrage
from sickrage.core.api.exceptions import APIError
from sickrage.core.databases.cache import CacheDB
class API(object):
......@@ -22,11 +20,7 @@ class API(object):
self.name = 'SR-API'
self.api_base = 'https://www.sickrage.ca/api/'
self.api_version = 'v6'
self._session = None
@property
def is_enabled(self):
return self.token
self._token = {}
@property
def imdb(self):
......@@ -37,8 +31,12 @@ class API(object):
return self.ServerAPI(self)
@property
def provider(self):
return self.ProviderAPI(self)
def search_provider(self):
return self.SearchProviderAPI(self)
@property
def series_provider(self):
return self.SeriesProviderAPI(self)
@property
def announcement(self):
......@@ -62,63 +60,36 @@ class API(object):
@property
def session(self):
extra = {
'client_id': sickrage.app.auth_server.client_id,
}
if not self._session and self.token_url:
self._session = OAuth2Session(token=self.token, auto_refresh_kwargs=extra, auto_refresh_url=self.token_url, token_updater=self.token_updater)
if not self.token_url:
return
return self._session
return OAuth2Session(
token=self.token,
auto_refresh_kwargs={'client_id': sickrage.app.auth_server.client_id},
auto_refresh_url=self.token_url,
token_updater=self.token_updater
)
@property
def token(self):
session = sickrage.app.cache_db.session()
try:
token = session.query(CacheDB.OAuth2Token).one()
return token.as_dict()
except orm.exc.NoResultFound:
return {}
@token.setter
def token(self, value):
new_token = {
'access_token': value.get('access_token'),
'refresh_token': value.get('refresh_token'),
'expires_in': value.get('expires_in'),
'session_state': value.get('session_state'),
'token_type': value.get('token_type'),
'expires_at': value.get('expires_at', int(time.time() + value.get('expires_in'))),
'scope': value.scope if isinstance(value, oauthlib.oauth2.OAuth2Token) else value.get('scope'),
}
if not self._token:
self.login()
elif self.token_time_remaining < (int(self._token.get('expires_in')) / 2):
self.refresh_token()
session = sickrage.app.cache_db.session()
try:
token = session.query(CacheDB.OAuth2Token).one()
token.update(**new_token)
except orm.exc.NoResultFound:
session.add(CacheDB.OAuth2Token(**new_token))
finally:
session.commit()
self._session = None
@token.deleter
def token(self):
session = sickrage.app.cache_db.session()
session.query(CacheDB.OAuth2Token).delete()
session.commit()
return self._token
@property
def token_expiration(self):
try:
if not self._token:
return time.time()
certs = sickrage.app.auth_server.certs()
if not certs:
return time.time()
decoded_token = sickrage.app.auth_server.decode_token(self.token['access_token'], certs)
decoded_token = sickrage.app.auth_server.decode_token(self._token.get('access_token'), certs)
return decoded_token.get('exp', time.time())
except ExpiredSignatureError:
return time.time()
......@@ -158,31 +129,49 @@ class API(object):
return self.request('GET', 'userinfo')
def token_updater(self, value):
self.token = value
self._token = value
def logout(self):
sickrage.app.auth_server.logout(self.token.get('refresh_token'))
def login(self):
if not self.health:
return False
def refresh_token(self):
extra = {
if not self.token_url:
return False
session = requests.session()
data = {
'client_id': sickrage.app.auth_server.client_id,
'grant_type': 'password',
'apikey': sickrage.app.config.general.sso_api_key
}
if self.token_url:
try:
resp = session.post(self.token_url, data)
resp.raise_for_status()
self._token = resp.json()
except requests.exceptions.RequestException:
return False
return True
def logout(self):
if self._token:
try:
client = OAuth2Session(sickrage.app.auth_server.client_id, token=self.token)
self.token = client.refresh_token(self.token_url, **extra)
return True
except (InvalidGrantError, MissingTokenError, InvalidClientIdError, requests.exceptions.RequestException):
return False
sickrage.app.auth_server.logout(self._token.get('refresh_token'))
except KeycloakClientError:
pass
def refresh_token(self):
try:
if not self._token:
return self.login()
return False
self._token = sickrage.app.auth_server.refresh_token(self._token.get('refresh_token'))
except KeycloakClientError:
return self.login()
def exchange_token(self, access_token, scope='offline_access'):
exchange = {'scope': scope, 'subject_token': access_token}
exchanged_token = sickrage.app.auth_server.token_exchange(**exchange)
if exchanged_token:
self.token = exchanged_token
return True
def allowed_usernames(self):
return self.request('GET', 'allowed-usernames')
......@@ -197,7 +186,7 @@ class API(object):
return self.request('GET', 'network-timezones')
def request(self, method, url, timeout=120, **kwargs):
if not self.is_enabled or not self.session:
if not self.session:
return
url = urljoin(self.api_base, "/".join([self.api_version, url]))
......@@ -211,10 +200,6 @@ class API(object):
return None
continue
if self.token_time_remaining < (int(self.token['expires_in']) / 2):
if not self.refresh_token():
continue
resp = self.session.request(method, url, timeout=timeout, verify=False, hooks={'response': self.throttle_hook}, **kwargs)
resp.raise_for_status()
......@@ -225,12 +210,12 @@ class API(object):
return resp.json()
except ValueError:
return resp.content
except oauthlib.oauth2.TokenExpiredError:
except (oauthlib.oauth2.TokenExpiredError, oauthlib.oauth2.InvalidGrantError):
self.refresh_token()
time.sleep(1)
except (oauthlib.oauth2.InvalidClientIdError, oauthlib.oauth2.MissingTokenError) as e:
self.refresh_token()
time.sleep(1)
except (oauthlib.oauth2.InvalidClientIdError, oauthlib.oauth2.MissingTokenError, oauthlib.oauth2.InvalidGrantError) as e:
sickrage.app.log.warning("Invalid token error, please re-link your SiCKRAGE account from `settings->general->advanced->sickrage api`")
return
except requests.exceptions.ReadTimeout as e:
if i > 3:
sickrage.app.log.debug(f'Error connecting to url {url} Error: {e}')
......@@ -346,21 +331,21 @@ class API(object):
def get_announcements(self):
return self.api.request('GET', 'announcements')
class ProviderAPI:
class SearchProviderAPI:
def __init__(self, api):
self.api = api
def get_urls(self, provider):
query = f'provider/{provider}/urls'
return self.api.request('GET', query)
endpoint = f'provider/{provider}/urls'
return self.api.request('GET', endpoint)
def get_status(self, provider):
query = f'provider/{provider}/status'
return self.api.request('GET', query)
endpoint = f'provider/{provider}/status'
return self.api.request('GET', endpoint)
def get_search_result(self, provider, series_id, season, episode):
query = f'provider/{provider}/series-id/{series_id}/season/{season}/episode/{episode}'
return self.api.request('GET', query)
endpoint = f'provider/{provider}/series-id/{series_id}/season/{season}/episode/{episode}'
return self.api.request('GET', endpoint)
def add_search_result(self, provider, data):
return self.api.request('POST', f'provider/{provider}', json=data)
......@@ -370,12 +355,12 @@ class API(object):
self.api = api
def get_trackers(self):
query = f'torrent/trackers'
return self.api.request('GET', query)
endpoint = f'torrent/trackers'
return self.api.request('GET', endpoint)
def get_torrent(self, hash):
query = f'torrent/{hash}'
return self.api.request('GET', query)
endpoint = f'torrent/{hash}'
return self.api.request('GET', endpoint)
def add_torrent(self, url):
return self.api.request('POST', 'torrent', json={'url': url})
......@@ -385,56 +370,56 @@ class API(object):
self.api = api
def search_by_imdb_title(self, title):
query = f'imdb/search-by-title/{title}'
return self.api.request('GET', query)
endpoint = f'imdb/search-by-title/{title}'
return self.api.request('GET', endpoint)
def search_by_imdb_id(self, imdb_id):
query = f'imdb/search-by-id/{imdb_id}'
return self.api.request('GET', query)
endpoint = f'imdb/search-by-id/{imdb_id}'
return self.api.request('GET', endpoint)
class GoogleDriveAPI:
def __init__(self, api):
self.api = api
def is_connected(self):
query = 'google-drive/is-connected'
return self.api.request('GET', query)
endpoint = 'google-drive/is-connected'
return self.api.request('GET', endpoint)
def upload(self, file, folder):
query = 'google-drive/upload'
return self.api.request('POST', query, files={'file': open(file, 'rb')}, params={'folder': folder})
endpoint = 'google-drive/upload'
return self.api.request('POST', endpoint, files={'file': open(file, 'rb')}, params={'folder': folder})
def download(self, id):
query = f'google-drive/download/{id}'
return self.api.request('GET', query)
endpoint = f'google-drive/download/{id}'
return self.api.request('GET', endpoint)
def delete(self, id):
query = f'google-drive/delete/{id}'
return self.api.request('GET', query)
endpoint = f'google-drive/delete/{id}'
return self.api.request('GET', endpoint)
def search_files(self, id, term):
query = f'google-drive/search-files/{id}/{term}'
return self.api.request('GET', query)
endpoint = f'google-drive/search-files/{id}/{term}'
return self.api.request('GET', endpoint)
def list_files(self, id):
query = f'google-drive/list-files/{id}'
return self.api.request('GET', query)
endpoint = f'google-drive/list-files/{id}'
return self.api.request('GET', endpoint)
def clear_folder(self, id):
query = f'google-drive/clear-folder/{id}'
return self.api.request('GET', query)
endpoint = f'google-drive/clear-folder/{id}'
return self.api.request('GET', endpoint)
class SceneExceptions:
def __init__(self, api):
self.api = api
def get(self, *args, **kwargs):
query = 'scene-exceptions'
return self.api.request('GET', query)
endpoint = 'scene-exceptions'
return self.api.request('GET', endpoint)
def search_by_id(self, series_id):
query = f'scene-exceptions/search-by-id/{series_id}'
return self.api.request('GET', query)
endpoint = f'scene-exceptions/search-by-id/{series_id}'
return self.api.request('GET', endpoint)
class AlexaAPI:
def __init__(self, api):
......@@ -442,3 +427,27 @@ class API(object):
def send_notification(self, message):
return self.api.request('POST', 'alexa/notification', json={'message': message})
class SeriesProviderAPI:
def __init__(self, api):
self.api = api
def search(self, provider, query, language='eng'):
endpoint = f'series-provider/{provider}/search/{query}/{language}'
return self.api.request('GET', endpoint)
def get_series_info(self, provider, series_id, language='eng'):
endpoint = f'series-provider/{provider}/series/{series_id}/{language}'
return self.api.request('GET', endpoint)
def get_episodes_info(self, provider, series_id, season_type='default', language='eng'):
endpoint = f'series-provider/{provider}/series/{series_id}/episodes/{season_type}/{language}'
return self.api.request('GET', endpoint)
def languages(self, provider):
endpoint = f'series-provider/{provider}/languages'
return self.api.request('GET', endpoint)
def updates(self, provider, since):
endpoint = f'series-provider/{provider}/updates/{since}'
return self.api.request('GET', endpoint)
......@@ -36,9 +36,9 @@ class ImageCache(object):
FANART_THUMB = 6
IMAGE_TYPES = {
BANNER: 'series',
BANNER: 'banner',
POSTER: 'poster',
BANNER_THUMB: 'series_thumb',
BANNER_THUMB: 'banner_thumb',
POSTER_THUMB: 'poster_thumb',
FANART: 'fanart',
FANART_THUMB: 'fanart_thumb'
......@@ -266,6 +266,9 @@ class ImageCache(object):
# retrieve the image from a series provider using the generic metadata class
metadata_generator = MetadataProvider()
img_data = metadata_generator._retrieve_show_image(self.IMAGE_TYPES[img_type], show_obj)
if not img_data:
return False
result = metadata_generator._write_image(img_data, dest_path, force)
return result
......
......@@ -257,7 +257,7 @@ class TVCache(object):
from sickrage.search_providers import SearchProviderType
if not self.provider.private and self.provider.provider_type in [SearchProviderType.NZB, SearchProviderType.TORRENT]:
try:
sickrage.app.api.provider.add_search_result(provider=self.providerID, data=dbData)
sickrage.app.api.search_provider.add_search_result(provider=self.providerID, data=dbData)
except Exception as e:
pass
except (InvalidShowException, InvalidNameException):
......@@ -269,7 +269,7 @@ class TVCache(object):
# get data from external database
if sickrage.app.config.general.enable_sickrage_api and not self.provider.private:
resp = sickrage.app.api.provider.get_search_result(self.providerID, series_id, season, episode)
resp = sickrage.app.api.search_provider.get_search_result(self.providerID, series_id, season, episode)
if resp and 'data' in resp:
dbData += resp['data']
......
......@@ -103,18 +103,6 @@ class CacheDB(SRDatabase):
leechers = Column(Integer)
size = Column(Integer)
class OAuth2Token(base):
__tablename__ = 'oauth2_token'
id = Column(Integer, primary_key=True)
access_token = Column(String(255), unique=True, nullable=False)
refresh_token = Column(String(255), index=True)
expires_in = Column(Integer, nullable=False, default=0)
expires_at = Column(Integer, nullable=False, default=0)
scope = Column(Text, default="")
session_state = Column(Text, default="")
token_type = Column(Text, default="bearer")
class Announcements(base):
__tablename__ = 'announcements'
......
"""Initial migration
Revision ID: 10
Revises:
Create Date: 2017-12-29 14:39:27.854291
"""
import json
import os
from json import JSONDecodeError
import sqlalchemy as sa
from alembic import op
from sqlalchemy import orm
import sickrage
# revision identifiers, used by Alembic.
from sickrage.core import ConfigDB
revision = '10'
down_revision = '9'
def upgrade():
conn = op.get_bind()
meta = sa.MetaData(bind=conn)
oauth2_token = sa.Table('oauth2_token', meta, autoload=True)
certs = sickrage.app.auth_server.certs()
with op.get_context().begin_transaction():
for row in conn.execute(oauth2_token.select()):
refresh_token = row.refresh_token
new_token = sickrage.app.auth_server.refresh_token(refresh_token)
decoded_token = sickrage.app.auth_server.decode_token(new_token['access_token'], certs)
apikey = decoded_token['apikey']
try:
session = sickrage.app.config.db.session()
general = session.query(ConfigDB.General).one()
general.sso_api_key = apikey
session.commit()
except orm.exc.NoResultFound:
pass
if conn.engine.dialect.has_table(conn.engine, 'oauth2_token'):
op.drop_table('oauth2_token')
def downgrade():
# Operations to reverse the above upgrade go here.
pass
......@@ -40,6 +40,7 @@ from sickrage.core.tv.show.coming_episodes import ComingEpsLayout, ComingEpsSort
from sickrage.notification_providers.nmjv2 import NMJv2Location
from sickrage.search_providers import SearchProviderType
def encryption_key():
try:
return getattr(sickrage.app.config.user, 'sub_id', None) or 'sickrage'
......@@ -117,7 +118,7 @@ class CustomStringEncryptedType(StringEncryptedType):
class ConfigDB(SRDatabase):
base = declarative_base(cls=SRDatabaseBase)
def __init__(self, db_type, db_prefix, db_host, db_port, db_username, db_password):
super(ConfigDB, self).__init__('config', db_type, db_prefix, db_host, db_port, db_username, db_password)
......@@ -150,6 +151,7 @@ class ConfigDB(SRDatabase):
daily_searcher_freq = Column(Integer, default=40)
ignore_words = Column(Text, default=','.join(['german', 'french', 'core2hd', 'dutch', 'swedish', 'reenc', 'MrLss']))
api_v1_key = Column(Text, default=generate_api_key())
sso_api_key = Column(Text, default='')
sso_auth_enabled = Column(Boolean, default=True)
local_auth_enabled = Column(Boolean, default=False)
ip_whitelist_enabled = Column(Boolean, default=False)
......@@ -239,7 +241,7 @@ class ConfigDB(SRDatabase):
quality_default = Column(IntFlag(Qualities), default=Qualities.SD)
extra_scripts = Column(Text, default='')
flatten_folders_default = Column(Boolean, default=False)
series_provider_default_language = Column(Text, default='en')
series_provider_default_language = Column(Text, default='eng')
show_update_stale = Column(Boolean, default=True)
ep_default_deleted_status = Column(Enum(EpisodeStatus), default=EpisodeStatus.ARCHIVED)
no_restart = Column(Boolean, default=False)
......
"""Initial migration
Revision ID: 4
Revises:
Create Date: 2017-12-29 14:39:27.854291
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '4'
down_revision = '3'
def upgrade():
op.add_column('general', sa.Column('sso_api_key', sa.Text, default=''))
def downgrade():
pass
......@@ -27,7 +27,7 @@ def upgrade():
with op.get_context().begin_transaction():
for row in conn.execute(history.select()):
date = datetime.datetime.strptime(str(row.date), date_format)
conn.execute(f'UPDATE history SET date = {date} WHERE history.id = {row.id}')
conn.execute(f'UPDATE history SET date = "{date}" WHERE history.id = {row.id}')
op.alter_column('history', 'date', type_=sa.DateTime)
......
"""Initial migration
Revision ID: 22
Revises:
Create Date: 2017-12-29 14:39:27.854291
"""
import datetime
import babelfish
import pycountry
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.