This post focus’ on common hurdles when trying to do unit testing.
Testing Values During Run
You add the following line to anywhere you want to pause the unit test to check values.
- import pdb
- pdb.set_trace()
How to Patch a Function
- from unittest.mock import path
- @patch('src.path.to.file.my_function')
- @path('src.path.to.file.my_function_add')
- def test_some_function(mock_my_function_add, mock_my_function):
- mock_function_add.return_value = <something>
- .......
How to Patch a Function With No Return Value
- from unittest.mock import patch
- def test_some_function():
- with(patch('src.path.to.file.my_function'):
- ...
How to Patch a Function With 1 Return Value
- from unittest.mock import patch
- def test_some_function():
- with(patch('src.path.to.file.my_function', MagicMock(return_value=[<MY_VALUES>])):
- ...
How to Patch a Function With Multiple Return Value
- from unittest.mock import patch
- def test_some_function():
- with(patch('src.path.to.file.my_function', MagicMock(side-effect=[[<MY_VALUES>], [<OTHER_VALUES>]])):
- ...
How to Create a Test Module
- from unittest import TestCase
- class MyModule(TestCase):
- def setUp(self):
- some_class.my_variable = <something>
- ... DO OTHER STUFF
- def test_my_function(self):
- ... DO Function Test Stuff
How to Patch a Method
- patch_methods = [
- "pyodbc.connect"
- ]
- for method in patch_methods:
- patch(method).start()
How to create a PySpark Session
Now once you do this you can just call spark and it will set it.
- import pytest
- from pyspark.sql import SparkSession
- @pytest.fixture(scope='module')
- def spark():
- return (SparkSession.builder.appName('pyspark_test').getOrCreate())
How to Create a Spark SQL Example
- import pytest
- from pyspark.sql import SparkSession, Row
- from pyspark.sql.types import StructType, StructField, StringType
- @pytest.fixture(scope='module')
- def spark():
- return (SparkSession.builder.appName('pyspark_test').getOrCreate())
- def test_function(spark):
- query = 'SELECT * FROM SOMETHING'
- schema = StructType([
- StructField('column_a', StringType()),
- StructField('column_b', StringType()),
- StructField('column_c', StringType()),
- ])
- data = [Row(column_a='a', column_b='b', column_c='c')]
- table = spark.createDataFrame(data, schema=schema)
- table.createOrReplaceTempView('<table_name>')
- df = spark.sql(query).toPandas()
- assert not df.empty
- assert df.shape[0] == 1
- assert df.shape(1) == 5
- spark.catalog.dropTempView('<table_name>')
How to Mock a Database Call
First let’s assume you have an exeucte sql function
- def execute_sql(cursor, sql, params):
- result = cursor.execute(sql, params).fetchone()
- connection.commit()
- return result
Next in your unit tests you want to test that funciton
- def test_execute_sql():
- val = <YOUR_RETURN_VALUE>
- with patch('path.to.code.execute_sql', MagicMock(return_value=val)) as mock_execute:
- return_val = some_other_function_that_calls_execute_sql(....)
- assert return_val == val
If you need to close a cursor or DB connection
- def test_execute_sql():
- val = <YOUR_RETURN_VALUE>
- mock_cursor = MagicMock()
- mock_cursor.configure_mock(
- **{
- "close": MagicMock()
- }
- )
- mock_connection = MagicMock()
- mock_connection.configure_mock(
- **{
- "close": MagicMock()
- }
- )
- with patch('path.to.code.cursor', MagicMock(return_value=mock_cursor)) as mock_cursor_close:
- with patch('path.to.code.connection', MagicMock(return_value=mock_connection)) as mock_connection_close:
- return_val = some_other_function_that_calls_execute_sql(....)
- assert return_val == val
How to Mock Open a File Example 1
- @patch('builtins.open", new_callable=mock_open, read_data='my_data')
- def test_file_open(mock_file):
- assert open("my/file/path/filename.extension").read() == 'my_data'
- mock_file.assert_called_with("my/file/path/filename.extension")
- val = function_to_test(....)
- assert 'my_data' == val
How to Mock Open a File Example 2
- def test_file_open():
- fake_file_path = 'file/path/to/mock'
- file_content_mock = 'test'
- with patch('path.to.code.function'.format(__name__), new=mock_open(read_data=file_content_mock)) as mock_file:
- with patch(os.utime') as mock_utime:
- actual = function_to_test(fake_file_path)
- mock_file.assert_called_once_with(fake_file_path)
- assertIsNotNone(actual)
Compare DataFrames
- def as_dicts(df):
- df = [row.asDict() for row in df.collect()]
- return sorted(df, key=lambda row: str(row))
- assert as_dicts(df1) == as_dicts(df2)