Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.yml.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ postgresql:
dbname: green-coding
password: PLEASE_CHANGE_THIS
port: 9573
retry_timeout: 300 # Total time to retry database connections on outage/failure (5 minutes)

redis:
host: green-coding-redis-container
Expand Down
145 changes: 113 additions & 32 deletions lib/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,74 @@
#pylint: disable=consider-using-enumerate
import os
import time
import random
from functools import wraps
from psycopg_pool import ConnectionPool
import psycopg.rows
import psycopg
import pytest
from lib.global_config import GlobalConfig

def is_pytest_session():
return "pytest" in os.environ.get('_', '')

def with_db_retry(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function looks good but I think it has a design issue.

If the connection is not on auto-commit you might be re-connection and retrying a part of the query only. This can be tricky problem to debug.

Example code :

conn.autocommit = False
query('INSERT 1 INTO table')
## Connection drops here
## connection gets reinstantiated
query('DELETE 1 FROM table') # only the delete query is now executed as first request was dropped

In order to find a solution here I believe either the queries need to be buffered ... or we could decide to still fail ... this however would defeat the purpose of this PR entierely i feel.

Currently unclear what the best implementation is without a clear view how complex a query buffering would be. Please investigate / make a concept.

  1. Als the copy() statement is currently not profiting from the statement. This would need to be covered too.
    To extend this question: Are there any other query methods that can be covered we should think of to future-proof?

  2. How does the DB behave on keep-alive queries? Will they also lead to a reconnect?
    To my understanding the postgres adapter does regular keep-alives. If they fail it will also complain. Can we hook into this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Multi-Statement Queries

The issue you describe is only relevant, if GMT uses multi-statement queries, correct?
I did not find any usage of the multi-statement query feature. Did I overlook something?
If multi-statement queries are really not relevant for GMT, the easiest solution would be to remove the following code block:

if isinstance(query, list) and isinstance(params, list) and len(query) == len(params):
    for i in range(len(query)):
        # In error case the context manager will ROLLBACK the whole transaction
        cur.execute(query[i], params[i])

And to consider enabling the autocommit feature.

If the multi-statement query feature is needed, I would do a more comprehensive analysis on how to become transactional.

2. Cover copy method

Done.

3. PostgreSQL keep-alive queries?

There are no automatic keep-alive queries:

The pool doesn’t actively check the state of the connections held in its state. This means that, if communication with the server is lost, or if a connection is closed for other reasons [...], the application might be served a connection in broken state.
https://www.psycopg.org/psycopg3/docs/advanced/pool.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amendment to the 3. point "PostgreSQL keep-alive queries?"

psycopg3 includes a built-in ConnectionPool retry mechanisms. It is disabled by default, but can be enabled with the parameter check:

self._pool = ConnectionPool(
    # ...
    check=ConnectionPool.check_connection
)

Comparison

How does it compare to the custom @with_db_retry decorator implementation?

psycopg3 Built-in Limitations

Measurement Interference:

  • check=ConnectionPool.check_connection executes conn.execute("") on every connection borrow
  • Continuous background pool maintenance network calls
  • Unnecessary network calls during measurement runs should be avoided

Limited Error Scope:

  • No checks performed during actual query execution (cur.execute())
  • Cannot handle connections that fail mid-query
  • Generic PoolTimeout exceptions with limited context

Custom @with_db_retry Advantages

Measurement Integrity:

  • Zero network overhead during normal operation
  • Only activates during actual database failures
  • No background connection health checks

Superior Error Handling:

  • Detects and retries failures during cur.execute()
  • Classifies retryable vs non-retryable errors
  • Complete pool recreation ensures clean state
  • Detailed retry logging with error context

Similar Retry Performance:
Both mechanisms use comparable exponential backoff (1s, 2s, 4s, 8s...) with jitter.

Conclusion

Use the custom @with_db_retry mechanism because:

  • Preserves measurement accuracy by avoiding unnecessary network calls
  • Better error detection scope (covers mid-query failures)
  • Superior observability and debugging capabilities
  • No compromise in retry performance

To be explicit that we don't want to use the check functionality of psycopg3, we should set it to None:

self._pool = ConnectionPool(
    # ...
    check=None
)

config = GlobalConfig().config
retry_timeout = config.get('postgresql', {}).get('retry_timeout', 300)
retry_interval = 1 # Base interval for exponential backoff

start_time = time.time()
attempt = 0

while time.time() - start_time < retry_timeout:
attempt += 1
try:
return func(self, *args, **kwargs)
except (psycopg.OperationalError, psycopg.DatabaseError) as e:
# Check if this is a connection-related error that we should retry
error_str = str(e).lower()
retryable_errors = [
'connection', 'closed', 'terminated', 'timeout', 'network',
'server', 'unavailable', 'refused', 'reset', 'broken pipe'
]

is_retryable = any(keyword in error_str for keyword in retryable_errors)

if not is_retryable:
# Non-retryable error (e.g., SQL syntax error)
print(f"Database error (non-retryable): {e}")
raise

time_elapsed = time.time() - start_time
if time_elapsed >= retry_timeout:
print(f"Database retry timeout after {attempt} attempts over {time_elapsed:.1f} seconds. Last error: {e}")
raise

# Exponential backoff with jitter
backoff_time = min(retry_interval * (2 ** (attempt - 1)), 30) # Cap at 30 seconds
jitter = random.uniform(0.1, 0.5) * backoff_time
sleep_time = backoff_time + jitter

print(f"Database connection error (attempt {attempt}): {e}. Retrying in {sleep_time:.2f} seconds...")

# Try to recreate the connection pool if it's corrupted
try:
if hasattr(self, '_pool'):
self._pool.close()
del self._pool
self._create_pool()
except (psycopg.OperationalError, psycopg.DatabaseError, AttributeError) as pool_error:
print(f"Failed to recreate connection pool: {pool_error}")

time.sleep(sleep_time)

# If we get here, we've exhausted all retries
raise psycopg.OperationalError(f"Database connection failed after {attempt} attempts over {time.time() - start_time:.1f} seconds")

return wrapper

class DB:

def __new__(cls):
Expand All @@ -19,49 +80,62 @@ def __new__(cls):
return cls.instance

def __init__(self):

if not hasattr(self, '_pool'):
config = GlobalConfig().config

# Important note: We are not using cursor_factory = psycopg2.extras.RealDictCursor
# as an argument, because this would increase the size of a single API request
# from 50 kB to 100kB.
# Users are required to use the mask of the API requests to read the data.
# force domain socket connection by not supplying host
# pylint: disable=consider-using-f-string

self._pool = ConnectionPool(
"user=%s password=%s host=%s port=%s dbname=%s sslmode=require" % (
config['postgresql']['user'],
config['postgresql']['password'],
config['postgresql']['host'],
config['postgresql']['port'],
config['postgresql']['dbname'],
),
min_size=1,
max_size=2,
open=True,
)
self._create_pool()

def _create_pool(self):
config = GlobalConfig().config

# Important note: We are not using cursor_factory = psycopg2.extras.RealDictCursor
# as an argument, because this would increase the size of a single API request
# from 50 kB to 100kB.
# Users are required to use the mask of the API requests to read the data.
# force domain socket connection by not supplying host
# pylint: disable=consider-using-f-string

self._pool = ConnectionPool(
"user=%s password=%s host=%s port=%s dbname=%s sslmode=require" % (
config['postgresql']['user'],
config['postgresql']['password'],
config['postgresql']['host'],
config['postgresql']['port'],
config['postgresql']['dbname'],
),
min_size=1,
max_size=2,
open=True,
# Explicitly disabled (default) to prevent measurement interference
# from conn.execute("") calls, using @with_db_retry instead
check=None
)

def shutdown(self):
if hasattr(self, '_pool'):
self._pool.close()
del self._pool


def __query(self, query, params=None, return_type=None, fetch_mode=None):
# Query list only supports SELECT queries
# If we ever need complex queries in the future where we have a transaction that mixes SELECTs and INSERTS
# then this class needs a refactoring. Until then we can KISS it
def __query_multi(self, query, params=None):
with self._pool.connection() as conn:
conn.autocommit = False # should be default, but we are explicit
cur = conn.cursor(row_factory=None) # None is actually the default cursor factory
for i in range(len(query)):
# In error case the context manager will ROLLBACK the whole transaction
cur.execute(query[i], params[i])
conn.commit()

@with_db_retry
def __query_single(self, query, params=None, return_type=None, fetch_mode=None):
ret = False
row_factory = psycopg.rows.dict_row if fetch_mode == 'dict' else None

with self._pool.connection() as conn:
conn.autocommit = False # should be default, but we are explicit
cur = conn.cursor(row_factory=row_factory) # None is actually the default cursor factory
if isinstance(query, list) and isinstance(params, list) and len(query) == len(params):
for i in range(len(query)):
# In error case the context manager will ROLLBACK the whole transaction
cur.execute(query[i], params[i])
else:
cur.execute(query, params)
cur.execute(query, params)
conn.commit()
if return_type == 'one':
ret = cur.fetchone()
Expand All @@ -72,15 +146,21 @@ def __query(self, query, params=None, return_type=None, fetch_mode=None):

return ret



def query(self, query, params=None, fetch_mode=None):
return self.__query(query, params=params, return_type=None, fetch_mode=fetch_mode)
return self.__query_single(query, params=params, return_type=None, fetch_mode=fetch_mode)

def query_multi(self, query, params=None):
return self.__query_multi(query, params=params)

def fetch_one(self, query, params=None, fetch_mode=None):
return self.__query(query, params=params, return_type='one', fetch_mode=fetch_mode)
return self.__query_single(query, params=params, return_type='one', fetch_mode=fetch_mode)

def fetch_all(self, query, params=None, fetch_mode=None):
return self.__query(query, params=params, return_type='all', fetch_mode=fetch_mode)
return self.__query_single(query, params=params, return_type='all', fetch_mode=fetch_mode)

@with_db_retry
def import_csv(self, filename):
raise NotImplementedError('Code still flakes on ; in data. Please rework')
# pylint: disable=unreachable
Expand All @@ -94,6 +174,7 @@ def import_csv(self, filename):
cur.execute(statement)
conn.autocommit = False

@with_db_retry
def copy_from(self, file, table, columns, sep=','):
with self._pool.connection() as conn:
conn.autocommit = False # is implicit default
Expand Down
144 changes: 144 additions & 0 deletions tests/lib/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import unittest
from unittest.mock import Mock, patch
import io
import psycopg
from lib.db import with_db_retry, DB


class TestWithDbRetryDecorator(unittest.TestCase):
"""Test the @with_db_retry decorator using mocks to simulate various error conditions.

These tests verify the retry logic, timeout behavior, and error classification
without requiring actual database connections.
"""

def setUp(self):
class MockDB:
def __init__(self):
self._pool = Mock()

def _create_pool(self):
self._pool = Mock()

@with_db_retry
def failing_method(self):
raise psycopg.OperationalError("connection refused")

@with_db_retry
def non_retryable_method(self):
raise psycopg.DatabaseError("syntax error at or near")

self.mock_db = MockDB()

@patch('time.time')
@patch('time.sleep')
def test_retry_on_retryable_errors(self, mock_sleep, mock_time):
# Mock time progression: start=0, while_loop=1, elapsed_check=2, timeout_while=350, final_msg=350
mock_time.side_effect = [0, 1, 2, 350, 350]

with self.assertRaises(psycopg.OperationalError):
self.mock_db.failing_method()

# Verify that sleep was called (indicating a retry attempt)
self.assertTrue(mock_sleep.called)

def test_non_retryable_errors(self):
with self.assertRaises(psycopg.DatabaseError) as cm:
self.mock_db.non_retryable_method()

self.assertIn("syntax error", str(cm.exception))

@patch('time.time')
@patch('time.sleep')
@patch('builtins.print')
def test_timeout_behavior(self, mock_print, mock_sleep, mock_time):
# Mock time: start=0, while_check=1, elapsed_check=350 (timeout)
mock_time.side_effect = [0, 1, 350]

with self.assertRaises(psycopg.OperationalError) as cm:
self.mock_db.failing_method()

# Original error is raised
self.assertIn("connection refused", str(cm.exception))

# But timeout message is printed
timeout_call = None
for call in mock_print.call_args_list:
if "Database retry timeout" in str(call):
timeout_call = call
break
self.assertIsNotNone(timeout_call)

# Sleep should not be called since timeout occurs before sleep
self.assertFalse(mock_sleep.called)


class TestDbIntegration(unittest.TestCase):
"""Integration tests for DB class methods using real database connections.

These tests verify actual database operations against a test PostgreSQL database without mocking.
"""

def setUp(self):
self.db = DB()
self.table_name = "test_integration_table"

def test_basic_query_execution(self):
result = self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)")
self.assertIn("CREATE TABLE", result)

