This post focus’ on common hurdles when trying to do unit testing.
Testing Values During Run
You add the following line to anywhere you want to pause the unit test to check values.
import pdb pdb.set_trace()
How to Patch a Function
from unittest.mock import path @patch('src.path.to.file.my_function') @path('src.path.to.file.my_function_add') def test_some_function(mock_my_function_add, mock_my_function): mock_function_add.return_value = <something> .......
How to Patch a Function With No Return Value
from unittest.mock import patch def test_some_function(): with(patch('src.path.to.file.my_function'): ...
How to Patch a Function With 1 Return Value
from unittest.mock import patch def test_some_function(): with(patch('src.path.to.file.my_function', MagicMock(return_value=[<MY_VALUES>])): ...
How to Patch a Function With Multiple Return Value
from unittest.mock import patch def test_some_function(): with(patch('src.path.to.file.my_function', MagicMock(side-effect=[[<MY_VALUES>], [<OTHER_VALUES>]])): ...
How to Create a Test Module
from unittest import TestCase class MyModule(TestCase): def setUp(self): some_class.my_variable = <something> ... DO OTHER STUFF def test_my_function(self): ... DO Function Test Stuff
How to Patch a Method
patch_methods = [ "pyodbc.connect" ] for method in patch_methods: patch(method).start()
How to create a PySpark Session
Now once you do this you can just call spark and it will set it.
import pytest from pyspark.sql import SparkSession @pytest.fixture(scope='module') def spark(): return (SparkSession.builder.appName('pyspark_test').getOrCreate())
How to Create a Spark SQL Example
import pytest from pyspark.sql import SparkSession, Row from pyspark.sql.types import StructType, StructField, StringType @pytest.fixture(scope='module') def spark(): return (SparkSession.builder.appName('pyspark_test').getOrCreate()) def test_function(spark): query = 'SELECT * FROM SOMETHING' schema = StructType([ StructField('column_a', StringType()), StructField('column_b', StringType()), StructField('column_c', StringType()), ]) data = [Row(column_a='a', column_b='b', column_c='c')] table = spark.createDataFrame(data, schema=schema) table.createOrReplaceTempView('<table_name>') df = spark.sql(query).toPandas() assert not df.empty assert df.shape[0] == 1 assert df.shape(1) == 5 spark.catalog.dropTempView('<table_name>')
How to Mock a Database Call
First let’s assume you have an exeucte sql function
def execute_sql(cursor, sql, params): result = cursor.execute(sql, params).fetchone() connection.commit() return result
Next in your unit tests you want to test that funciton
def test_execute_sql(): val = <YOUR_RETURN_VALUE> with patch('path.to.code.execute_sql', MagicMock(return_value=val)) as mock_execute: return_val = some_other_function_that_calls_execute_sql(....) assert return_val == val
If you need to close a cursor or DB connection
def test_execute_sql(): val = <YOUR_RETURN_VALUE> mock_cursor = MagicMock() mock_cursor.configure_mock( **{ "close": MagicMock() } ) mock_connection = MagicMock() mock_connection.configure_mock( **{ "close": MagicMock() } ) with patch('path.to.code.cursor', MagicMock(return_value=mock_cursor)) as mock_cursor_close: with patch('path.to.code.connection', MagicMock(return_value=mock_connection)) as mock_connection_close: return_val = some_other_function_that_calls_execute_sql(....) assert return_val == val
How to Mock Open a File Example 1
@patch('builtins.open", new_callable=mock_open, read_data='my_data') def test_file_open(mock_file): assert open("my/file/path/filename.extension").read() == 'my_data' mock_file.assert_called_with("my/file/path/filename.extension") val = function_to_test(....) assert 'my_data' == val
How to Mock Open a File Example 2
def test_file_open(): fake_file_path = 'file/path/to/mock' file_content_mock = 'test' with patch('path.to.code.function'.format(__name__), new=mock_open(read_data=file_content_mock)) as mock_file: with patch(os.utime') as mock_utime: actual = function_to_test(fake_file_path) mock_file.assert_called_once_with(fake_file_path) assertIsNotNone(actual)
Compare DataFrames
def as_dicts(df): df = [row.asDict() for row in df.collect()] return sorted(df, key=lambda row: str(row)) assert as_dicts(df1) == as_dicts(df2)