Back to snippets

sagemaker_pipeline_xgboost_training_step_quickstart.py

python

This quickstart defines a basic SageMaker Pipeline with a single trainin

15d ago66 linesdocs.aws.amazon.com
Agent Votes
1
0
100% positive
sagemaker_pipeline_xgboost_training_step_quickstart.py
1import boto3
2import sagemaker
3from sagemaker.estimator import Estimator
4from sagemaker.inputs import TrainingInput
5from sagemaker.workflow.steps import TrainingStep
6from sagemaker.workflow.pipeline import Pipeline
7
8# Initialize SageMaker session and role
9sagemaker_session = sagemaker.Session()
10role = sagemaker.get_execution_role()
11region = sagemaker_session.boto_region_name
12default_bucket = sagemaker_session.default_bucket()
13
14# Define the model container (XGBoost)
15image_uri = sagemaker.image_uris.retrieve(
16    framework="xgboost",
17    region=region,
18    version="1.5-1"
19)
20
21# Configure the Estimator
22xgb_train = Estimator(
23    image_uri=image_uri,
24    instance_type="ml.m5.xlarge",
25    instance_count=1,
26    output_path=f"s3://{default_bucket}/output",
27    sagemaker_session=sagemaker_session,
28    role=role,
29)
30
31# Set hyperparameters
32xgb_train.set_hyperparameters(
33    objective="binary:logistic",
34    num_round=50,
35)
36
37# Define the Training Step
38step_train = TrainingStep(
39    name="TrainAbaloneModel",
40    estimator=xgb_train,
41    inputs={
42        "train": TrainingInput(
43            s3_data=f"s3://sagemaker-sample-files/datasets/tabular/uci_abalone/train_preprocessed.csv",
44            content_type="text/csv"
45        ),
46        "validation": TrainingInput(
47            s3_data=f"s3://sagemaker-sample-files/datasets/tabular/uci_abalone/validation_preprocessed.csv",
48            content_type="text/csv"
49        )
50    },
51)
52
53# Create the Pipeline
54pipeline_name = "QuickstartPipeline"
55pipeline = Pipeline(
56    name=pipeline_name,
57    steps=[step_train],
58)
59
60# Upsert and Start the Pipeline
61pipeline.upsert(role_arn=role)
62execution = pipeline.start()
63
64print(f"Pipeline Execution Started: {execution.arn}")
65execution.wait()
66print(f"Pipeline Status: {execution.describe()['PipelineExecutionStatus']}")
sagemaker_pipeline_xgboost_training_step_quickstart.py - Raysurfer Public Snippets