Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
"""Placeholder docstring"""
from __future__ import absolute_import

import enum
import datetime
import enum
import json
import logging
import os
import re
import tempfile
import time
from uuid import uuid4
from copy import deepcopy
from uuid import uuid4

from botocore.exceptions import ClientError

import sagemaker.local.data

from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
from sagemaker.utils import DeferredError, get_config_value, format_tags
from sagemaker.local.exceptions import StepExecutionException
from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import copy_directory_structure, get_docker_host, move_to_destination
from sagemaker.utils import DeferredError, format_tags, get_config_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -272,9 +273,45 @@ def describe(self):
"AlgorithmSpecification": {
"ContainerEntrypoint": self.container.container_entrypoint,
},
"FinalMetricDataList": self._extract_final_metrics(),
}
return response

def _extract_final_metrics(self):
"""Extract metrics from container logs using metric definitions."""
if not hasattr(self.container, "logs") or not self.container.logs:
return []

# Get metric definitions from container
metric_definitions = getattr(self.container, "metric_definitions", [])
if not metric_definitions:
return []

final_metrics = []
logs = self.container.logs

for metric_def in metric_definitions:
metric_name = metric_def.get("Name")
regex_pattern = metric_def.get("Regex")

if not metric_name or not regex_pattern:
continue

# Find all matches in logs
matches = re.findall(regex_pattern, logs)
if matches:
# Use the last match as final metric
final_value = float(matches[-1])
final_metrics.append(
{
"MetricName": metric_name,
"Value": final_value,
"Timestamp": self.end_time or datetime.now(),
}
)

return final_metrics


class _LocalTransformJob(object):
"""Placeholder docstring"""
Expand Down Expand Up @@ -711,8 +748,8 @@ def __init__(
PipelineExecutionDisplayName=None,
local_session=None,
):
from sagemaker.workflow.pipeline import PipelineGraph
from sagemaker import LocalSession
from sagemaker.workflow.pipeline import PipelineGraph

self.pipeline = pipeline
self.pipeline_execution_name = execution_id
Expand Down Expand Up @@ -809,7 +846,7 @@ def mark_step_executing(self, step_name):

def _initialize_step_execution(self, steps):
"""Initialize step_execution dict."""
from sagemaker.workflow.steps import StepTypeEnum, Step
from sagemaker.workflow.steps import Step, StepTypeEnum

supported_steps_types = (
StepTypeEnum.TRAINING,
Expand Down
190 changes: 189 additions & 1 deletion tests/unit/sagemaker/local/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import urllib3
import os
from datetime import datetime
from botocore.exceptions import ClientError
from mock import Mock, patch
from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION
Expand All @@ -25,7 +26,7 @@
from sagemaker.workflow.pipeline import Pipeline
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.local.local_session import LocalSession
from sagemaker.local.entities import _LocalPipelineExecution
from sagemaker.local.entities import _LocalPipelineExecution, _LocalTrainingJob


OK_RESPONSE = urllib3.HTTPResponse()
Expand Down Expand Up @@ -1100,3 +1101,190 @@ def test_config_setter():

with pytest.raises(jsonschema.ValidationError):
session.config = INVALID_LOCAL_MODE_CONFIG


class TestLocalTrainingJobFinalMetrics:
"""Test cases for FinalMetricDataList functionality in _LocalTrainingJob."""

def test_describe_includes_final_metric_data_list(self):
"""Test that describe() includes FinalMetricDataList field."""
container = Mock()
container.logs = None
container.metric_definitions = []
job = _LocalTrainingJob(container)
job.training_job_name = "test-job"
job.state = "Completed"
job.start_time = datetime.now()
job.end_time = datetime.now()
job.model_artifacts = "/path/to/model"
job.output_data_config = {}
job.environment = {}

response = job.describe()

assert "FinalMetricDataList" in response
assert isinstance(response["FinalMetricDataList"], list)

def test_extract_final_metrics_no_logs(self):
"""Test _extract_final_metrics returns empty list when no logs."""
container = Mock()
container.logs = None
job = _LocalTrainingJob(container)

result = job._extract_final_metrics()

assert result == []

def test_extract_final_metrics_no_metric_definitions(self):
"""Test _extract_final_metrics returns empty list when no metric definitions."""
container = Mock()
container.logs = "some logs"
container.metric_definitions = []
job = _LocalTrainingJob(container)

result = job._extract_final_metrics()

assert result == []

def test_extract_final_metrics_with_valid_metrics(self):
"""Test _extract_final_metrics extracts metrics correctly."""
container = Mock()
container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete"
container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
]
job = _LocalTrainingJob(container)
job.end_time = datetime(2023, 1, 1, 12, 0, 0)

