diff --git a/django_roa/db/__init__.py b/django_roa/db/__init__.py index bc397b8..455053a 100644 --- a/django_roa/db/__init__.py +++ b/django_roa/db/__init__.py @@ -1,5 +1,8 @@ from threading import local from django.conf import settings +from django.utils.module_loading import import_string + +import requests ROA_SESSION_HEADERS_KEY = 'roa_session_headers_key' @@ -32,3 +35,11 @@ def get_roa_headers(): def reset_roa_headers(): if hasattr(_roa_headers, 'value'): del _roa_headers.value + + +def get_roa_client(): + client = getattr(settings, 'ROA_CLIENT', None) + if client is not None: + client_class = import_string(client) + return client_class() + return requests diff --git a/django_roa/db/models.py b/django_roa/db/models.py index 164b6f0..c764cc2 100644 --- a/django_roa/db/models.py +++ b/django_roa/db/models.py @@ -30,7 +30,7 @@ from rest_framework_yaml.renderers import YAMLRenderer from rest_framework_xml.renderers import XMLRenderer -from django_roa.db import get_roa_headers +from django_roa.db import get_roa_headers, get_roa_client from django_roa.db.exceptions import ROAException import requests @@ -725,6 +725,8 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, headers = get_roa_headers() headers.update(self.get_serializer_content_type()) + requests_client = get_roa_client() + # check if resource use custom primary key if not meta.pk.attname in ['pk', 'id']: # consider it might be inserting so check it first @@ -733,9 +735,9 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, try: if ROA_SSL_CA: - response=requests.get(self.get_resource_url_detail(),params=None,headers=headers,verify=ROA_SSL_CA) + response=requests_client.get(self.get_resource_url_detail(),params=None,headers=headers,verify=ROA_SSL_CA) else: - response=requests.get(self.get_resource_url_detail(),params=None,headers=headers) + response=requests_client.get(self.get_resource_url_detail(),params=None,headers=headers) response=response.text.encode("utf-8") except HTTPError: pk_is_set = False @@ -749,9 +751,9 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, force_text(payload), force_text(get_args))) if ROA_SSL_CA: - response=requests.put(self.get_resource_url_detail(),data=payload,headers=headers,verify=ROA_SSL_CA) + response=requests_client.put(self.get_resource_url_detail(),data=payload,headers=headers,verify=ROA_SSL_CA) else: - response=requests.put(self.get_resource_url_detail(),data=payload,headers=headers) + response=requests_client.put(self.get_resource_url_detail(),data=payload,headers=headers) response=response.text.encode("utf-8") except HTTPError as e: raise ROAException(e) @@ -764,9 +766,9 @@ def save_base(self, raw=False, cls=None, origin=None, force_insert=False, force_text(payload), force_text(get_args))) if ROA_SSL_CA: - response=requests.post(self.get_resource_url_list(),data=payload,headers=headers,verify=ROA_SSL_CA) + response=requests_client.post(self.get_resource_url_list(),data=payload,headers=headers,verify=ROA_SSL_CA) else: - response=requests.post(self.get_resource_url_list(),data=payload,headers=headers) + response=requests_client.post(self.get_resource_url_list(),data=payload,headers=headers) response=response.text.encode("utf-8") except HTTPError as e: raise ROAException(e) @@ -810,10 +812,12 @@ def delete(self): headers = get_roa_headers() headers.update(self.get_serializer_content_type()) + requests_client = get_roa_client() + if ROA_SSL_CA: - response=requests.delete(self.get_resource_url_detail(),headers=headers,verify=ROA_SSL_CA) + response=requests_client.delete(self.get_resource_url_detail(),headers=headers,verify=ROA_SSL_CA) else: - response=requests.delete(self.get_resource_url_detail(),headers=headers) + response=requests_client.delete(self.get_resource_url_detail(),headers=headers) if response.status_code in [200, 202, 204]: self.pk = None diff --git a/django_roa/db/query.py b/django_roa/db/query.py index 1eb9af6..28cba55 100644 --- a/django_roa/db/query.py +++ b/django_roa/db/query.py @@ -5,7 +5,7 @@ from django.db.models import query from django.core import serializers # Django >= 1.5 -from django_roa.db import get_roa_headers +from django_roa.db import get_roa_headers, get_roa_client try: from django.db.models.constants import LOOKUP_SEP @@ -209,9 +209,9 @@ def iterator(self): self.model.get_resource_url_list(), force_text(parameters))) if ROA_SSL_CA: - response = requests.get(self.model.get_resource_url_list(),params=parameters,headers=self._get_http_headers(),verify=ROA_SSL_CA) + response = self._get_requests_client().get(self.model.get_resource_url_list(),params=parameters,headers=self._get_http_headers(),verify=ROA_SSL_CA) else: - response = requests.get(self.model.get_resource_url_list(),params=parameters,headers=self._get_http_headers()) + response = self._get_requests_client().get(self.model.get_resource_url_list(),params=parameters,headers=self._get_http_headers()) except Exception as e: raise ROAException(e) @@ -250,6 +250,8 @@ def iterator(self): return result + return [] + def count(self): """ Returns the number of records as an integer. @@ -271,10 +273,10 @@ def count(self): self.model.get_resource_url_list(), force_text(parameters))) if ROA_SSL_CA: - response = requests.get(self.model.get_resource_url_list(), params=parameters, + response = self._get_requests_client().get(self.model.get_resource_url_list(), params=parameters, headers=self._get_http_headers(),verify=ROA_SSL_CA) else: - response = requests.get(self.model.get_resource_url_list(), params=parameters, + response = self._get_requests_client().get(self.model.get_resource_url_list(), params=parameters, headers=self._get_http_headers()) except Exception as e: raise ROAException(e) @@ -307,10 +309,10 @@ def _get_from_id_or_pk(self, id=None, pk=None, **kwargs): instance.get_resource_url_detail(), force_text(parameters))) if ROA_SSL_CA: - response = requests.get(instance.get_resource_url_detail(), params=parameters, + response = self._get_requests_client().get(instance.get_resource_url_detail(), params=parameters, headers=self._get_http_headers(),verify=ROA_SSL_CA) else: - response = requests.get(instance.get_resource_url_detail(), params=parameters, + response = self._get_requests_client().get(instance.get_resource_url_detail(), params=parameters, headers=self._get_http_headers()) except Exception as e: raise ROAException(e) @@ -548,3 +550,6 @@ def _as_url(self): def _get_http_headers(self): return get_roa_headers() + + def _get_requests_client(self): + return get_roa_client()