pytest is a mature, full-featured Python testing framework that makes it easy to write small, readable tests and scales to support complex functional testing for applications and libraries.
Key Features
- Simple syntax: Plain
assertstatements, no special assertion methods - Auto-discovery: Automatically finds test files and functions
- Fixtures: Powerful dependency injection for test setup/teardown
- Parametrization: Run same test with multiple input sets
- Rich plugin ecosystem: Extend functionality with plugins
- Detailed failure reports: Clear error messages with context
Getting Started
# test_analysis.py
import numpy as np
from scipy import signal
def butter_lowpass(data, cutoff, fs, order=4):
"""Apply Butterworth lowpass filter."""
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = signal.butter(order, normal_cutoff, btype='low')
return signal.filtfilt(b, a, data)
# Test functions start with 'test_'
def test_lowpass_reduces_noise():
# Create noisy signal
fs = 1000
t = np.linspace(0, 1, fs)
clean = np.sin(2 * np.pi * 5 * t)
noise = 0.5 * np.random.randn(len(t))
noisy = clean + noise
# Filter
filtered = butter_lowpass(noisy, cutoff=10, fs=fs)
# Check filtering worked
assert np.std(filtered) < np.std(noisy)
assert np.corrcoef(filtered, clean)[0, 1] > 0.95
def test_lowpass_preserves_low_frequencies():
fs = 1000
t = np.linspace(0, 1, fs)
signal_5hz = np.sin(2 * np.pi * 5 * t)
filtered = butter_lowpass(signal_5hz, cutoff=10, fs=fs)
# Should preserve 5Hz signal (below 10Hz cutoff)
assert np.allclose(filtered, signal_5hz, atol=0.01)
def test_lowpass_removes_high_frequencies():
fs = 1000
t = np.linspace(0, 1, fs)
signal_50hz = np.sin(2 * np.pi * 50 * t)
filtered = butter_lowpass(signal_50hz, cutoff=10, fs=fs)
# Should remove 50Hz signal (above 10Hz cutoff)
assert np.max(np.abs(filtered)) < 0.1
Run tests:
pytest test_analysis.py -v
Fixtures for Test Setup
import pytest
import numpy as np
from pathlib import Path
@pytest.fixture
def sample_recording():
"""Fixture providing sample neural data."""
fs = 30000
duration = 1.0
n_samples = int(fs * duration)
n_channels = 64
data = np.random.randn(n_samples, n_channels)
return {
'data': data,
'fs': fs,
'n_channels': n_channels,
}
@pytest.fixture
def temp_data_dir(tmp_path):
"""Fixture providing temporary directory for test data."""
data_dir = tmp_path / "test_data"
data_dir.mkdir()
return data_dir
def test_save_load_recording(sample_recording, temp_data_dir):
"""Test saving and loading recordings."""
filepath = temp_data_dir / "recording.npy"
# Save
np.save(filepath, sample_recording['data'])
# Load
loaded = np.load(filepath)
assert np.array_equal(loaded, sample_recording['data'])
Parametrized Tests
@pytest.mark.parametrize("cutoff,expected_max", [
(10, 0.1),
(20, 0.2),
(50, 0.5),
])
def test_lowpass_different_cutoffs(cutoff, expected_max):
fs = 1000
t = np.linspace(0, 1, fs)
signal_100hz = np.sin(2 * np.pi * 100 * t)
filtered = butter_lowpass(signal_100hz, cutoff=cutoff, fs=fs)
assert np.max(np.abs(filtered)) < expected_max
Testing Notebooks with ipytest
# In Jupyter notebook
import ipytest
ipytest.autoconfig()
%%ipytest
def test_data_loaded():
assert len(df) > 0
assert 'subject_id' in df.columns
When to Use pytest
Best for:
- Testing analysis pipelines
- Validating data processing functions
- Regression testing (ensuring code changes don’t break existing functionality)
- Test-driven development
- Continuous integration workflows
Research Benefits:
- Reproducibility: Tests document expected behavior
- Confidence: Catch errors before they affect results
- Collaboration: Tests help others understand and modify code
- Refactoring: Safely improve code structure
Common Patterns in Research
# Test statistical functions
def test_z_score_normalization():
data = np.array([1, 2, 3, 4, 5])
z = (data - np.mean(data)) / np.std(data)
assert np.isclose(np.mean(z), 0, atol=1e-10)
assert np.isclose(np.std(z), 1, atol=1e-10)
# Test for expected exceptions
def test_invalid_input_raises_error():
with pytest.raises(ValueError, match="sampling rate must be positive"):
process_signal(data, fs=-1000)
# Test numerical accuracy
def test_numerical_precision():
result = compute_correlation(x, y)
expected = 0.87654
assert np.isclose(result, expected, rtol=1e-5)
Integration with Research Workflows
- Run tests automatically with pre-commit hooks
- Include in CI/CD pipelines (GitHub Actions, GitLab CI)
- Test notebooks before converting to scripts
- Validate data before analysis