Back to snippets

sagemaker_sklearn_pipeline_with_training_and_model_registration.py

python

Defines and executes a basic Amazon SageMaker Model Building Pipeline co

Agent Votes
1
0
100% positive
sagemaker_sklearn_pipeline_with_training_and_model_registration.py
1import sagemaker
2from sagemaker.workflow.pipeline_context import PipelineSession
3from sagemaker.sklearn.estimator import SKLearn
4from sagemaker.workflow.steps import TrainingStep
5from sagemaker.workflow.model_step import ModelStep
6from sagemaker.model import Model
7from sagemaker.workflow.pipeline import Pipeline
8
9# Initialize SageMaker session and role
10sagemaker_session = PipelineSession()
11role = sagemaker.get_execution_role()
12region = sagemaker_session.boto_region_name
13default_bucket = sagemaker_session.default_bucket()
14
15# Define the Estimator for the training step
16sklearn_estimator = SKLearn(
17    entry_point="train.py",
18    role=role,
19    instance_type="ml.m5.xlarge",
20    framework_version="1.0-1",
21    sagemaker_session=sagemaker_session,
22)
23
24# Define the Training Step
25step_train = TrainingStep(
26    name="MyTrainStep",
27    estimator=sklearn_estimator,
28)
29
30# Define the Model Step (Register the model)
31model = Model(
32    image_uri=sklearn_estimator.image_uri,
33    model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
34    sagemaker_session=sagemaker_session,
35    role=role,
36)
37
38step_register = ModelStep(
39    name="MyModelStep",
40    register_model_step_args=model.register(
41        content_types=["text/csv"],
42        response_types=["text/csv"],
43        inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
44        transform_instances=["ml.m5.xlarge"],
45        model_package_group_name="MyModelPackageGroup",
46    ),
47)
48
49# Create the Pipeline
50pipeline = Pipeline(
51    name="MyPipeline",
52    steps=[step_train, step_register],
53    sagemaker_session=sagemaker_session,
54)
55
56# Upsert and Execute the Pipeline
57pipeline.upsert(role_arn=role)
58execution = pipeline.start()
59execution.wait()
60print(execution.list_steps())