From faa50e396ff21a75767ef52f61b8bcd607e73202 Mon Sep 17 00:00:00 2001 From: guohelu <19503896967@163.com> Date: Fri, 13 Dec 2024 12:54:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=96=B9=E6=B3=95=20#7626?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gcloud/contrib/template_market/clients.py | 6 +++ gcloud/contrib/template_market/models.py | 50 +++++++++++++++++++ gcloud/contrib/template_market/permission.py | 2 +- gcloud/contrib/template_market/serializers.py | 37 +------------- gcloud/contrib/template_market/urls.py | 4 +- gcloud/contrib/template_market/viewsets.py | 17 ++++--- 6 files changed, 70 insertions(+), 46 deletions(-) diff --git a/gcloud/contrib/template_market/clients.py b/gcloud/contrib/template_market/clients.py index 97cafd2da..f1da852d9 100644 --- a/gcloud/contrib/template_market/clients.py +++ b/gcloud/contrib/template_market/clients.py @@ -23,6 +23,12 @@ def __init__(self): def _get_url(self, endpoint): return f"{self.base_url}{endpoint}" + def get_template_detail(self, market_record_id): + url = self._get_url(f"/sre_scene/flow_template_scene/{market_record_id}/") + cookies = {"bk_ticket": "5IT20mfD9mI_uTitTrUuDKTp9GIif7ZaBivi9E6k5qw"} + response = requests.get(url, cookies=cookies) + return response.json() + def get_market_template_list(self): url = self._get_url("/sre_scene/flow_template_scene/?is_all=true") response = requests.get(url) diff --git a/gcloud/contrib/template_market/models.py b/gcloud/contrib/template_market/models.py index 1adc82ce1..04d02eae0 100644 --- a/gcloud/contrib/template_market/models.py +++ b/gcloud/contrib/template_market/models.py @@ -14,6 +14,54 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ +from gcloud import err_code + + +class TemplateSharedManager(models.Manager): + def update_shared_record(self, new_template_ids, market_record_id, project_id, creator, existing_template_ids=None): + market_record_id = int(market_record_id) + + if existing_template_ids: + templates_to_remove = existing_template_ids - set(new_template_ids) + if templates_to_remove: + for template_id in templates_to_remove: + current_template_record = TemplateSharedRecord.objects.get(template_id=template_id) + current_market_ids = current_template_record.extra_info.get("market_record_ids", []) + if market_record_id in current_market_ids: + current_market_ids.remove(market_record_id) + current_template_record.extra_info["market_record_ids"] = current_market_ids + current_template_record.save() + if not current_template_record.extra_info["market_record_ids"]: + current_template_record.delete() + else: + return { + "result": False, + "message": "template {} is not in record {}".format(template_id, market_record_id), + "code": err_code.REQUEST_PARAM_INVALID.code, + } + + templates_to_add = set(new_template_ids) - existing_template_ids + if templates_to_add: + new_template_ids = list(templates_to_add) + + new_records = [] + for template_id in new_template_ids: + existing_record, created = TemplateSharedRecord.objects.get_or_create( + project_id=project_id, + template_id=template_id, + defaults={"creator": creator, "extra_info": {"market_record_ids": [market_record_id]}}, + ) + if not created: + market_ids = existing_record.extra_info.setdefault("market_record_ids", []) + if market_record_id not in market_ids: + market_ids.append(market_record_id) + new_records.append(existing_record) + + if new_records: + TemplateSharedRecord.objects.bulk_update(new_records, ["extra_info"]) + + return {"result": True, "message": "update shared record successfully", "code": err_code.SUCCESS.code} + class TemplateSharedRecord(models.Model): project_id = models.IntegerField(_("项目 ID"), default=-1, help_text="项目 ID") @@ -23,6 +71,8 @@ class TemplateSharedRecord(models.Model): update_at = models.DateTimeField(verbose_name=_("更新时间"), auto_now=True) extra_info = models.JSONField(_("额外信息"), blank=True, null=True) + objects = TemplateSharedManager() + class Meta: verbose_name = _("模板共享记录 TemplateSharedRecord") verbose_name_plural = _("模板共享记录 TemplateSharedRecord") diff --git a/gcloud/contrib/template_market/permission.py b/gcloud/contrib/template_market/permission.py index 8ac657d34..d2f684701 100644 --- a/gcloud/contrib/template_market/permission.py +++ b/gcloud/contrib/template_market/permission.py @@ -24,7 +24,7 @@ class TemplatePreviewPermission(permissions.BasePermission): def has_permission(self, request, view): - serializer = TemplateProjectBaseSerializer(data=request.GET) + serializer = TemplateProjectBaseSerializer(data=request.query_params) serializer.is_valid(raise_exception=True) template_id = int(serializer.validated_data["template_id"]) diff --git a/gcloud/contrib/template_market/serializers.py b/gcloud/contrib/template_market/serializers.py index ebcab5571..89be644f1 100644 --- a/gcloud/contrib/template_market/serializers.py +++ b/gcloud/contrib/template_market/serializers.py @@ -11,9 +11,8 @@ specific language governing permissions and limitations under the License. """ import json -from rest_framework import serializers -from gcloud.contrib.template_market.models import TemplateSharedRecord +from rest_framework import serializers class TemplatePreviewSerializer(serializers.Serializer): @@ -42,37 +41,3 @@ class TemplateSharedRecordSerializer(serializers.Serializer): usage_id = serializers.IntegerField(required=True, help_text="使用说明id") labels = serializers.ListField(child=serializers.IntegerField(), required=True, help_text="共享标签列表") usage_content = serializers.JSONField(required=True, help_text="使用说明") - - def create_shared_record(self, project_id, market_record_id, template_ids, creator): - for template_id in template_ids: - existing_record, created = TemplateSharedRecord.objects.get_or_create( - project_id=project_id, - template_id=template_id, - defaults={"creator": creator, "extra_info": {"market_record_ids": [market_record_id]}}, - ) - if not created: - market_ids = existing_record.extra_info.setdefault("market_record_ids", []) - if market_record_id not in market_ids: - market_ids.append(market_record_id) - existing_record.save() - - def update_shared_record(self, new_template_ids, market_record_id, project_id, creator): - market_record_id = int(market_record_id) - - existing_records = TemplateSharedRecord.objects.filter( - project_id=project_id, extra_info__market_record_ids__contains=[market_record_id] - ) - existing_template_ids = set(existing_records.values_list("template_id", flat=True)) - templates_to_remove = existing_template_ids - set(new_template_ids) - - for template_id in templates_to_remove: - current_template_record = existing_records.get(template_id=template_id) - current_market_ids = current_template_record.extra_info.get("market_record_ids", []) - if market_record_id in current_market_ids: - current_market_ids.remove(market_record_id) - current_template_record.extra_info["market_record_ids"] = current_market_ids - current_template_record.save() - - templates_to_add = set(new_template_ids) - existing_template_ids - if templates_to_add: - self.create_shared_record(project_id, market_record_id, list(templates_to_add), creator) diff --git a/gcloud/contrib/template_market/urls.py b/gcloud/contrib/template_market/urls.py index 250be933f..a85a34312 100644 --- a/gcloud/contrib/template_market/urls.py +++ b/gcloud/contrib/template_market/urls.py @@ -13,13 +13,13 @@ from django.conf.urls import include, url from rest_framework.routers import DefaultRouter -from gcloud.contrib.template_market.viewsets import TemplatePreviewViewSet, SharedTemplateRecordsViewSet +from gcloud.contrib.template_market.viewsets import TemplatePreviewAPIView, SharedTemplateRecordsViewSet template_market_router = DefaultRouter() -template_market_router.register(r"template_preview", TemplatePreviewViewSet) template_market_router.register(r"shared_templates_records", SharedTemplateRecordsViewSet) urlpatterns = [ url(r"^api/", include(template_market_router.urls)), + url(r"^api/template_preview/$", TemplatePreviewAPIView.as_view()), ] diff --git a/gcloud/contrib/template_market/viewsets.py b/gcloud/contrib/template_market/viewsets.py index 4bbf52132..9ec4849e2 100644 --- a/gcloud/contrib/template_market/viewsets.py +++ b/gcloud/contrib/template_market/viewsets.py @@ -14,6 +14,7 @@ import logging from rest_framework import viewsets +from rest_framework.views import APIView from rest_framework.response import Response from rest_framework import permissions @@ -31,13 +32,13 @@ from gcloud.contrib.template_market.permission import TemplatePreviewPermission, SharedTemplateRecordPermission -class TemplatePreviewViewSet(viewsets.ViewSet): +class TemplatePreviewAPIView(APIView): queryset = TaskTemplate.objects.filter(pipeline_template__isnull=False, is_deleted=False) serializer_class = TemplatePreviewSerializer permission_classes = [permissions.IsAuthenticated, TemplatePreviewPermission] - def retrieve(self, request, *args, **kwargs): - request_serializer = TemplateProjectBaseSerializer(data=request.GET) + def get(self, request, *args, **kwargs): + request_serializer = TemplateProjectBaseSerializer(data=request.query_params) request_serializer.is_valid(raise_exception=True) template_id = request_serializer.validated_data["template_id"] @@ -105,9 +106,9 @@ def create(self, request, *args, **kwargs): "code": err_code.OPERATION_FAIL.code, } ) - serializer.create_shared_record( + TemplateSharedRecord.objects.update_shared_record( project_id=int(serializer.validated_data["project_id"]), - template_ids=serializer.validated_data["template_ids"], + new_template_ids=serializer.validated_data["template_ids"], market_record_id=response_data["data"]["id"], creator=serializer.validated_data["creator"], ) @@ -118,7 +119,8 @@ def partial_update(self, request, *args, **kwargs): market_record_id = kwargs["pk"] serializer = self.serializer_class(data=request.data, partial=True) serializer.is_valid(raise_exception=True) - + existing_records = self.market_client.get_template_detail(market_record_id) + existing_template_ids = set([template["id"] for template in json.loads(existing_records["data"]["templates"])]) data = self._build_template_data(serializer, market_record_id=market_record_id) response_data = self.market_client.patch_market_template_record(data, market_record_id) if not response_data.get("result"): @@ -129,10 +131,11 @@ def partial_update(self, request, *args, **kwargs): "code": err_code.OPERATION_FAIL.code, } ) - serializer.update_shared_record( + TemplateSharedRecord.objects.update_shared_record( project_id=int(serializer.validated_data["project_id"]), new_template_ids=serializer.validated_data["template_ids"], market_record_id=market_record_id, creator=serializer.validated_data["creator"], + existing_template_ids=existing_template_ids, ) return Response({"result": True, "data": response_data, "code": err_code.SUCCESS.code})