You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			106 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
			
		
		
	
	
			106 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
import pytest
 | 
						|
from numpy.testing import assert_allclose, assert_array_equal
 | 
						|
 | 
						|
from matplotlib.sankey import Sankey
 | 
						|
from matplotlib.testing.decorators import check_figures_equal
 | 
						|
 | 
						|
 | 
						|
def test_sankey():
 | 
						|
    # lets just create a sankey instance and check the code runs
 | 
						|
    sankey = Sankey()
 | 
						|
    sankey.add()
 | 
						|
 | 
						|
 | 
						|
def test_label():
 | 
						|
    s = Sankey(flows=[0.25], labels=['First'], orientations=[-1])
 | 
						|
    assert s.diagrams[0].texts[0].get_text() == 'First\n0.25'
 | 
						|
 | 
						|
 | 
						|
def test_format_using_callable():
 | 
						|
    # test using callable by slightly incrementing above label example
 | 
						|
 | 
						|
    def show_three_decimal_places(value):
 | 
						|
        return f'{value:.3f}'
 | 
						|
 | 
						|
    s = Sankey(flows=[0.25], labels=['First'], orientations=[-1],
 | 
						|
               format=show_three_decimal_places)
 | 
						|
 | 
						|
    assert s.diagrams[0].texts[0].get_text() == 'First\n0.250'
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize('kwargs, msg', (
 | 
						|
    ({'gap': -1}, "'gap' is negative"),
 | 
						|
    ({'gap': 1, 'radius': 2}, "'radius' is greater than 'gap'"),
 | 
						|
    ({'head_angle': -1}, "'head_angle' is negative"),
 | 
						|
    ({'tolerance': -1}, "'tolerance' is negative"),
 | 
						|
    ({'flows': [1, -1], 'orientations': [-1, 0, 1]},
 | 
						|
     r"The shapes of 'flows' \(2,\) and 'orientations'"),
 | 
						|
    ({'flows': [1, -1], 'labels': ['a', 'b', 'c']},
 | 
						|
     r"The shapes of 'flows' \(2,\) and 'labels'"),
 | 
						|
    ))
 | 
						|
def test_sankey_errors(kwargs, msg):
 | 
						|
    with pytest.raises(ValueError, match=msg):
 | 
						|
        Sankey(**kwargs)
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize('kwargs, msg', (
 | 
						|
    ({'trunklength': -1}, "'trunklength' is negative"),
 | 
						|
    ({'flows': [0.2, 0.3], 'prior': 0}, "The scaled sum of the connected"),
 | 
						|
    ({'prior': -1}, "The index of the prior diagram is negative"),
 | 
						|
    ({'prior': 1}, "The index of the prior diagram is 1"),
 | 
						|
    ({'connect': (-1, 1), 'prior': 0}, "At least one of the connection"),
 | 
						|
    ({'connect': (2, 1), 'prior': 0}, "The connection index to the source"),
 | 
						|
    ({'connect': (1, 3), 'prior': 0}, "The connection index to this dia"),
 | 
						|
    ({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
 | 
						|
      'orientations': [2]}, "The value of orientations"),
 | 
						|
    ({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
 | 
						|
      'pathlengths': [2]}, "The lengths of 'flows'"),
 | 
						|
    ))
 | 
						|
def test_sankey_add_errors(kwargs, msg):
 | 
						|
    sankey = Sankey()
 | 
						|
    with pytest.raises(ValueError, match=msg):
 | 
						|
        sankey.add(flows=[0.2, -0.2])
 | 
						|
        sankey.add(**kwargs)
 | 
						|
 | 
						|
 | 
						|
def test_sankey2():
 | 
						|
    s = Sankey(flows=[0.25, -0.25, 0.5, -0.5], labels=['Foo'],
 | 
						|
               orientations=[-1], unit='Bar')
 | 
						|
    sf = s.finish()
 | 
						|
    assert_array_equal(sf[0].flows, [0.25, -0.25, 0.5, -0.5])
 | 
						|
    assert sf[0].angles == [1, 3, 1, 3]
 | 
						|
    assert all([text.get_text()[0:3] == 'Foo' for text in sf[0].texts])
 | 
						|
    assert all([text.get_text()[-3:] == 'Bar' for text in sf[0].texts])
 | 
						|
    assert sf[0].text.get_text() == ''
 | 
						|
    assert_allclose(sf[0].tips,
 | 
						|
                    [(-1.375, -0.52011255),
 | 
						|
                     (1.375, -0.75506044),
 | 
						|
                     (-0.75, -0.41522509),
 | 
						|
                     (0.75, -0.8599479)])
 | 
						|
 | 
						|
    s = Sankey(flows=[0.25, -0.25, 0, 0.5, -0.5], labels=['Foo'],
 | 
						|
               orientations=[-1], unit='Bar')
 | 
						|
    sf = s.finish()
 | 
						|
    assert_array_equal(sf[0].flows, [0.25, -0.25, 0, 0.5, -0.5])
 | 
						|
    assert sf[0].angles == [1, 3, None, 1, 3]
 | 
						|
    assert_allclose(sf[0].tips,
 | 
						|
                    [(-1.375, -0.52011255),
 | 
						|
                     (1.375, -0.75506044),
 | 
						|
                     (0, 0),
 | 
						|
                     (-0.75, -0.41522509),
 | 
						|
                     (0.75, -0.8599479)])
 | 
						|
 | 
						|
 | 
						|
@check_figures_equal(extensions=['png'])
 | 
						|
def test_sankey3(fig_test, fig_ref):
 | 
						|
    ax_test = fig_test.gca()
 | 
						|
    s_test = Sankey(ax=ax_test, flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
 | 
						|
                    orientations=[1, -1, 1, -1, 0, 0])
 | 
						|
    s_test.finish()
 | 
						|
 | 
						|
    ax_ref = fig_ref.gca()
 | 
						|
    s_ref = Sankey(ax=ax_ref)
 | 
						|
    s_ref.add(flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
 | 
						|
              orientations=[1, -1, 1, -1, 0, 0])
 | 
						|
    s_ref.finish()
 |