result = job._extract_final_metrics()

assert len(result) == 1
assert result[0]["MetricName"] == "ganloss"
assert result[0]["Value"] == 0.138318
assert result[0]["Timestamp"] == job.end_time

def test_extract_final_metrics_multiple_matches_uses_last(self):
"""Test _extract_final_metrics uses the last match for each metric."""
container = Mock()
container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;"
container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
]
job = _LocalTrainingJob(container)
job.end_time = datetime(2023, 1, 1, 12, 0, 0)

result = job._extract_final_metrics()

assert len(result) == 1
assert result[0]["Value"] == 0.138318

def test_extract_final_metrics_multiple_metrics(self):
"""Test _extract_final_metrics handles multiple different metrics."""
container = Mock()
container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;"
container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"},
{"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"},
{"Name": "loss", "Regex": r"Loss=([\d\.]+);"}
]
job = _LocalTrainingJob(container)
job.end_time = datetime(2023, 1, 1, 12, 0, 0)

result = job._extract_final_metrics()

assert len(result) == 3
metric_names = [m["MetricName"] for m in result]
assert "ganloss" in metric_names
assert "accuracy" in metric_names
assert "loss" in metric_names

def test_extract_final_metrics_no_matches(self):
"""Test _extract_final_metrics returns empty list when regex doesn't match."""
container = Mock()
container.logs = "Training started\nTraining complete"
container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
]
job = _LocalTrainingJob(container)

result = job._extract_final_metrics()

assert result == []

def test_extract_final_metrics_invalid_metric_definition(self):
"""Test _extract_final_metrics skips invalid metric definitions."""
container = Mock()
container.logs = "GAN_loss=0.138318;"
container.metric_definitions = [
{"Name": "ganloss"}, # Missing Regex
{"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name
{"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid
]
job = _LocalTrainingJob(container)
job.end_time = datetime(2023, 1, 1, 12, 0, 0)

result = job._extract_final_metrics()

assert len(result) == 1
assert result[0]["MetricName"] == "valid"

@patch("sagemaker.local.entities.datetime")
def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime):
"""Test _extract_final_metrics uses current time when end_time is None."""
container = Mock()
container.logs = "GAN_loss=0.138318;"
container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
]
job = _LocalTrainingJob(container)
job.end_time = None

mock_now = datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.now.return_value = mock_now

result = job._extract_final_metrics()

assert len(result) == 1
assert result[0]["Timestamp"] == mock_now

@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
def test_integration_describe_training_job_with_metrics(self, mock_train):
"""Integration test: describe_training_job includes FinalMetricDataList."""
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

algo_spec = {"TrainingImage": "my-image:1.0"}
input_data_config = [{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3Uri": "s3://bucket/data"
}
}
}]
output_data_config = {}
resource_config = {"InstanceType": "local", "InstanceCount": 1}

# Create training job
local_sagemaker_client.create_training_job(
"test-job",
algo_spec,
output_data_config,
resource_config,
InputDataConfig=input_data_config,
HyperParameters={}
)

# Mock the container logs and metric definitions
training_job = local_sagemaker_client._training_jobs["test-job"]
training_job.container.logs = "GAN_loss=0.138318;"
training_job.container.metric_definitions = [
{"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}
]

response = local_sagemaker_client.describe_training_job("test-job")

assert "FinalMetricDataList" in response
assert len(response["FinalMetricDataList"]) == 1
assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss"
assert response["FinalMetricDataList"][0]["Value"] == 0.138318
Loading
Loading