Skip to content

Commit

Permalink
feat(construct): implemented review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dinsajwa committed Feb 21, 2025
1 parent 995d808 commit 964c3b9
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 132 deletions.
131 changes: 66 additions & 65 deletions lambda/aws-bedrock-data-automation/create-blueprint/lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
tracer = Tracer()
metrics = Metrics(namespace="CREATE_BLUEPRINT")

input_bucket = os.environ.get('INPUT_BUCKET')

class OperationType(str, Enum):
CREATE_BLUEPRINT = "CREATE"
DELETE_BLUEPRINT = "DELETE"
LIST_BLUEPRINTS = "LIST"
UPDATE_BLUEPRINT = "UPDATE"
GET_BLUEPRINT = "GET"

input_bucket = os.environ.get('INPUT_BUCKET')


def process_event_bridge_event(event: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -109,14 +109,13 @@ def get_schema(bucket_name: str, schema_key: str) -> Dict[str, Any]:
})
raise

#@logger.inject_lambda_context
@logger.inject_lambda_context
def handler(event, context: LambdaContext):
"""
Lambda handler function
"""
try:
logger.info(f"Received event: {json.dumps(event)}")
# Determine event source and process accordingly
if event.get("source") and event.get("detail-type"):
blueprint_details = process_event_bridge_event(event)
else:
Expand All @@ -128,77 +127,79 @@ def handler(event, context: LambdaContext):
if operation_type not in [stage.value for stage in OperationType]:
raise ValueError(f"Invalid operation type: {operation_type}. Must be one of {[stage.value for stage in OperationType]}")

if operation_type.lower() == 'delete':
logger.info("delete blueprint")

blueprint_arn = blueprint_details.get('blueprint_arn')
blueprint_version = blueprint_details.get('blueprint_version')

if not blueprint_arn:
raise ValueError("blueprint_arn is required for delete operation")
match operation_type.lower():
case "delete":
logger.info(f"deleteing blueprint {blueprint_details}")
blueprint_arn = blueprint_details.get('blueprint_arn')
blueprint_version = blueprint_details.get('blueprint_version')

return delete_blueprint(blueprint_arn, blueprint_version)
if not blueprint_arn:
raise ValueError("blueprint_arn is required for delete operation")

return delete_blueprint(blueprint_arn, blueprint_version)

elif operation_type.lower() == 'list':
logger.info("Listing all blueprints")
return list_blueprints(blueprint_details)
case "list":
logger.info("Listing all blueprints")
return list_blueprints(blueprint_details)

elif operation_type.lower() == 'get':
logger.info("Get blueprint")
return get_blueprint(blueprint_details)
case "get":
logger.info(f"Get blueprint {blueprint_details}")
return get_blueprint(blueprint_details)

elif operation_type.lower() == 'update':
logger.info("update blueprints")
return update_blueprint(blueprint_details)
case "update":
logger.info(f"update blueprint {blueprint_details}")
return update_blueprint(blueprint_details)

elif operation_type.lower() == 'create':
logger.info("create blueprint")
case "create":
logger.info("create blueprint")

# Check if schema_file_name is present
if 'schema_file_name' in blueprint_details:
input_key = blueprint_details['schema_file_name']
# Get schema from S3
logger.info(f"Retrieving schema from S3: {input_bucket}/{input_key}")
schema_content = get_schema(input_bucket, input_key)
if isinstance(schema_content, dict) and 'statusCode' in schema_content:
return schema_content
if 'schema_file_name' in blueprint_details:
input_key = blueprint_details['schema_file_name']

logger.info(f"Retrieving schema from S3: {input_bucket}/{input_key}")
schema_content = get_schema(input_bucket, input_key)
if isinstance(schema_content, dict) and 'statusCode' in schema_content:
return schema_content


# Check if schema_fields is present
if 'schema_fields' in blueprint_details:
schema_fields = blueprint_details['schema_fields']

# Validate schema_fields format
if not isinstance(schema_fields, list):
raise ValueError("schema_fields must be a list of field configurations")

# Validate each field has required properties
for field in schema_fields:
if not all(key in field for key in ['name', 'description', 'alias']):
raise ValueError("Each field must contain 'name', 'description', and 'alias'")

# Create schema using the fields
try:
DynamicSchema = create_schema(schema_fields)
schema_instance = DynamicSchema()
schema_content = json.dumps(schema_instance.model_json_schema())
if 'schema_fields' in blueprint_details:
schema_fields = blueprint_details['schema_fields']

# Validate schema_fields format
if not isinstance(schema_fields, list):
raise ValueError("schema_fields must be a list of field configurations")

# Validate each field has required properties
for field in schema_fields:
if not all(key in field for key in ['name', 'description', 'alias']):
raise ValueError("Each field must contain 'name', 'description', and 'alias'")

# Create schema using the fields
try:
DynamicSchema = create_schema(schema_fields)
schema_instance = DynamicSchema()
schema_content = json.dumps(schema_instance.model_json_schema())

except Exception as e:
print("Error creating schema")
return {
'statusCode': 500,
'body': json.dumps({
'message': 'Error creating schema',
'error': str(e)
})
}
except Exception as e:
print("Error creating schema")
return {
'statusCode': 500,
'body': json.dumps({
'message': 'Error creating schema',
'error': str(e)
})
}

return create_blueprint(schema_content,blueprint_details)

# Create blueprint with schema content
return create_blueprint(schema_content,blueprint_details)

else:
logger.warning(f"Unknown operation type: {operation_type}")

case _:
logger.warning(f"Unknown operation type: {operation_type}")
return {
'statusCode': 400,
'body': json.dumps({
'message': f'Unknown operation type: {operation_type}'
})
}

