diff --git a/config.yml.example b/config.yml.example index d256d0ccf..f5c1ceb5e 100644 --- a/config.yml.example +++ b/config.yml.example @@ -4,6 +4,7 @@ postgresql: dbname: green-coding password: PLEASE_CHANGE_THIS port: 9573 + retry_timeout: 300 # Total time to retry database connections on outage/failure (5 minutes) redis: host: green-coding-redis-container diff --git a/lib/db.py b/lib/db.py index 3bece8ab1..c75a0cde3 100644 --- a/lib/db.py +++ b/lib/db.py @@ -1,13 +1,74 @@ #pylint: disable=consider-using-enumerate import os +import time +import random +from functools import wraps from psycopg_pool import ConnectionPool import psycopg.rows +import psycopg import pytest from lib.global_config import GlobalConfig def is_pytest_session(): return "pytest" in os.environ.get('_', '') +def with_db_retry(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + config = GlobalConfig().config + retry_timeout = config.get('postgresql', {}).get('retry_timeout', 300) + retry_interval = 1 # Base interval for exponential backoff + + start_time = time.time() + attempt = 0 + + while time.time() - start_time < retry_timeout: + attempt += 1 + try: + return func(self, *args, **kwargs) + except (psycopg.OperationalError, psycopg.DatabaseError) as e: + # Check if this is a connection-related error that we should retry + error_str = str(e).lower() + retryable_errors = [ + 'connection', 'closed', 'terminated', 'timeout', 'network', + 'server', 'unavailable', 'refused', 'reset', 'broken pipe' + ] + + is_retryable = any(keyword in error_str for keyword in retryable_errors) + + if not is_retryable: + # Non-retryable error (e.g., SQL syntax error) + print(f"Database error (non-retryable): {e}") + raise + + time_elapsed = time.time() - start_time + if time_elapsed >= retry_timeout: + print(f"Database retry timeout after {attempt} attempts over {time_elapsed:.1f} seconds. Last error: {e}") + raise + + # Exponential backoff with jitter + backoff_time = min(retry_interval * (2 ** (attempt - 1)), 30) # Cap at 30 seconds + jitter = random.uniform(0.1, 0.5) * backoff_time + sleep_time = backoff_time + jitter + + print(f"Database connection error (attempt {attempt}): {e}. Retrying in {sleep_time:.2f} seconds...") + + # Try to recreate the connection pool if it's corrupted + try: + if hasattr(self, '_pool'): + self._pool.close() + del self._pool + self._create_pool() + except (psycopg.OperationalError, psycopg.DatabaseError, AttributeError) as pool_error: + print(f"Failed to recreate connection pool: {pool_error}") + + time.sleep(sleep_time) + + # If we get here, we've exhausted all retries + raise psycopg.OperationalError(f"Database connection failed after {attempt} attempts over {time.time() - start_time:.1f} seconds") + + return wrapper + class DB: def __new__(cls): @@ -19,29 +80,34 @@ def __new__(cls): return cls.instance def __init__(self): - if not hasattr(self, '_pool'): - config = GlobalConfig().config - - # Important note: We are not using cursor_factory = psycopg2.extras.RealDictCursor - # as an argument, because this would increase the size of a single API request - # from 50 kB to 100kB. - # Users are required to use the mask of the API requests to read the data. - # force domain socket connection by not supplying host - # pylint: disable=consider-using-f-string - - self._pool = ConnectionPool( - "user=%s password=%s host=%s port=%s dbname=%s sslmode=require" % ( - config['postgresql']['user'], - config['postgresql']['password'], - config['postgresql']['host'], - config['postgresql']['port'], - config['postgresql']['dbname'], - ), - min_size=1, - max_size=2, - open=True, - ) + self._create_pool() + + def _create_pool(self): + config = GlobalConfig().config + + # Important note: We are not using cursor_factory = psycopg2.extras.RealDictCursor + # as an argument, because this would increase the size of a single API request + # from 50 kB to 100kB. + # Users are required to use the mask of the API requests to read the data. + # force domain socket connection by not supplying host + # pylint: disable=consider-using-f-string + + self._pool = ConnectionPool( + "user=%s password=%s host=%s port=%s dbname=%s sslmode=require" % ( + config['postgresql']['user'], + config['postgresql']['password'], + config['postgresql']['host'], + config['postgresql']['port'], + config['postgresql']['dbname'], + ), + min_size=1, + max_size=2, + open=True, + # Explicitly disabled (default) to prevent measurement interference + # from conn.execute("") calls, using @with_db_retry instead + check=None + ) def shutdown(self): if hasattr(self, '_pool'): @@ -49,19 +115,27 @@ def shutdown(self): del self._pool - def __query(self, query, params=None, return_type=None, fetch_mode=None): + # Query list only supports SELECT queries + # If we ever need complex queries in the future where we have a transaction that mixes SELECTs and INSERTS + # then this class needs a refactoring. Until then we can KISS it + def __query_multi(self, query, params=None): + with self._pool.connection() as conn: + conn.autocommit = False # should be default, but we are explicit + cur = conn.cursor(row_factory=None) # None is actually the default cursor factory + for i in range(len(query)): + # In error case the context manager will ROLLBACK the whole transaction + cur.execute(query[i], params[i]) + conn.commit() + + @with_db_retry + def __query_single(self, query, params=None, return_type=None, fetch_mode=None): ret = False row_factory = psycopg.rows.dict_row if fetch_mode == 'dict' else None with self._pool.connection() as conn: conn.autocommit = False # should be default, but we are explicit cur = conn.cursor(row_factory=row_factory) # None is actually the default cursor factory - if isinstance(query, list) and isinstance(params, list) and len(query) == len(params): - for i in range(len(query)): - # In error case the context manager will ROLLBACK the whole transaction - cur.execute(query[i], params[i]) - else: - cur.execute(query, params) + cur.execute(query, params) conn.commit() if return_type == 'one': ret = cur.fetchone() @@ -72,15 +146,21 @@ def __query(self, query, params=None, return_type=None, fetch_mode=None): return ret + + def query(self, query, params=None, fetch_mode=None): - return self.__query(query, params=params, return_type=None, fetch_mode=fetch_mode) + return self.__query_single(query, params=params, return_type=None, fetch_mode=fetch_mode) + + def query_multi(self, query, params=None): + return self.__query_multi(query, params=params) def fetch_one(self, query, params=None, fetch_mode=None): - return self.__query(query, params=params, return_type='one', fetch_mode=fetch_mode) + return self.__query_single(query, params=params, return_type='one', fetch_mode=fetch_mode) def fetch_all(self, query, params=None, fetch_mode=None): - return self.__query(query, params=params, return_type='all', fetch_mode=fetch_mode) + return self.__query_single(query, params=params, return_type='all', fetch_mode=fetch_mode) + @with_db_retry def import_csv(self, filename): raise NotImplementedError('Code still flakes on ; in data. Please rework') # pylint: disable=unreachable @@ -94,6 +174,7 @@ def import_csv(self, filename): cur.execute(statement) conn.autocommit = False + @with_db_retry def copy_from(self, file, table, columns, sep=','): with self._pool.connection() as conn: conn.autocommit = False # is implicit default diff --git a/tests/lib/test_db.py b/tests/lib/test_db.py new file mode 100644 index 000000000..5baa49a51 --- /dev/null +++ b/tests/lib/test_db.py @@ -0,0 +1,144 @@ +import unittest +from unittest.mock import Mock, patch +import io +import psycopg +from lib.db import with_db_retry, DB + + +class TestWithDbRetryDecorator(unittest.TestCase): + """Test the @with_db_retry decorator using mocks to simulate various error conditions. + + These tests verify the retry logic, timeout behavior, and error classification + without requiring actual database connections. + """ + + def setUp(self): + class MockDB: + def __init__(self): + self._pool = Mock() + + def _create_pool(self): + self._pool = Mock() + + @with_db_retry + def failing_method(self): + raise psycopg.OperationalError("connection refused") + + @with_db_retry + def non_retryable_method(self): + raise psycopg.DatabaseError("syntax error at or near") + + self.mock_db = MockDB() + + @patch('time.time') + @patch('time.sleep') + def test_retry_on_retryable_errors(self, mock_sleep, mock_time): + # Mock time progression: start=0, while_loop=1, elapsed_check=2, timeout_while=350, final_msg=350 + mock_time.side_effect = [0, 1, 2, 350, 350] + + with self.assertRaises(psycopg.OperationalError): + self.mock_db.failing_method() + + # Verify that sleep was called (indicating a retry attempt) + self.assertTrue(mock_sleep.called) + + def test_non_retryable_errors(self): + with self.assertRaises(psycopg.DatabaseError) as cm: + self.mock_db.non_retryable_method() + + self.assertIn("syntax error", str(cm.exception)) + + @patch('time.time') + @patch('time.sleep') + @patch('builtins.print') + def test_timeout_behavior(self, mock_print, mock_sleep, mock_time): + # Mock time: start=0, while_check=1, elapsed_check=350 (timeout) + mock_time.side_effect = [0, 1, 350] + + with self.assertRaises(psycopg.OperationalError) as cm: + self.mock_db.failing_method() + + # Original error is raised + self.assertIn("connection refused", str(cm.exception)) + + # But timeout message is printed + timeout_call = None + for call in mock_print.call_args_list: + if "Database retry timeout" in str(call): + timeout_call = call + break + self.assertIsNotNone(timeout_call) + + # Sleep should not be called since timeout occurs before sleep + self.assertFalse(mock_sleep.called) + + +class TestDbIntegration(unittest.TestCase): + """Integration tests for DB class methods using real database connections. + + These tests verify actual database operations against a test PostgreSQL database without mocking. + """ + + def setUp(self): + self.db = DB() + self.table_name = "test_integration_table" + + def test_basic_query_execution(self): + result = self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)") + self.assertIn("CREATE TABLE", result) + + def test_fetch_one_operation(self): + self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)") + self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'test')") + + result = self.db.fetch_one(f"SELECT id, name FROM {self.table_name} WHERE id = 1") + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 'test') + + def test_fetch_all_operation(self): + self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)") + self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'test1'), (2, 'test2')") + + results = self.db.fetch_all(f"SELECT id, name FROM {self.table_name} ORDER BY id") + self.assertEqual(len(results), 2) + self.assertEqual(results[0][0], 1) + self.assertEqual(results[1][0], 2) + + def test_parameter_binding(self): + self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)") + + self.db.query(f"INSERT INTO {self.table_name} VALUES (%s, %s)", (1, 'param_test')) + result = self.db.fetch_one(f"SELECT name FROM {self.table_name} WHERE id = %s", (1,)) + self.assertEqual(result[0], 'param_test') + + def test_fetch_mode_dict(self): + self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)") + self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'dict_test')") + + result = self.db.fetch_one(f"SELECT id, name FROM {self.table_name}", fetch_mode='dict') + self.assertIsInstance(result, dict) + self.assertEqual(result['id'], 1) + self.assertEqual(result['name'], 'dict_test') + + def test_error_handling_invalid_sql(self): + with self.assertRaises(psycopg.DatabaseError): + self.db.query("INVALID SQL STATEMENT") + + def test_copy_from_csv_data(self): + self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT, value NUMERIC)") + + csv_data = io.StringIO("1,test1,10.5\n2,test2,20.7\n") + columns = ['id', 'name', 'value'] + + self.db.copy_from(csv_data, self.table_name, columns) + + results = self.db.fetch_all(f"SELECT id, name, value FROM {self.table_name} ORDER BY id") + self.assertEqual(len(results), 2) + self.assertEqual(results[0][0], 1) + self.assertEqual(results[0][1], 'test1') + self.assertEqual(results[1][0], 2) + self.assertEqual(results[1][1], 'test2') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_functions.py b/tests/test_functions.py index af22cd242..7daede367 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -281,8 +281,54 @@ def __enter__(self): return self def run_until(self, step): + """ + Execute the runner pipeline until the specified step. + + Args: + step (str): The step name to stop at. Valid pause points: + 'import_metric_providers', 'initialize_run', 'save_image_and_volume_sizes', + 'setup_networks', 'setup_services' + + Raises: + RuntimeError: If called outside of the context manager. + + Note: + This is a convenience wrapper around run_steps(stop_at=step). + For more control and inspection capabilities, use run_steps() directly. + """ + for _ in self.run_steps(stop_at=step): + pass + + def run_steps(self, stop_at=None): + """ + Generator that executes the runner pipeline, yielding at predefined pause points. + + Args: + stop_at (str, optional): If provided, stops execution after reaching this pause point. + Valid pause points: 'import_metric_providers', 'initialize_run', + 'save_image_and_volume_sizes', 'setup_networks', 'setup_services' + + Yields: + str: The name of the pause point that was just reached, allowing for inspection + before continuing execution. + + Raises: + RuntimeError: If called outside of the context manager. + + Example: + # Run with inspection at all pause points: + with RunUntilManager(runner) as context: + for pause_point in context.run_steps(): + print(f"Reached pause point: {pause_point}") + + # Run until specific pause point (with inspection at all pause points along the way): + with RunUntilManager(runner) as context: + for pause_point in context.run_steps(stop_at='initialize_run'): + print(f"Reached pause point: {pause_point}") + # This will print both 'import_metric_providers' and 'initialize_run' + """ if not getattr(self, '_active', False): - raise RuntimeError("run_until must be used within the context") + raise RuntimeError("run_steps must be used within the context") try: self.__runner._start_measurement() @@ -294,7 +340,8 @@ def run_until(self, step): self.__runner._initial_parse() self.__runner._register_machine_id() self.__runner._import_metric_providers() - if step == 'import_metric_providers': + yield 'import_metric_providers' + if stop_at == 'import_metric_providers': return self.__runner._populate_image_names() self.__runner._prepare_docker() @@ -302,7 +349,9 @@ def run_until(self, step): self.__runner._remove_docker_images() self.__runner._download_dependencies() self.__runner._initialize_run() - + yield 'initialize_run' + if stop_at == 'initialize_run': + return self.__runner._start_metric_providers(allow_other=True, allow_container=False) self.__runner._custom_sleep(self.__runner._measurement_pre_test_sleep) @@ -315,13 +364,17 @@ def run_until(self, step): self.__runner._end_phase('[INSTALLATION]') self.__runner._save_image_and_volume_sizes() - + yield 'save_image_and_volume_sizes' + if stop_at == 'save_image_and_volume_sizes': + return self.__runner._start_phase('[BOOT]') self.__runner._setup_networks() - if step == 'setup_networks': + yield 'setup_networks' + if stop_at == 'setup_networks': return self.__runner._setup_services() - if step == 'setup_services': + yield 'setup_services' + if stop_at == 'setup_services': return self.__runner._end_phase('[BOOT]') diff --git a/tests/test_runner.py b/tests/test_runner.py index 59c9b77e2..d8b2350fe 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -598,3 +598,29 @@ def test_print_logs_flag_with_iterations(): assert test_log_pos < test_error_pos assert ps.stderr == '', Tests.assertion_info('no errors', ps.stderr) + +## automatic database reconnection +def test_database_reconnection_during_run(): + """Verify GMT runner handles database reconnection during execution + + This test simulates a database outage scenario: + 1. A first succesful database query occurs at step 'initialize_run' + 2. After this step, a database restart is triggered to simulate an outage + 3. The next database query occurs at step 'save_image_and_volume_sizes': + Initially it fails due to the outage, but the retry mechanism should recover it + """ + + out = io.StringIO() + err = io.StringIO() + runner = ScenarioRunner(uri=GMT_DIR, uri_type='folder', filename='tests/data/usage_scenarios/basic_stress.yml', skip_system_checks=True, dev_cache_build=True, dev_no_sleeps=True, dev_no_metrics=True, dev_no_optimizations=True) + + with redirect_stdout(out), redirect_stderr(err): + with Tests.RunUntilManager(runner) as context: + for pause_point in context.run_steps(stop_at='save_image_and_volume_sizes'): + if pause_point == 'initialize_run': + # Simulate short db outage + result = subprocess.run(['docker', 'restart', '-t', '0', 'test-green-coding-postgres-container'], + check=True, capture_output=True) + + assert ('Database connection error' in out.getvalue() and 'Retrying in' in out.getvalue()), \ + "No database retry messages found - test may not have properly simulated database outage"