Skip to content

Commit

Permalink
ECS: Tagging is now supported for Tasks (getmoto#7029)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Nov 15, 2023
1 parent d3efa2a commit ed56ffd
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 64 deletions.
36 changes: 16 additions & 20 deletions moto/ecs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,10 @@ def task_arn(self) -> str:
return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.cluster_name}/{self.id}"
return f"arn:aws:ecs:{self.region_name}:{self._account_id}:task/{self.id}"

@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
def response_object(self, include_tags: bool = True) -> Dict[str, Any]: # type: ignore
response_object = self.gen_response_object()
if not include_tags:
response_object.pop("tags", None)
response_object["taskArn"] = self.task_arn
response_object["lastStatus"] = self.last_status
return response_object
Expand Down Expand Up @@ -1471,12 +1472,7 @@ def start_task(
self.tasks[cluster.name][task.task_arn] = task
return tasks

def describe_tasks(
self,
cluster_str: str,
tasks: Optional[str],
include: Optional[List[str]] = None,
) -> List[Task]:
def describe_tasks(self, cluster_str: str, tasks: Optional[str]) -> List[Task]:
"""
Only include=TAGS is currently supported.
"""
Expand All @@ -1495,22 +1491,18 @@ def describe_tasks(
):
task.advance()
response.append(task)
if "TAGS" in (include or []):
return response

for task in response:
task.tags = []
return response

def list_tasks(
self,
cluster_str: str,
container_instance: Optional[str],
family: str,
started_by: str,
service_name: str,
desiredStatus: str,
) -> List[str]:
cluster_str: Optional[str] = None,
container_instance: Optional[str] = None,
family: Optional[str] = None,
started_by: Optional[str] = None,
service_name: Optional[str] = None,
desiredStatus: Optional[str] = None,
) -> List[Task]:
filtered_tasks = []
for tasks in self.tasks.values():
for task in tasks.values():
Expand Down Expand Up @@ -1554,7 +1546,7 @@ def list_tasks(
filter(lambda t: t.desired_status == desiredStatus, filtered_tasks)
)

return [t.task_arn for t in filtered_tasks]
return filtered_tasks

def stop_task(self, cluster_str: str, task_str: str, reason: str) -> Task:
cluster = self._get_cluster(cluster_str)
Expand Down Expand Up @@ -2080,6 +2072,10 @@ def _get_resource(self, resource_arn: str, parsed_arn: Dict[str, str]) -> Any:
return task_def
elif parsed_arn["service"] == "capacity-provider":
return self._get_provider(parsed_arn["id"])
elif parsed_arn["service"] == "task":
for task in self.list_tasks():
if task.task_arn == resource_arn:
return task
raise NotImplementedError()

def list_tags_for_resource(self, resource_arn: str) -> List[Dict[str, str]]:
Expand Down
19 changes: 11 additions & 8 deletions moto/ecs/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,19 @@ def run_task(self) -> str:
network_configuration,
)
return json.dumps(
{"tasks": [task.response_object for task in tasks], "failures": []}
{"tasks": [task.response_object() for task in tasks], "failures": []}
)

def describe_tasks(self) -> str:
cluster = self._get_param("cluster", "default")
tasks = self._get_param("tasks")
include = self._get_param("include")
data = self.ecs_backend.describe_tasks(cluster, tasks, include)
include_tags = "TAGS" in self._get_param("include", [])
data = self.ecs_backend.describe_tasks(cluster, tasks)
return json.dumps(
{"tasks": [task.response_object for task in data], "failures": []}
{
"tasks": [task.response_object(include_tags) for task in data],
"failures": [],
}
)

def start_task(self) -> str:
Expand All @@ -221,7 +224,7 @@ def start_task(self) -> str:
tags,
)
return json.dumps(
{"tasks": [task.response_object for task in tasks], "failures": []}
{"tasks": [task.response_object() for task in tasks], "failures": []}
)