except Exception as e:
print(f"Unexpected error: {str(e)}")
Expand Down
94 changes: 40 additions & 54 deletions lambda/aws-bedrock-data-automation/create_project/lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,70 +65,56 @@ def handler(event: Dict[str, Any], context: LambdaContext) -> Dict[str, Any]:
else:
project_config = process_api_gateway_event(event)

operation = project_config.get('operation', '')
operation_type = project_config.get('operation_type', '')

logger.info("Project configuration", extra={"config": project_config})

if operation == 'create':
response = create_project(project_config)
return {
'statusCode': 200,
'body': json.dumps({
'message': 'Project created successfully',
'response': response,
})
}

elif operation == 'update':
# Validate project ARN for update
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for update operation")

response = update_project(project_config)

return {
'statusCode': 200,
'body': json.dumps({
'message': 'Project updated successfully',
'response': response
})
}

elif operation == 'delete':
# Validate project ARN for delete
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for delete operation")

delete_project(project_config['projectArn'])

match operation_type.lower():
case "create":
response = create_project(project_config)
response_msg='Project created successfully'

case "update":
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for update operation")

response = update_project(project_config)
response_msg='Project updated successfully'

return {
'statusCode': 200,
'body': json.dumps({
'message': 'Project deleted successfully',
'projectArn': project_config['projectArn']
})
}

elif operation == 'get':
# Validate project ARN for get
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for get operation")
case "delete":
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for delete operation")
delete_project(project_config['projectArn'])
response_msg='Project deleted successfully'

response = get_project(project_config )

return {
'statusCode': 200,
'body': json.dumps({
'message': 'project fetched',
'response': response
})
}
case "get":
if 'projectArn' not in project_config:
raise ValueError("projectArn is required for get operation")

response = get_project(project_config )
response_msg='Project fetched successfully'


case _:
logger.warning(f"Unknown operation type: {operation_type}")
response_msg=f'Unknown operation type: {operation_type}'
status_code=400

logger.info("Project configuration", extra={"config": project_config})

return {
'status_code': status_code if status_code else 200,
'body': json.dumps({
'message': response_msg,
'response': response
})
}

except Exception as e:
logger.error("Unexpected error", extra={"error": str(e)})
return {
'statusCode': 500,
'status_code': 500,
'body': json.dumps({
'message': 'Internal server error',
'error': str(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ class AudioGenerativeField(str, Enum):
CHAPTER_SUMMARY = "CHAPTER_SUMMARY"
IAB = "IAB"

def ensure_list(x):
"""
Ensures the input is always returned as a list.
If input is not a list, converts it to a single-item list.
If input is already a list, returns it unchanged.
Args:
x: Any type of input
Returns:
list: Input converted to or kept as list
"""
return [x] if not isinstance(x, list) else x
class ProjectConfig:
"""Configuration class for Bedrock Data Automation project settings"""

Expand Down Expand Up @@ -102,7 +115,7 @@ def _get_document_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
return {
'extraction': {
'granularity': {
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('extraction', {}).get('granularity', {}).get('types', [DocumentGranularity.DOCUMENT.value])
)
},
Expand All @@ -115,7 +128,7 @@ def _get_document_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
},
'outputFormat': {
'textFormat': {
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('document', {}).get('outputFormat', {}).get('textFormat', {}).get('types', ['PLAIN_TEXT'])
)
},
Expand All @@ -132,7 +145,7 @@ def _get_image_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
'extraction': {
'category': {
'state': config.get('extraction', {}).get('category', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('image', {}).get('extraction', {}).get('category', {}).get('types', ['CONTENT_MODERATION'])
)
},
Expand All @@ -142,7 +155,7 @@ def _get_image_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
},
'generativeField': {
'state': config.get('generativeField', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('image', {}).get('generativeField', {}).get('types', ['IMAGE_SUMMARY'])
)
}
Expand All @@ -155,7 +168,7 @@ def _get_video_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
'extraction': {
'category': {
'state': config.get('extraction', {}).get('category', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('video', {}).get('extraction', {}).get('category', {}).get('types', ['CONTENT_MODERATION'])
)
},
Expand All @@ -165,7 +178,7 @@ def _get_video_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
},
'generativeField': {
'state': config.get('generativeField', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('video', {}).get('generativeField', {}).get('types', ['VIDEO_SUMMARY'])
)
}
Expand All @@ -180,13 +193,13 @@ def _get_audio_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
'extraction': {
'category': {
'state': config.get('extraction', {}).get('category', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('audio', {}).get('extraction', {}).get('category', {}).get('types', ['TRANSCRIPT'])
) }
},
'generativeField': {
'state': config.get('generativeField', {}).get('state', State.DISABLED.value),
'types': (lambda x: [x] if not isinstance(x, list) else x)(
'types': ensure_list(
config.get('audio', {}).get('generativeField', {}).get('types', ['AUDIO_SUMMARY'])
) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def __init__(self, input_bucket: str, output_bucket: str, client=None):
self.client = boto3.client("bedrock-data-automation-runtime")
self.input_bucket = input_bucket
self.output_bucket = output_bucket
# self.max_retries = 60
# self.retry_delay = 10


def invoke_data_automation_async(
self,
Expand Down Expand Up @@ -120,9 +119,6 @@ def invoke_data_automation_async(

validate_configs(blueprint_config, data_automation_config)

# Create S3 URIs
#s3://cb-output-documents/noa.json
#s3://cb-input-documents/noa.pdf
input_s3_uri = f"s3://{self.input_bucket}/{input_filename}"
output_s3_uri = f"s3://{self.output_bucket}/{output_filename}"

Expand Down

0 comments on commit 964c3b9

Please sign in to comment.