fix sequential run issue (#3643)
This commit is contained in:
@@ -4,6 +4,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
import structlog
|
||||
from types_boto3_batch.client import BatchClient
|
||||
from types_boto3_ec2.client import EC2Client
|
||||
from types_boto3_ecs.client import ECSClient
|
||||
from types_boto3_s3.client import S3Client
|
||||
@@ -31,6 +32,7 @@ class AWSClientType(StrEnum):
|
||||
SECRETS_MANAGER = "secretsmanager"
|
||||
ECS = "ecs"
|
||||
EC2 = "ec2"
|
||||
BATCH = "batch"
|
||||
|
||||
|
||||
class AsyncAWSClient:
|
||||
@@ -64,6 +66,9 @@ class AsyncAWSClient:
|
||||
def _ec2_client(self) -> EC2Client:
|
||||
return self.session.client(AWSClientType.EC2, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
def _batch_client(self) -> BatchClient:
|
||||
return self.session.client(AWSClientType.BATCH, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
def _create_tag_string(self, tags: dict[str, str]) -> str:
|
||||
return "&".join([f"{k}={v}" for k, v in tags.items()])
|
||||
|
||||
@@ -394,6 +399,61 @@ class AsyncAWSClient:
|
||||
async with self._ec2_client() as client:
|
||||
return await client.describe_network_interfaces(NetworkInterfaceIds=network_interface_ids)
|
||||
|
||||
###### Batch ######
|
||||
async def describe_job(self, job_id: str) -> dict:
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/describe_jobs.html
|
||||
async with self._batch_client() as client:
|
||||
response = await client.describe_jobs(jobs=[job_id])
|
||||
return response["jobs"][0] if response["jobs"] else {}
|
||||
|
||||
async def list_jobs(self, job_queue: str, job_status: str) -> list[dict]:
|
||||
# NOTE: AWS batch only records the latest 7 days jobs by default
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/list_jobs.html
|
||||
async with self._batch_client() as client:
|
||||
total_jobs = []
|
||||
async for page in client.get_paginator("list_jobs").paginate(jobQueue=job_queue, jobStatus=job_status):
|
||||
for job in page["jobSummaryList"]:
|
||||
total_jobs.append(job)
|
||||
|
||||
return total_jobs
|
||||
|
||||
async def submit_job(
|
||||
self,
|
||||
job_name: str,
|
||||
job_queue: str,
|
||||
job_definition: str,
|
||||
params: dict,
|
||||
job_priority: int | None = None,
|
||||
share_identifier: str | None = None,
|
||||
container_overrides: dict | None = None,
|
||||
depends_on_ids: list[str] | None = None,
|
||||
) -> str | None:
|
||||
container_overrides = container_overrides or {}
|
||||
depends_on = [{"jobId": job_id} for job_id in depends_on_ids or []]
|
||||
async with self._batch_client() as client:
|
||||
if job_priority is None or share_identifier is None:
|
||||
response = await client.submit_job(
|
||||
jobName=job_name,
|
||||
jobQueue=job_queue,
|
||||
jobDefinition=job_definition,
|
||||
parameters=params,
|
||||
containerOverrides=container_overrides,
|
||||
dependsOn=depends_on,
|
||||
)
|
||||
return response.get("jobId")
|
||||
else:
|
||||
response = await client.submit_job(
|
||||
jobName=job_name,
|
||||
jobQueue=job_queue,
|
||||
jobDefinition=job_definition,
|
||||
parameters=params,
|
||||
schedulingPriorityOverride=job_priority,
|
||||
shareIdentifier=share_identifier,
|
||||
containerOverrides=container_overrides,
|
||||
dependsOn=depends_on,
|
||||
)
|
||||
return response.get("jobId")
|
||||
|
||||
|
||||
class S3Uri:
|
||||
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
|
||||
|
||||
Reference in New Issue
Block a user