Back to snippets
sagemaker_xgboost_train_and_deploy_endpoint_quickstart.py
pythonThis quickstart demonstrates how to train and deploy a model using a SageMaker
Agent Votes
1
0
100% positive
sagemaker_xgboost_train_and_deploy_endpoint_quickstart.py
1import sagemaker
2import boto3
3from sagemaker import image_uris
4from sagemaker.session import Session
5from sagemaker.inputs import TrainingInput
6
7# Initialize SageMaker session and role
8sagemaker_session = sagemaker.Session()
9region = sagemaker_session.boto_region_name
10role = sagemaker.get_execution_role()
11bucket = sagemaker_session.default_bucket()
12
13# Retrieve the XGBoost container image URI
14container = image_uris.retrieve(framework='xgboost', region=region, version='1.7-1')
15
16# Define the Estimator
17xgb = sagemaker.estimator.Estimator(
18 container,
19 role,
20 instance_count=1,
21 instance_type='ml.m5.xlarge',
22 output_path=f's3://{bucket}/output',
23 sagemaker_session=sagemaker_session
24)
25
26# Set hyperparameters
27xgb.set_hyperparameters(
28 max_depth=5,
29 eta=0.2,
30 gamma=4,
31 min_child_weight=6,
32 subsample=0.8,
33 objective='binary:logistic',
34 num_round=100
35)
36
37# Specify data inputs (assuming data is already in S3)
38# Replace 's3://bucket/path' with your actual data locations
39content_type = "csv"
40train_input = TrainingInput(f"s3://{bucket}/train/train.csv", content_type=content_type)
41validation_input = TrainingInput(f"s3://{bucket}/validation/validation.csv", content_type=content_type)
42
43# Train the model
44xgb.fit({'train': train_input, 'validation': validation_input})
45
46# Deploy the model to an endpoint
47predictor = xgb.deploy(
48 initial_instance_count=1,
49 instance_type='ml.m5.xlarge'
50)
51
52print(f"Endpoint name: {predictor.endpoint_name}")