def test_fetch_one_operation(self):
self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)")
self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'test')")

result = self.db.fetch_one(f"SELECT id, name FROM {self.table_name} WHERE id = 1")
self.assertEqual(result[0], 1)
self.assertEqual(result[1], 'test')

def test_fetch_all_operation(self):
self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)")
self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'test1'), (2, 'test2')")

results = self.db.fetch_all(f"SELECT id, name FROM {self.table_name} ORDER BY id")
self.assertEqual(len(results), 2)
self.assertEqual(results[0][0], 1)
self.assertEqual(results[1][0], 2)

def test_parameter_binding(self):
self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)")

self.db.query(f"INSERT INTO {self.table_name} VALUES (%s, %s)", (1, 'param_test'))
result = self.db.fetch_one(f"SELECT name FROM {self.table_name} WHERE id = %s", (1,))
self.assertEqual(result[0], 'param_test')

def test_fetch_mode_dict(self):
self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT)")
self.db.query(f"INSERT INTO {self.table_name} VALUES (1, 'dict_test')")

result = self.db.fetch_one(f"SELECT id, name FROM {self.table_name}", fetch_mode='dict')
self.assertIsInstance(result, dict)
self.assertEqual(result['id'], 1)
self.assertEqual(result['name'], 'dict_test')

def test_error_handling_invalid_sql(self):
with self.assertRaises(psycopg.DatabaseError):
self.db.query("INVALID SQL STATEMENT")

def test_copy_from_csv_data(self):
self.db.query(f"CREATE TABLE {self.table_name} (id INT, name TEXT, value NUMERIC)")

csv_data = io.StringIO("1,test1,10.5\n2,test2,20.7\n")
columns = ['id', 'name', 'value']

self.db.copy_from(csv_data, self.table_name, columns)

results = self.db.fetch_all(f"SELECT id, name, value FROM {self.table_name} ORDER BY id")
self.assertEqual(len(results), 2)
self.assertEqual(results[0][0], 1)
self.assertEqual(results[0][1], 'test1')
self.assertEqual(results[1][0], 2)
self.assertEqual(results[1][1], 'test2')


if __name__ == '__main__':
unittest.main()
Loading
Loading