Commit a7bffda5 authored by echel0n's avatar echel0n
Browse files

Refactored web handlers to return data and call tornado finish on resp from run_async method

Merged base handler render_string and render methods into render method
Refactored base handlers for API v1 and v2
Fixed redirect issues for auth handlers
parent 2870a93c
......@@ -39,7 +39,7 @@ from sickrage.core.helpers import create_https_certificates
from sickrage.core.webserver.handlers.account import AccountLinkHandler, AccountUnlinkHandler, AccountIsLinkedHandler
from sickrage.core.webserver.handlers.announcements import AnnouncementsHandler, MarkAnnouncementSeenHandler, AnnouncementCountHandler
from sickrage.core.webserver.handlers.api import ApiSwaggerDotJsonHandler, ApiPingHandler, ApiProfileHandler
from sickrage.core.webserver.handlers.api.v1 import ApiHandler
from sickrage.core.webserver.handlers.api.v1 import ApiV1BaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2RetrieveSeriesMetadataHandler
from sickrage.core.webserver.handlers.api.v2.config import ApiV2ConfigHandler
from sickrage.core.webserver.handlers.api.v2.file_browser import ApiV2FileBrowserHandler
......@@ -204,7 +204,7 @@ class WebServer(object):
# API v1 Handlers
self.handlers['api_v1_handlers'] = [
# api
(fr'{self.api_v1_root}(/?.*)', ApiHandler),
(fr'{self.api_v1_root}(/?.*)', ApiV1BaseHandler),
# api builder
(fr'{sickrage.app.config.general.web_root}/api/builder', RedirectHandler,
......
......@@ -92,7 +92,7 @@ class AccountLinkHandler(BaseHandler):
else:
authorization_url = sickrage.app.auth_server.authorization_url(redirect_uri=redirect_uri, scope="profile email")
if authorization_url:
return super(BaseHandler, self).redirect(authorization_url)
return self.redirect(authorization_url, add_web_root=False)
return self.redirect('/account/link')
......@@ -117,4 +117,4 @@ class AccountUnlinkHandler(BaseHandler):
class AccountIsLinkedHandler(BaseHandler):
@authenticated
def get(self, *args, **kwargs):
return self.write(json.dumps({'linked': ('true', 'false')[not sickrage.app.api.userinfo]}))
return json.dumps({'linked': ('true', 'false')[not sickrage.app.api.userinfo]})
......@@ -46,10 +46,10 @@ class MarkAnnouncementSeenHandler(BaseHandler):
if announcement:
announcement.seen = True
return self.write(json.dumps({'success': True}))
return json.dumps({'success': True})
class AnnouncementCountHandler(BaseHandler):
@authenticated
def get(self, *args, **kwargs):
return self.write(json.dumps({'count': sickrage.app.announcements.count()}))
return json.dumps({'count': sickrage.app.announcements.count()})
......@@ -18,9 +18,10 @@
# You should have received a copy of the GNU General Public License
# along with SiCKRAGE. If not, see <http://www.gnu.org/licenses/>.
# ##############################################################################
import ipaddress
import functools
import json
import traceback
import types
import sentry_sdk
from apispec import APISpec
......@@ -28,11 +29,12 @@ from apispec.exceptions import APISpecError
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec_webframeworks.tornado import TornadoPlugin
from tornado.escape import to_basestring
from tornado.ioloop import IOLoop
from tornado.web import HTTPError
import sickrage
from sickrage.core.enums import UserPermission
from sickrage.core.helpers import get_external_ip, get_internal_ip, get_ip_address
from sickrage.core.helpers import get_internal_ip
from sickrage.core.webserver.handlers.base import BaseHandler
......@@ -104,6 +106,9 @@ class APIBaseHandler(BaseHandler):
if sickrage.app.config.general.server_id:
sentry_sdk.set_tag('server_id', sickrage.app.config.general.server_id)
method = self.run_async(getattr(self, method_name))
setattr(self, method_name, method)
except Exception:
return self.send_error(401, error='failed to decode token')
else:
......@@ -111,6 +116,14 @@ class APIBaseHandler(BaseHandler):
else:
return self.send_error(401, error='authorization header missing')
def run_async(self, method):
@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
resp = await IOLoop.current().run_in_executor(self.executor, functools.partial(method, *args, **kwargs))
self.finish(resp)
return types.MethodType(wrapper, self)
def get_current_user(self):
auth_header = self.request.headers.get('Authorization')
if 'bearer' in auth_header.lower():
......@@ -135,14 +148,14 @@ class APIBaseHandler(BaseHandler):
sickrage.app.log.error(error_msg)
self.write_json({'error': error_msg})
return self.finish(self.to_json({'error': error_msg}))
def set_default_headers(self):
super(APIBaseHandler, self).set_default_headers()
self.set_header('Content-Type', 'application/json')
def write_json(self, response):
self.write(json.dumps(response))
def to_json(self, response):
return json.dumps(response)
def _validate_schema(self, schema, arguments):
return schema().validate({k: to_basestring(v[0]) if len(v) <= 1 else to_basestring(v) for k, v in arguments.items()})
......@@ -183,12 +196,12 @@ class APIBaseHandler(BaseHandler):
class ApiProfileHandler(APIBaseHandler):
def get(self):
return self.write_json(self.current_user)
return self.to_json(self.current_user)
class ApiPingHandler(APIBaseHandler):
def get(self):
return self.write_json({'message': 'pong'})
return self.to_json({'message': 'pong'})
class ApiSwaggerDotJsonHandler(APIBaseHandler):
......@@ -199,4 +212,4 @@ class ApiSwaggerDotJsonHandler(APIBaseHandler):
def get(self):
""" Get swagger.json """
return self.write_json(self.generate_swagger_json(self.api_handlers, self.api_version))
return self.to_json(self.generate_swagger_json(self.api_handlers, self.api_version))
......@@ -19,12 +19,19 @@
# along with SiCKRAGE. If not, see <http://www.gnu.org/licenses/>.
# ##############################################################################
import os
from concurrent.futures.thread import ThreadPoolExecutor
import sickrage
from sickrage.core.webserver.handlers.api import APIBaseHandler
class ApiV2RetrieveSeriesMetadataHandler(APIBaseHandler):
class ApiV2BaseHandler(APIBaseHandler):
def __init__(self, application, request, **kwargs):
super(APIBaseHandler, self).__init__(application, request, **kwargs)
self.executor = ThreadPoolExecutor(thread_name_prefix='APIv2-Thread')
class ApiV2RetrieveSeriesMetadataHandler(ApiV2BaseHandler):
def get(self):
series_directory = self.get_argument('seriesDirectory', None)
if not series_directory:
......@@ -54,4 +61,4 @@ class ApiV2RetrieveSeriesMetadataHandler(APIBaseHandler):
if not json_data['seriesSlug'] and series_id and series_provider_id:
json_data['seriesSlug'] = f'{series_id}-{series_provider_id.slug}'
self.write_json(json_data)
return self.to_json(json_data)
......@@ -24,10 +24,10 @@ import sickrage
from sickrage.core.common import Overview
from sickrage.core.common import Qualities, EpisodeStatus
from sickrage.core.enums import SearchFormat
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
class ApiV2ConfigHandler(APIBaseHandler):
class ApiV2ConfigHandler(ApiV2BaseHandler):
def get(self, *args, **kwargs):
config_data = sickrage.app.config.to_json()
......@@ -59,4 +59,4 @@ class ApiV2ConfigHandler(APIBaseHandler):
}
}
return self.write_json(config_data)
return self.to_json(config_data)
......@@ -20,15 +20,15 @@
# ##############################################################################
import os
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
class ApiV2FileBrowserHandler(APIBaseHandler):
class ApiV2FileBrowserHandler(ApiV2BaseHandler):
def get(self):
path = self.get_argument('path', None)
include_files = self.get_argument('includeFiles', None)
return self.write_json(self.get_path(path, bool(include_files)))
return self.to_json(self.get_path(path, bool(include_files)))
def get_path(self, path, include_files=False):
entries = {
......
......@@ -26,10 +26,10 @@ from sickrage.core import Quality
from sickrage.core.common import dateTimeFormat
from sickrage.core.helpers import convert_dict_keys_to_camelcase
from sickrage.core.tv.show.history import History
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
class ApiV2HistoryHandler(APIBaseHandler):
class ApiV2HistoryHandler(ApiV2BaseHandler):
def get(self):
"""Get snatch and download history"
---
......@@ -88,4 +88,4 @@ class ApiV2HistoryHandler(APIBaseHandler):
results.append(row)
return self.write_json(results)
return self.to_json(results)
......@@ -22,11 +22,11 @@
import sickrage
from sickrage.core.enums import ProcessMethod
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
from sickrage.core.webserver.handlers.api.v2.postprocess.schemas import PostProcessSchema
class Apiv2PostProcessHandler(APIBaseHandler):
class Apiv2PostProcessHandler(ApiV2BaseHandler):
def get(self):
"""Postprocess TV show video files"
---
......@@ -81,4 +81,4 @@ class Apiv2PostProcessHandler(APIBaseHandler):
if 'Processing succeeded' not in json_data:
return self.send_error(400, error=json_data)
self.write_json({'data': json_data if return_data else ''})
return self.to_json({'data': json_data if return_data else ''})
......@@ -23,10 +23,10 @@ import datetime
import sickrage
from sickrage.core.helpers import convert_dict_keys_to_camelcase
from sickrage.core.tv.show.coming_episodes import ComingEpisodes
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
class ApiV2ScheduleHandler(APIBaseHandler):
class ApiV2ScheduleHandler(ApiV2BaseHandler):
def get(self):
"""Get TV show schedule information"
---
......@@ -74,4 +74,4 @@ class ApiV2ScheduleHandler(APIBaseHandler):
results[i]['localtime'] = result['localtime'].timestamp()
results[i] = convert_dict_keys_to_camelcase(results[i])
return self.write_json({'episodes': results, 'today': today.timestamp(), 'nextWeek': next_week.timestamp()})
return self.to_json({'episodes': results, 'today': today.timestamp(), 'nextWeek': next_week.timestamp()})
......@@ -35,12 +35,12 @@ from sickrage.core.media.util import series_image, SeriesImageType
from sickrage.core.queues.search import ManualSearchTask
from sickrage.core.tv.episode.helpers import find_episode
from sickrage.core.tv.show.helpers import get_show_list, find_show, find_show_by_slug
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
from sickrage.core.websocket import WebSocketMessage
from .schemas import *
class ApiV2SeriesHandler(APIBaseHandler):
class ApiV2SeriesHandler(ApiV2BaseHandler):
def get(self, series_slug=None):
"""Get list of series or specific series information"
---
......@@ -87,13 +87,13 @@ class ApiV2SeriesHandler(APIBaseHandler):
all_series.append(show.to_json(progress=True))
return self.write_json(all_series)
return self.to_json(all_series)
series = find_show_by_slug(series_slug)
if series is None:
return self.send_error(404, error=f"Unable to find the specified series using slug: {series_slug}")
return self.write_json(series.to_json(episodes=True, details=True))
return self.to_json(series.to_json(episodes=True, details=True))
def post(self):
data = json_decode(self.request.body)
......@@ -181,7 +181,7 @@ class ApiV2SeriesHandler(APIBaseHandler):
sickrage.app.alerts.message(_('Adding Show'), _(f'Adding the specified show into {series_directory}'))
return self.write_json({'message': True})
return self.to_json({'message': True})
def patch(self, series_slug):
warnings, errors = [], []
......@@ -307,7 +307,7 @@ class ApiV2SeriesHandler(APIBaseHandler):
# commit changes to database
series.save()
return self.write_json(series.to_json(episodes=True, details=True))
return self.to_json(series.to_json(episodes=True, details=True))
def delete(self, series_slug):
data = json_decode(self.request.body)
......@@ -318,10 +318,10 @@ class ApiV2SeriesHandler(APIBaseHandler):
sickrage.app.show_queue.remove_show(series.series_id, series.series_provider_id, checkbox_to_value(data.get('delete')))
return self.write_json({'message': True})
return self.to_json({'message': True})
class ApiV2SeriesEpisodesHandler(APIBaseHandler):
class ApiV2SeriesEpisodesHandler(ApiV2BaseHandler):
def get(self, series_slug, *args, **kwargs):
series = find_show_by_slug(series_slug)
if series is None:
......@@ -331,20 +331,20 @@ class ApiV2SeriesEpisodesHandler(APIBaseHandler):
for episode in series.episodes:
episodes.append(episode.to_json())
return self.write_json(episodes)
return self.to_json(episodes)
class ApiV2SeriesImagesHandler(APIBaseHandler):
class ApiV2SeriesImagesHandler(ApiV2BaseHandler):
def get(self, series_slug, *args, **kwargs):
series = find_show_by_slug(series_slug)
if series is None:
return self.send_error(404, error=f"Unable to find the specified series using slug: {series_slug}")
image = series_image(series.series_id, series.series_provider_id, SeriesImageType.POSTER_THUMB)
return self.write_json({'poster': image.url})
return self.to_json({'poster': image.url})
class ApiV2SeriesImdbInfoHandler(APIBaseHandler):
class ApiV2SeriesImdbInfoHandler(ApiV2BaseHandler):
def get(self, series_slug, *args, **kwargs):
series = find_show_by_slug(series_slug)
if series is None:
......@@ -354,10 +354,10 @@ class ApiV2SeriesImdbInfoHandler(APIBaseHandler):
imdb_info = session.query(MainDB.IMDbInfo).filter_by(imdb_id=series.imdb_id).one_or_none()
json_data = IMDbInfoSchema().dump(imdb_info)
return self.write_json(json_data)
return self.to_json(json_data)
class ApiV2SeriesBlacklistHandler(APIBaseHandler):
class ApiV2SeriesBlacklistHandler(ApiV2BaseHandler):
def get(self, series_slug, *args, **kwargs):
series = find_show_by_slug(series_slug)
if series is None:
......@@ -367,10 +367,10 @@ class ApiV2SeriesBlacklistHandler(APIBaseHandler):
blacklist = session.query(MainDB.Blacklist).filter_by(series_id=series.series_id, series_provider_id=series.series_provider_id).one_or_none()
json_data = BlacklistSchema().dump(blacklist)
return self.write_json(json_data)
return self.to_json(json_data)
class ApiV2SeriesWhitelistHandler(APIBaseHandler):
class ApiV2SeriesWhitelistHandler(ApiV2BaseHandler):
def get(self, series_slug, *args, **kwargs):
series = find_show_by_slug(series_slug)
if series is None:
......@@ -380,10 +380,10 @@ class ApiV2SeriesWhitelistHandler(APIBaseHandler):
whitelist = session.query(MainDB.Whitelist).filter_by(series_id=series.series_id, series_provider_id=series.series_provider_id).one_or_none()
json_data = WhitelistSchema().dump(whitelist)
return self.write_json(json_data)
return self.to_json(json_data)
class ApiV2SeriesRefreshHandler(APIBaseHandler):
class ApiV2SeriesRefreshHandler(ApiV2BaseHandler):
def get(self, series_slug):
force = self.get_argument('force', None)
......@@ -397,7 +397,7 @@ class ApiV2SeriesRefreshHandler(APIBaseHandler):
return self.send_error(400, error=_(f"Unable to refresh this show, error: {e}"))
class ApiV2SeriesUpdateHandler(APIBaseHandler):
class ApiV2SeriesUpdateHandler(ApiV2BaseHandler):
def get(self, series_slug):
force = self.get_argument('force', None)
......@@ -411,7 +411,7 @@ class ApiV2SeriesUpdateHandler(APIBaseHandler):
return self.send_error(400, error=_(f"Unable to update this show, error: {e}"))
class ApiV2SeriesEpisodesRenameHandler(APIBaseHandler):
class ApiV2SeriesEpisodesRenameHandler(ApiV2BaseHandler):
def get(self, series_slug):
"""Get list of episodes to rename"
---
......@@ -470,7 +470,7 @@ class ApiV2SeriesEpisodesRenameHandler(APIBaseHandler):
'newLocation': new_location,
})
return self.write_json(rename_data)
return self.to_json(rename_data)
def post(self, series_slug):
"""Rename list of episodes"
......@@ -522,10 +522,10 @@ class ApiV2SeriesEpisodesRenameHandler(APIBaseHandler):
if len(renamed_episodes) > 0:
WebSocketMessage('SHOW_RENAMED', {'seriesSlug': series.slug}).push()
return self.write_json(renamed_episodes)
return self.to_json(renamed_episodes)
class ApiV2SeriesEpisodesManualSearchHandler(APIBaseHandler):
class ApiV2SeriesEpisodesManualSearchHandler(ApiV2BaseHandler):
def get(self, series_slug, episode_slug):
"""Episode Manual Search"
---
......@@ -597,7 +597,7 @@ class ApiV2SeriesEpisodesManualSearchHandler(APIBaseHandler):
sickrage.app.search_queue.put(ep_queue_item)
if not all([ep_queue_item.started, ep_queue_item.success]):
return self.write_json({'success': True})
return self.to_json({'success': True})
return self.send_error(
status_code=404,
......
......@@ -22,15 +22,15 @@
import sickrage
from sickrage.core.enums import SeriesProviderID
from sickrage.core.webserver.handlers.api import APIBaseHandler
from sickrage.core.webserver.handlers.api.v2 import ApiV2BaseHandler
class ApiV2SeriesProvidersHandler(APIBaseHandler):
class ApiV2SeriesProvidersHandler(ApiV2BaseHandler):
def get(self):
self.write_json([{'displayName': x.display_name, 'slug': x.slug} for x in SeriesProviderID])
return self.to_json([{'displayName': x.display_name, 'slug': x.slug} for x in SeriesProviderID])
class ApiV2SeriesProvidersSearchHandler(APIBaseHandler):
class ApiV2SeriesProvidersSearchHandler(ApiV2BaseHandler):
def get(self, series_provider_slug):
search_term = self.get_argument('searchTerm', None)
lang = self.get_argument('seriesProviderLanguage', None)
......@@ -46,13 +46,13 @@ class ApiV2SeriesProvidersSearchHandler(APIBaseHandler):
if not results:
return self.send_error(404, reason=f"Unable to find the series using the search term: {search_term}")
return self.write_json(results)
return self.to_json(results)
class ApiV2SeriesProvidersLanguagesHandler(APIBaseHandler):
class ApiV2SeriesProvidersLanguagesHandler(ApiV2BaseHandler):
def get(self, series_provider_slug):
series_provider_id = SeriesProviderID.by_slug(series_provider_slug)
if not series_provider_id:
return self.send_error(404, reason="Unable to identify a series provider using provided slug")
self.write_json(sickrage.app.series_providers[series_provider_id].languages())
return self.to_json(sickrage.app.series_providers[series_provider_id].languages())
......@@ -43,7 +43,7 @@ class BaseHandler(RequestHandler):
def __init__(self, application, request, **kwargs):
super(BaseHandler, self).__init__(application, request, **kwargs)
self.executor = ThreadPoolExecutor(thread_name_prefix='TORNADO-Thread')
self.executor = ThreadPoolExecutor(thread_name_prefix='WEB-Thread')
self.startTime = time.time()
......@@ -65,7 +65,7 @@ class BaseHandler(RequestHandler):
request_info = ''.join([f"<strong>{k}</strong>: {v}<br>" for k, v in self.request.__dict__.items()])
self.set_header('Content-Type', 'text/html')
return self.write(f"""<html>
return self.finish(f"""<html>
<title>{error}</title>
<body>
<button onclick="window.location='{sickrage.app.config.general.web_root}/logs/';">View Log(Errors)</button>
......@@ -111,7 +111,7 @@ class BaseHandler(RequestHandler):
if cookie == sickrage.app.config.general.api_v1_key:
return True
def render_string(self, template_name, **kwargs):
def render(self, template_name, **kwargs):
template_kwargs = {
'title': "",
'header': "",
......@@ -152,17 +152,14 @@ class BaseHandler(RequestHandler):
return self.application.settings['templates']['errors/500.mako'].render_unicode(**template_kwargs)
def render(self, template_name, **kwargs):
self.write(self.render_string(template_name, **kwargs))
def set_default_headers(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header("Access-Control-Allow-Headers", "Content-Type, Access-Control-Allow-Headers, Authorization, X-Requested-With")
self.set_header('Access-Control-Allow-Methods', 'POST, GET, PUT, PATCH, DELETE, OPTIONS')
self.set_header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0')
def redirect(self, url, permanent=True, status=None):
if sickrage.app.config.general.web_root not in url:
def redirect(self, url, permanent=True, status=None, add_web_root=True):
if add_web_root and sickrage.app.config.general.web_root not in url:
url = urljoin(sickrage.app.config.general.web_root + '/', url.lstrip('/'))
if self._headers_written:
......@@ -195,7 +192,8 @@ class BaseHandler(RequestHandler):
def run_async(self, method):
@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
await IOLoop.current().run_in_executor(self.executor, functools.partial(method, *args, **kwargs))
resp = await IOLoop.current().run_in_executor(self.executor, functools.partial(method, *args, **kwargs))
self.finish(resp)
return types.MethodType(wrapper, self)
......
......@@ -21,7 +21,7 @@
import datetime
import dateutil
from dateutil.tz import gettz
from tornado.web import authenticated
import sickrage
......@@ -33,19 +33,19 @@ from sickrage.core.webserver.handlers.base import BaseHandler
class CalendarHandler(BaseHandler):
def get(self, *args, **kwargs):
if sickrage.app.config.general.calendar_unprotected:
self.write(self.calendar())
return self.calendar()
else:
self.calendar_auth()
return self.calendar_auth()
@authenticated
def calendar_auth(self):
self.write(self.calendar())
return self.calendar()
def calendar(self):
""" Provides a subscribeable URL for iCal subscriptions
"""
utc = dateutil.tz.gettz('GMT')
utc = gettz('GMT')
sickrage.app.log.info("Receiving iCal request from %s" % self.request.remote_ip)
......
......@@ -35,4 +35,4 @@ class ChangelogHandler(BaseHandler):
except Exception:
data = ''
return self.write(data)
return data
......@@ -58,7 +58,7 @@ class ConfigBackupHandler(BaseHandler):
final_result += "<br>\n"
return self.write(final_result)
return final_result
class ConfigRestoreHandler(BaseHandler):
......@@ -89,7 +89,7 @@ class ConfigRestoreHandler(BaseHandler):
final_result += "<br>\n"
return self.write(final_result)
return final_result
class SaveBackupRestoreHandler(BaseHandler):
......
......@@ -48,7 +48,7 @@ class ConfigGeneralHandler(BaseHandler):
class GenerateApiKeyHandler(BaseHandler):
@authenticated