def list_tasks(self) -> str:
Expand All @@ -231,22 +234,22 @@ def list_tasks(self) -> str:
started_by = self._get_param("startedBy")
service_name = self._get_param("serviceName")
desiredStatus = self._get_param("desiredStatus")
task_arns = self.ecs_backend.list_tasks(
tasks = self.ecs_backend.list_tasks(
cluster_str,
container_instance,
family,
started_by,
service_name,
desiredStatus,
)
return json.dumps({"taskArns": task_arns})
return json.dumps({"taskArns": [t.task_arn for t in tasks]})

def stop_task(self) -> str:
cluster_str = self._get_param("cluster", "default")
task = self._get_param("task")
reason = self._get_param("reason")
task = self.ecs_backend.stop_task(cluster_str, task, reason)
return json.dumps({"task": task.response_object})
return json.dumps({"task": task.response_object()})

def create_service(self) -> str:
cluster_str = self._get_param("cluster", "default")
Expand Down
36 changes: 0 additions & 36 deletions tests/test_ecs/test_ecs_boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,42 +2481,6 @@ def test_describe_tasks_empty_tags():
assert len(response["tasks"]) == 1


@mock_ec2
@mock_ecs
def test_describe_tasks_include_tags():
client = boto3.client("ecs", region_name=ECS_REGION)
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)

task_tags = [{"key": "Name", "value": "test_ecs_task"}]
tasks_arns = [
task["taskArn"]
for task in client.run_task(
cluster="test_ecs_cluster",
overrides={},
taskDefinition="test_ecs_task",
count=2,
startedBy="moto",
tags=task_tags,
)["tasks"]
]
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"]
)

assert len(response["tasks"]) == 2
assert set(
[response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]]
) == set(tasks_arns)
assert response["tasks"][0]["tags"] == task_tags

# Test we can pass task ids instead of ARNs
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]]
)
assert len(response["tasks"]) == 1


@mock_ecs
def test_describe_tasks_exceptions():
client = boto3.client("ecs", region_name=ECS_REGION)
Expand Down
71 changes: 71 additions & 0 deletions tests/test_ecs/test_ecs_task_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import boto3

from moto import mock_ec2, mock_ecs
from .test_ecs_boto3 import setup_ecs_cluster_with_ec2_instance


@mock_ec2
@mock_ecs
def test_describe_tasks_include_tags():
client = boto3.client("ecs", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)

task_tags = [{"key": "Name", "value": "test_ecs_task"}]
tasks_arns = [
task["taskArn"]
for task in client.run_task(
cluster="test_ecs_cluster",
taskDefinition="test_ecs_task",
count=2,
tags=task_tags,
)["tasks"]
]
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=tasks_arns, include=["TAGS"]
)

assert len(response["tasks"]) == 2
assert set(
[response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]]
) == set(tasks_arns)
assert response["tasks"][0]["tags"] == task_tags

# Test we can pass task ids instead of ARNs
response = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]]
)
assert len(response["tasks"]) == 1

tags = client.list_tags_for_resource(resourceArn=tasks_arns[0])["tags"]
assert tags == task_tags


@mock_ec2
@mock_ecs
def test_add_tags_to_task():
client = boto3.client("ecs", region_name="us-east-1")
test_cluster_name = "test_ecs_cluster"
setup_ecs_cluster_with_ec2_instance(client, test_cluster_name)

task_tags = [{"key": "k1", "value": "v1"}]
task_arn = client.run_task(
cluster="test_ecs_cluster",
taskDefinition="test_ecs_task",
count=1,
tags=task_tags,
)["tasks"][0]["taskArn"]

client.tag_resource(resourceArn=task_arn, tags=[{"key": "k2", "value": "v2"}])

tags = client.describe_tasks(
cluster="test_ecs_cluster", tasks=[task_arn], include=["TAGS"]
)["tasks"][0]["tags"]
assert len(tags) == 2
assert {"key": "k1", "value": "v1"} in tags
assert {"key": "k2", "value": "v2"} in tags

client.untag_resource(resourceArn=task_arn, tagKeys=["k2"])

resp = client.list_tags_for_resource(resourceArn=task_arn)
assert resp["tags"] == [{"key": "k1", "value": "v1"}]

0 comments on commit ed56ffd

Please sign in to comment.