Adding upstream version 4.64.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ee08d9327c
commit
2da88b2fbc
89 changed files with 16770 additions and 0 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
41
tests/conftest.py
Normal file
41
tests/conftest.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
"""Shared pytest config."""
|
||||
import sys
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def pretest_posttest():
|
||||
"""Fixture for all tests ensuring environment cleanup"""
|
||||
try:
|
||||
sys.setswitchinterval(1)
|
||||
except AttributeError:
|
||||
sys.setcheckinterval(100) # deprecated
|
||||
|
||||
if getattr(tqdm, "_instances", False):
|
||||
n = len(tqdm._instances)
|
||||
if n:
|
||||
tqdm._instances.clear()
|
||||
raise EnvironmentError(
|
||||
"{0} `tqdm` instances still in existence PRE-test".format(n))
|
||||
yield
|
||||
if getattr(tqdm, "_instances", False):
|
||||
n = len(tqdm._instances)
|
||||
if n:
|
||||
tqdm._instances.clear()
|
||||
raise EnvironmentError(
|
||||
"{0} `tqdm` instances still in existence POST-test".format(n))
|
||||
|
||||
|
||||
if sys.version_info[0] > 2:
|
||||
@fixture
|
||||
def capsysbin(capsysbinary):
|
||||
"""alias for capsysbinary (py3)"""
|
||||
return capsysbinary
|
||||
else:
|
||||
@fixture
|
||||
def capsysbin(capsys):
|
||||
"""alias for capsys (py2)"""
|
||||
return capsys
|
128
tests/py37_asyncio.py
Normal file
128
tests/py37_asyncio.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import asyncio
|
||||
from functools import partial
|
||||
from sys import platform
|
||||
from time import time
|
||||
|
||||
from tqdm.asyncio import tarange, tqdm_asyncio
|
||||
|
||||
from .tests_tqdm import StringIO, closing, mark
|
||||
|
||||
tqdm = partial(tqdm_asyncio, miniters=0, mininterval=0)
|
||||
trange = partial(tarange, miniters=0, mininterval=0)
|
||||
as_completed = partial(tqdm_asyncio.as_completed, miniters=0, mininterval=0)
|
||||
gather = partial(tqdm_asyncio.gather, miniters=0, mininterval=0)
|
||||
|
||||
|
||||
def count(start=0, step=1):
|
||||
i = start
|
||||
while True:
|
||||
new_start = yield i
|
||||
if new_start is None:
|
||||
i += step
|
||||
else:
|
||||
i = new_start
|
||||
|
||||
|
||||
async def acount(*args, **kwargs):
|
||||
for i in count(*args, **kwargs):
|
||||
yield i
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_break():
|
||||
"""Test asyncio break"""
|
||||
pbar = tqdm(count())
|
||||
async for _ in pbar:
|
||||
break
|
||||
pbar.close()
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_generators(capsys):
|
||||
"""Test asyncio generators"""
|
||||
with tqdm(count(), desc="counter") as pbar:
|
||||
async for i in pbar:
|
||||
if i >= 8:
|
||||
break
|
||||
_, err = capsys.readouterr()
|
||||
assert '9it' in err
|
||||
|
||||
with tqdm(acount(), desc="async_counter") as pbar:
|
||||
async for i in pbar:
|
||||
if i >= 8:
|
||||
break
|
||||
_, err = capsys.readouterr()
|
||||
assert '9it' in err
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_range():
|
||||
"""Test asyncio range"""
|
||||
with closing(StringIO()) as our_file:
|
||||
async for _ in tqdm(range(9), desc="range", file=our_file):
|
||||
pass
|
||||
assert '9/9' in our_file.getvalue()
|
||||
our_file.seek(0)
|
||||
our_file.truncate()
|
||||
|
||||
async for _ in trange(9, desc="trange", file=our_file):
|
||||
pass
|
||||
assert '9/9' in our_file.getvalue()
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_nested():
|
||||
"""Test asyncio nested"""
|
||||
with closing(StringIO()) as our_file:
|
||||
async for _ in tqdm(trange(9, desc="inner", file=our_file),
|
||||
desc="outer", file=our_file):
|
||||
pass
|
||||
assert 'inner: 100%' in our_file.getvalue()
|
||||
assert 'outer: 100%' in our_file.getvalue()
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_coroutines():
|
||||
"""Test asyncio coroutine.send"""
|
||||
with closing(StringIO()) as our_file:
|
||||
with tqdm(count(), file=our_file) as pbar:
|
||||
async for i in pbar:
|
||||
if i == 9:
|
||||
pbar.send(-10)
|
||||
elif i < 0:
|
||||
assert i == -9
|
||||
break
|
||||
assert '10it' in our_file.getvalue()
|
||||
|
||||
|
||||
@mark.slow
|
||||
@mark.asyncio
|
||||
@mark.parametrize("tol", [0.2 if platform.startswith("darwin") else 0.1])
|
||||
async def test_as_completed(capsys, tol):
|
||||
"""Test asyncio as_completed"""
|
||||
for retry in range(3):
|
||||
t = time()
|
||||
skew = time() - t
|
||||
for i in as_completed([asyncio.sleep(0.01 * i) for i in range(30, 0, -1)]):
|
||||
await i
|
||||
t = time() - t - 2 * skew
|
||||
try:
|
||||
assert 0.3 * (1 - tol) < t < 0.3 * (1 + tol), t
|
||||
_, err = capsys.readouterr()
|
||||
assert '30/30' in err
|
||||
except AssertionError:
|
||||
if retry == 2:
|
||||
raise
|
||||
|
||||
|
||||
async def double(i):
|
||||
return i * 2
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_gather(capsys):
|
||||
"""Test asyncio gather"""
|
||||
res = await gather(*map(double, range(30)))
|
||||
_, err = capsys.readouterr()
|
||||
assert '30/30' in err
|
||||
assert res == list(range(0, 30 * 2, 2))
|
11
tests/tests_asyncio.py
Normal file
11
tests/tests_asyncio.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""Tests `tqdm.asyncio` on `python>=3.7`."""
|
||||
import sys
|
||||
|
||||
if sys.version_info[:2] > (3, 6):
|
||||
from .py37_asyncio import * # NOQA, pylint: disable=wildcard-import
|
||||
else:
|
||||
from .tests_tqdm import skip
|
||||
try:
|
||||
skip("async not supported", allow_module_level=True)
|
||||
except TypeError:
|
||||
pass
|
49
tests/tests_concurrent.py
Normal file
49
tests/tests_concurrent.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
"""
|
||||
Tests for `tqdm.contrib.concurrent`.
|
||||
"""
|
||||
from pytest import warns
|
||||
|
||||
from tqdm.contrib.concurrent import process_map, thread_map
|
||||
|
||||
from .tests_tqdm import StringIO, TqdmWarning, closing, importorskip, mark, skip
|
||||
|
||||
|
||||
def incr(x):
|
||||
"""Dummy function"""
|
||||
return x + 1
|
||||
|
||||
|
||||
def test_thread_map():
|
||||
"""Test contrib.concurrent.thread_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert thread_map(lambda x: x + 1, a, file=our_file) == b
|
||||
except ImportError as err:
|
||||
skip(str(err))
|
||||
assert thread_map(incr, a, file=our_file) == b
|
||||
|
||||
|
||||
def test_process_map():
|
||||
"""Test contrib.concurrent.process_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
try:
|
||||
assert process_map(incr, a, file=our_file) == b
|
||||
except ImportError as err:
|
||||
skip(str(err))
|
||||
|
||||
|
||||
@mark.parametrize("iterables,should_warn", [([], False), (['x'], False), ([()], False),
|
||||
(['x', ()], False), (['x' * 1001], True),
|
||||
(['x' * 100, ('x',) * 1001], True)])
|
||||
def test_chunksize_warning(iterables, should_warn):
|
||||
"""Test contrib.concurrent.process_map chunksize warnings"""
|
||||
patch = importorskip('unittest.mock').patch
|
||||
with patch('tqdm.contrib.concurrent._executor_map'):
|
||||
if should_warn:
|
||||
warns(TqdmWarning, process_map, incr, *iterables)
|
||||
else:
|
||||
process_map(incr, *iterables)
|
71
tests/tests_contrib.py
Normal file
71
tests/tests_contrib.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Tests for `tqdm.contrib`.
|
||||
"""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm.contrib import tenumerate, tmap, tzip
|
||||
|
||||
from .tests_tqdm import StringIO, closing, importorskip
|
||||
|
||||
|
||||
def incr(x):
|
||||
"""Dummy function"""
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
|
||||
def test_enumerate(tqdm_kwargs):
|
||||
"""Test contrib.tenumerate"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
|
||||
assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
|
||||
enumerate(a, 42)
|
||||
)
|
||||
with closing(StringIO()) as our_file:
|
||||
_ = list(tenumerate(iter(a), file=our_file, **tqdm_kwargs))
|
||||
assert "100%" not in our_file.getvalue()
|
||||
with closing(StringIO()) as our_file:
|
||||
_ = list(tenumerate(iter(a), file=our_file, total=len(a), **tqdm_kwargs))
|
||||
assert "100%" in our_file.getvalue()
|
||||
|
||||
|
||||
def test_enumerate_numpy():
|
||||
"""Test contrib.tenumerate(numpy.ndarray)"""
|
||||
np = importorskip("numpy")
|
||||
with closing(StringIO()) as our_file:
|
||||
a = np.random.random((42, 7))
|
||||
assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
|
||||
def test_zip(tqdm_kwargs):
|
||||
"""Test contrib.tzip"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
if sys.version_info[:1] < (3,):
|
||||
assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
|
||||
else:
|
||||
gen = tzip(a, b, file=our_file, **tqdm_kwargs)
|
||||
assert gen != list(zip(a, b))
|
||||
assert list(gen) == list(zip(a, b))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
|
||||
def test_map(tqdm_kwargs):
|
||||
"""Test contrib.tmap"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
b = [i + 1 for i in a]
|
||||
if sys.version_info[:1] < (3,):
|
||||
assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
|
||||
incr, a
|
||||
)
|
||||
else:
|
||||
gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
|
||||
assert gen != b
|
||||
assert list(gen) == b
|
173
tests/tests_contrib_logging.py
Normal file
173
tests/tests_contrib_logging.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
# pylint: disable=missing-module-docstring, missing-class-docstring
|
||||
# pylint: disable=missing-function-docstring, no-self-use
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
|
||||
from tqdm import tqdm
|
||||
from tqdm.contrib.logging import _get_first_found_console_logging_handler
|
||||
from tqdm.contrib.logging import _TqdmLoggingHandler as TqdmLoggingHandler
|
||||
from tqdm.contrib.logging import logging_redirect_tqdm, tqdm_logging_redirect
|
||||
|
||||
from .tests_tqdm import importorskip
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TEST_LOGGING_FORMATTER = logging.Formatter()
|
||||
|
||||
|
||||
class CustomTqdm(tqdm):
|
||||
messages = []
|
||||
|
||||
@classmethod
|
||||
def write(cls, s, **__): # pylint: disable=arguments-differ
|
||||
CustomTqdm.messages.append(s)
|
||||
|
||||
|
||||
class ErrorRaisingTqdm(tqdm):
|
||||
exception_class = RuntimeError
|
||||
|
||||
@classmethod
|
||||
def write(cls, s, **__): # pylint: disable=arguments-differ
|
||||
raise ErrorRaisingTqdm.exception_class('fail fast')
|
||||
|
||||
|
||||
class TestTqdmLoggingHandler:
|
||||
def test_should_call_tqdm_write(self):
|
||||
CustomTqdm.messages = []
|
||||
logger = logging.Logger('test')
|
||||
logger.handlers = [TqdmLoggingHandler(CustomTqdm)]
|
||||
logger.info('test')
|
||||
assert CustomTqdm.messages == ['test']
|
||||
|
||||
def test_should_call_handle_error_if_exception_was_thrown(self):
|
||||
patch = importorskip('unittest.mock').patch
|
||||
logger = logging.Logger('test')
|
||||
ErrorRaisingTqdm.exception_class = RuntimeError
|
||||
handler = TqdmLoggingHandler(ErrorRaisingTqdm)
|
||||
logger.handlers = [handler]
|
||||
with patch.object(handler, 'handleError') as mock:
|
||||
logger.info('test')
|
||||
assert mock.called
|
||||
|
||||
@pytest.mark.parametrize('exception_class', [
|
||||
KeyboardInterrupt,
|
||||
SystemExit
|
||||
])
|
||||
def test_should_not_swallow_certain_exceptions(self, exception_class):
|
||||
logger = logging.Logger('test')
|
||||
ErrorRaisingTqdm.exception_class = exception_class
|
||||
handler = TqdmLoggingHandler(ErrorRaisingTqdm)
|
||||
logger.handlers = [handler]
|
||||
with pytest.raises(exception_class):
|
||||
logger.info('test')
|
||||
|
||||
|
||||
class TestGetFirstFoundConsoleLoggingHandler:
|
||||
def test_should_return_none_for_no_handlers(self):
|
||||
assert _get_first_found_console_logging_handler([]) is None
|
||||
|
||||
def test_should_return_none_without_stream_handler(self):
|
||||
handler = logging.handlers.MemoryHandler(capacity=1)
|
||||
assert _get_first_found_console_logging_handler([handler]) is None
|
||||
|
||||
def test_should_return_none_for_stream_handler_not_stdout_or_stderr(self):
|
||||
handler = logging.StreamHandler(StringIO())
|
||||
assert _get_first_found_console_logging_handler([handler]) is None
|
||||
|
||||
def test_should_return_stream_handler_if_stream_is_stdout(self):
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
assert _get_first_found_console_logging_handler([handler]) == handler
|
||||
|
||||
def test_should_return_stream_handler_if_stream_is_stderr(self):
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
assert _get_first_found_console_logging_handler([handler]) == handler
|
||||
|
||||
|
||||
class TestRedirectLoggingToTqdm:
|
||||
def test_should_add_and_remove_tqdm_handler(self):
|
||||
logger = logging.Logger('test')
|
||||
with logging_redirect_tqdm(loggers=[logger]):
|
||||
assert len(logger.handlers) == 1
|
||||
assert isinstance(logger.handlers[0], TqdmLoggingHandler)
|
||||
assert not logger.handlers
|
||||
|
||||
def test_should_remove_and_restore_console_handlers(self):
|
||||
logger = logging.Logger('test')
|
||||
stderr_console_handler = logging.StreamHandler(sys.stderr)
|
||||
stdout_console_handler = logging.StreamHandler(sys.stderr)
|
||||
logger.handlers = [stderr_console_handler, stdout_console_handler]
|
||||
with logging_redirect_tqdm(loggers=[logger]):
|
||||
assert len(logger.handlers) == 1
|
||||
assert isinstance(logger.handlers[0], TqdmLoggingHandler)
|
||||
assert logger.handlers == [stderr_console_handler, stdout_console_handler]
|
||||
|
||||
def test_should_inherit_console_logger_formatter(self):
|
||||
logger = logging.Logger('test')
|
||||
formatter = logging.Formatter('custom: %(message)s')
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.handlers = [console_handler]
|
||||
with logging_redirect_tqdm(loggers=[logger]):
|
||||
assert logger.handlers[0].formatter == formatter
|
||||
|
||||
def test_should_not_remove_stream_handlers_not_for_stdout_or_stderr(self):
|
||||
logger = logging.Logger('test')
|
||||
stream_handler = logging.StreamHandler(StringIO())
|
||||
logger.addHandler(stream_handler)
|
||||
with logging_redirect_tqdm(loggers=[logger]):
|
||||
assert len(logger.handlers) == 2
|
||||
assert logger.handlers[0] == stream_handler
|
||||
assert isinstance(logger.handlers[1], TqdmLoggingHandler)
|
||||
assert logger.handlers == [stream_handler]
|
||||
|
||||
|
||||
class TestTqdmWithLoggingRedirect:
|
||||
def test_should_add_and_remove_handler_from_root_logger_by_default(self):
|
||||
original_handlers = list(logging.root.handlers)
|
||||
with tqdm_logging_redirect(total=1) as pbar:
|
||||
assert isinstance(logging.root.handlers[-1], TqdmLoggingHandler)
|
||||
LOGGER.info('test')
|
||||
pbar.update(1)
|
||||
assert logging.root.handlers == original_handlers
|
||||
|
||||
def test_should_add_and_remove_handler_from_custom_logger(self):
|
||||
logger = logging.Logger('test')
|
||||
with tqdm_logging_redirect(total=1, loggers=[logger]) as pbar:
|
||||
assert len(logger.handlers) == 1
|
||||
assert isinstance(logger.handlers[0], TqdmLoggingHandler)
|
||||
logger.info('test')
|
||||
pbar.update(1)
|
||||
assert not logger.handlers
|
||||
|
||||
def test_should_not_fail_with_logger_without_console_handler(self):
|
||||
logger = logging.Logger('test')
|
||||
logger.handlers = []
|
||||
with tqdm_logging_redirect(total=1, loggers=[logger]):
|
||||
logger.info('test')
|
||||
assert not logger.handlers
|
||||
|
||||
def test_should_format_message(self):
|
||||
logger = logging.Logger('test')
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
r'prefix:%(message)s'
|
||||
))
|
||||
logger.handlers = [console_handler]
|
||||
CustomTqdm.messages = []
|
||||
with tqdm_logging_redirect(loggers=[logger], tqdm_class=CustomTqdm):
|
||||
logger.info('test')
|
||||
assert CustomTqdm.messages == ['prefix:test']
|
||||
|
||||
def test_use_root_logger_by_default_and_write_to_custom_tqdm(self):
|
||||
logger = logging.root
|
||||
CustomTqdm.messages = []
|
||||
with tqdm_logging_redirect(total=1, tqdm_class=CustomTqdm) as pbar:
|
||||
assert isinstance(pbar, CustomTqdm)
|
||||
logger.info('test')
|
||||
assert CustomTqdm.messages == ['test']
|
20
tests/tests_dask.py
Normal file
20
tests/tests_dask.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from __future__ import division
|
||||
|
||||
from time import sleep
|
||||
|
||||
from .tests_tqdm import importorskip, mark
|
||||
|
||||
pytestmark = mark.slow
|
||||
|
||||
|
||||
def test_dask(capsys):
|
||||
"""Test tqdm.dask.TqdmCallback"""
|
||||
ProgressBar = importorskip('tqdm.dask').TqdmCallback
|
||||
dask = importorskip('dask')
|
||||
|
||||
schedule = [dask.delayed(sleep)(i / 10) for i in range(5)]
|
||||
with ProgressBar(desc="computing"):
|
||||
dask.compute(schedule)
|
||||
_, err = capsys.readouterr()
|
||||
assert "computing: " in err
|
||||
assert '5/5' in err
|
7
tests/tests_gui.py
Normal file
7
tests/tests_gui.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
"""Test `tqdm.gui`."""
|
||||
from .tests_tqdm import importorskip
|
||||
|
||||
|
||||
def test_gui_import():
|
||||
"""Test `tqdm.gui` import"""
|
||||
importorskip('tqdm.gui')
|
26
tests/tests_itertools.py
Normal file
26
tests/tests_itertools.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
Tests for `tqdm.contrib.itertools`.
|
||||
"""
|
||||
import itertools as it
|
||||
|
||||
from tqdm.contrib.itertools import product
|
||||
|
||||
from .tests_tqdm import StringIO, closing
|
||||
|
||||
|
||||
class NoLenIter(object):
|
||||
def __init__(self, iterable):
|
||||
self._it = iterable
|
||||
|
||||
def __iter__(self):
|
||||
for i in self._it:
|
||||
yield i
|
||||
|
||||
|
||||
def test_product():
|
||||
"""Test contrib.itertools.product"""
|
||||
with closing(StringIO()) as our_file:
|
||||
a = range(9)
|
||||
assert list(product(a, a[::-1], file=our_file)) == list(it.product(a, a[::-1]))
|
||||
|
||||
assert list(product(a, NoLenIter(a), file=our_file)) == list(it.product(a, NoLenIter(a)))
|
93
tests/tests_keras.py
Normal file
93
tests/tests_keras.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from __future__ import division
|
||||
|
||||
from .tests_tqdm import importorskip, mark
|
||||
|
||||
pytestmark = mark.slow
|
||||
|
||||
|
||||
@mark.filterwarnings("ignore:.*:DeprecationWarning")
|
||||
def test_keras(capsys):
|
||||
"""Test tqdm.keras.TqdmCallback"""
|
||||
TqdmCallback = importorskip('tqdm.keras').TqdmCallback
|
||||
np = importorskip('numpy')
|
||||
try:
|
||||
import keras as K
|
||||
except ImportError:
|
||||
K = importorskip('tensorflow.keras')
|
||||
|
||||
# 1D autoencoder
|
||||
dtype = np.float32
|
||||
model = K.models.Sequential([
|
||||
K.layers.InputLayer((1, 1), dtype=dtype), K.layers.Conv1D(1, 1)])
|
||||
model.compile("adam", "mse")
|
||||
x = np.random.rand(100, 1, 1).astype(dtype)
|
||||
batch_size = 10
|
||||
batches = len(x) / batch_size
|
||||
epochs = 5
|
||||
|
||||
# just epoch (no batch) progress
|
||||
model.fit(
|
||||
x,
|
||||
x,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
verbose=False,
|
||||
callbacks=[
|
||||
TqdmCallback(
|
||||
epochs,
|
||||
desc="training",
|
||||
data_size=len(x),
|
||||
batch_size=batch_size,
|
||||
verbose=0)])
|
||||
_, res = capsys.readouterr()
|
||||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||
assert "{batches}/{batches}".format(batches=batches) not in res
|
||||
|
||||
# full (epoch and batch) progress
|
||||
model.fit(
|
||||
x,
|
||||
x,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
verbose=False,
|
||||
callbacks=[
|
||||
TqdmCallback(
|
||||
epochs,
|
||||
desc="training",
|
||||
data_size=len(x),
|
||||
batch_size=batch_size,
|
||||
verbose=2)])
|
||||
_, res = capsys.readouterr()
|
||||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||
assert "{batches}/{batches}".format(batches=batches) in res
|
||||
|
||||
# auto-detect epochs and batches
|
||||
model.fit(
|
||||
x,
|
||||
x,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
verbose=False,
|
||||
callbacks=[TqdmCallback(desc="training", verbose=2)])
|
||||
_, res = capsys.readouterr()
|
||||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
||||
assert "{batches}/{batches}".format(batches=batches) in res
|
||||
|
||||
# continue training (start from epoch != 0)
|
||||
initial_epoch = 3
|
||||
model.fit(
|
||||
x,
|
||||
x,
|
||||
initial_epoch=initial_epoch,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
verbose=False,
|
||||
callbacks=[TqdmCallback(desc="training", verbose=0,
|
||||
miniters=1, mininterval=0, maxinterval=0)])
|
||||
_, res = capsys.readouterr()
|
||||
assert "training: " in res
|
||||
assert "{epochs}/{epochs}".format(epochs=initial_epoch - 1) not in res
|
||||
assert "{epochs}/{epochs}".format(epochs=epochs) in res
|
245
tests/tests_main.py
Normal file
245
tests/tests_main.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
"""Test CLI usage."""
|
||||
import logging
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
from functools import wraps
|
||||
from os import linesep
|
||||
|
||||
from tqdm.cli import TqdmKeyError, TqdmTypeError, main
|
||||
from tqdm.utils import IS_WIN
|
||||
|
||||
from .tests_tqdm import BytesIO, _range, closing, mark, raises
|
||||
|
||||
|
||||
def restore_sys(func):
|
||||
"""Decorates `func(capsysbin)` to save & restore `sys.(stdin|argv)`."""
|
||||
@wraps(func)
|
||||
def inner(capsysbin):
|
||||
"""function requiring capsysbin which may alter `sys.(stdin|argv)`"""
|
||||
_SYS = sys.stdin, sys.argv
|
||||
try:
|
||||
res = func(capsysbin)
|
||||
finally:
|
||||
sys.stdin, sys.argv = _SYS
|
||||
return res
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def norm(bytestr):
|
||||
"""Normalise line endings."""
|
||||
return bytestr if linesep == "\n" else bytestr.replace(linesep.encode(), b"\n")
|
||||
|
||||
|
||||
@mark.slow
|
||||
def test_pipes():
|
||||
"""Test command line pipes"""
|
||||
ls_out = subprocess.check_output(['ls']) # nosec
|
||||
ls = subprocess.Popen(['ls'], stdout=subprocess.PIPE) # nosec
|
||||
res = subprocess.Popen( # nosec
|
||||
[sys.executable, '-c', 'from tqdm.cli import main; main()'],
|
||||
stdin=ls.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, err = res.communicate()
|
||||
assert ls.poll() == 0
|
||||
|
||||
# actual test:
|
||||
assert norm(ls_out) == norm(out)
|
||||
assert b"it/s" in err
|
||||
assert b"Error" not in err
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 8):
|
||||
test_pipes = mark.filterwarnings("ignore:unclosed file:ResourceWarning")(
|
||||
test_pipes)
|
||||
|
||||
|
||||
def test_main_import():
|
||||
"""Test main CLI import"""
|
||||
N = 123
|
||||
_SYS = sys.stdin, sys.argv
|
||||
# test direct import
|
||||
sys.stdin = [str(i).encode() for i in _range(N)]
|
||||
sys.argv = ['', '--desc', 'Test CLI import',
|
||||
'--ascii', 'True', '--unit_scale', 'True']
|
||||
try:
|
||||
import tqdm.__main__ # NOQA, pylint: disable=unused-variable
|
||||
finally:
|
||||
sys.stdin, sys.argv = _SYS
|
||||
|
||||
|
||||
@restore_sys
|
||||
def test_main_bytes(capsysbin):
|
||||
"""Test CLI --bytes"""
|
||||
N = 123
|
||||
|
||||
# test --delim
|
||||
IN_DATA = '\0'.join(map(str, _range(N))).encode()
|
||||
with closing(BytesIO()) as sys.stdin:
|
||||
sys.stdin.write(IN_DATA)
|
||||
# sys.stdin.write(b'\xff') # TODO
|
||||
sys.stdin.seek(0)
|
||||
main(sys.stderr, ['--desc', 'Test CLI delim', '--ascii', 'True',
|
||||
'--delim', r'\0', '--buf_size', '64'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert out == IN_DATA
|
||||
assert str(N) + "it" in err.decode("U8")
|
||||
|
||||
# test --bytes
|
||||
IN_DATA = IN_DATA.replace(b'\0', b'\n')
|
||||
with closing(BytesIO()) as sys.stdin:
|
||||
sys.stdin.write(IN_DATA)
|
||||
sys.stdin.seek(0)
|
||||
main(sys.stderr, ['--ascii', '--bytes=True', '--unit_scale', 'False'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert out == IN_DATA
|
||||
assert str(len(IN_DATA)) + "B" in err.decode("U8")
|
||||
|
||||
|
||||
@mark.skipif(sys.version_info[0] == 2, reason="no caplog on py2")
|
||||
def test_main_log(capsysbin, caplog):
|
||||
"""Test CLI --log"""
|
||||
_SYS = sys.stdin, sys.argv
|
||||
N = 123
|
||||
sys.stdin = [(str(i) + '\n').encode() for i in _range(N)]
|
||||
IN_DATA = b''.join(sys.stdin)
|
||||
try:
|
||||
with caplog.at_level(logging.INFO):
|
||||
main(sys.stderr, ['--log', 'INFO'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA and b"123/123" in err
|
||||
assert not caplog.record_tuples
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
main(sys.stderr, ['--log', 'DEBUG'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA and b"123/123" in err
|
||||
assert caplog.record_tuples
|
||||
finally:
|
||||
sys.stdin, sys.argv = _SYS
|
||||
|
||||
|
||||
@restore_sys
|
||||
def test_main(capsysbin):
|
||||
"""Test misc CLI options"""
|
||||
N = 123
|
||||
sys.stdin = [(str(i) + '\n').encode() for i in _range(N)]
|
||||
IN_DATA = b''.join(sys.stdin)
|
||||
|
||||
# test --tee
|
||||
main(sys.stderr, ['--mininterval', '0', '--miniters', '1'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA and b"123/123" in err
|
||||
assert N <= len(err.split(b"\r")) < N + 5
|
||||
|
||||
len_err = len(err)
|
||||
main(sys.stderr, ['--tee', '--mininterval', '0', '--miniters', '1'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA and b"123/123" in err
|
||||
# spaces to clear intermediate lines could increase length
|
||||
assert len_err + len(norm(out)) <= len(err)
|
||||
|
||||
# test --null
|
||||
main(sys.stderr, ['--null'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert not out and b"123/123" in err
|
||||
|
||||
# test integer --update
|
||||
main(sys.stderr, ['--update'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum formula"
|
||||
|
||||
# test integer --update_to
|
||||
main(sys.stderr, ['--update-to'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
assert (str(N - 1) + "it").encode() in err
|
||||
assert (str(N) + "it").encode() not in err
|
||||
|
||||
with closing(BytesIO()) as sys.stdin:
|
||||
sys.stdin.write(IN_DATA.replace(b'\n', b'D'))
|
||||
|
||||
# test integer --update --delim
|
||||
sys.stdin.seek(0)
|
||||
main(sys.stderr, ['--update', '--delim', 'D'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert out == IN_DATA.replace(b'\n', b'D')
|
||||
assert (str(N // 2 * N) + "it").encode() in err, "expected arithmetic sum"
|
||||
|
||||
# test integer --update_to --delim
|
||||
sys.stdin.seek(0)
|
||||
main(sys.stderr, ['--update-to', '--delim', 'D'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert out == IN_DATA.replace(b'\n', b'D')
|
||||
assert (str(N - 1) + "it").encode() in err
|
||||
assert (str(N) + "it").encode() not in err
|
||||
|
||||
# test float --update_to
|
||||
sys.stdin = [(str(i / 2.0) + '\n').encode() for i in _range(N)]
|
||||
IN_DATA = b''.join(sys.stdin)
|
||||
main(sys.stderr, ['--update-to'])
|
||||
out, err = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
assert (str((N - 1) / 2.0) + "it").encode() in err
|
||||
assert (str(N / 2.0) + "it").encode() not in err
|
||||
|
||||
|
||||
@mark.slow
|
||||
@mark.skipif(IS_WIN, reason="no manpages on windows")
|
||||
def test_manpath(tmp_path):
|
||||
"""Test CLI --manpath"""
|
||||
man = tmp_path / "tqdm.1"
|
||||
assert not man.exists()
|
||||
with raises(SystemExit):
|
||||
main(argv=['--manpath', str(tmp_path)])
|
||||
assert man.is_file()
|
||||
|
||||
|
||||
@mark.slow
|
||||
@mark.skipif(IS_WIN, reason="no completion on windows")
|
||||
def test_comppath(tmp_path):
|
||||
"""Test CLI --comppath"""
|
||||
man = tmp_path / "tqdm_completion.sh"
|
||||
assert not man.exists()
|
||||
with raises(SystemExit):
|
||||
main(argv=['--comppath', str(tmp_path)])
|
||||
assert man.is_file()
|
||||
|
||||
# check most important options appear
|
||||
script = man.read_text()
|
||||
opts = {'--help', '--desc', '--total', '--leave', '--ncols', '--ascii',
|
||||
'--dynamic_ncols', '--position', '--bytes', '--nrows', '--delim',
|
||||
'--manpath', '--comppath'}
|
||||
assert all(args in script for args in opts)
|
||||
|
||||
|
||||
@restore_sys
|
||||
def test_exceptions(capsysbin):
|
||||
"""Test CLI Exceptions"""
|
||||
N = 123
|
||||
sys.stdin = [str(i) + '\n' for i in _range(N)]
|
||||
IN_DATA = ''.join(sys.stdin).encode()
|
||||
|
||||
with raises(TqdmKeyError, match="bad_arg_u_ment"):
|
||||
main(sys.stderr, argv=['-ascii', '-unit_scale', '--bad_arg_u_ment', 'foo'])
|
||||
out, _ = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
|
||||
with raises(TqdmTypeError, match="invalid_bool_value"):
|
||||
main(sys.stderr, argv=['-ascii', '-unit_scale', 'invalid_bool_value'])
|
||||
out, _ = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
|
||||
with raises(TqdmTypeError, match="invalid_int_value"):
|
||||
main(sys.stderr, argv=['-ascii', '--total', 'invalid_int_value'])
|
||||
out, _ = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
|
||||
with raises(TqdmKeyError, match="Can only have one of --"):
|
||||
main(sys.stderr, argv=['--update', '--update_to'])
|
||||
out, _ = capsysbin.readouterr()
|
||||
assert norm(out) == IN_DATA
|
||||
|
||||
# test SystemExits
|
||||
for i in ('-h', '--help', '-v', '--version'):
|
||||
with raises(SystemExit):
|
||||
main(argv=[i])
|
7
tests/tests_notebook.py
Normal file
7
tests/tests_notebook.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from tqdm.notebook import tqdm as tqdm_notebook
|
||||
|
||||
|
||||
def test_notebook_disabled_description():
|
||||
"""Test that set_description works for disabled tqdm_notebook"""
|
||||
with tqdm_notebook(1, disable=True) as t:
|
||||
t.set_description("description")
|
219
tests/tests_pandas.py
Normal file
219
tests/tests_pandas.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
from tqdm import tqdm
|
||||
|
||||
from .tests_tqdm import StringIO, closing, importorskip, mark, skip
|
||||
|
||||
pytestmark = mark.slow
|
||||
|
||||
random = importorskip('numpy.random')
|
||||
rand = random.rand
|
||||
randint = random.randint
|
||||
pd = importorskip('pandas')
|
||||
|
||||
|
||||
def test_pandas_setup():
|
||||
"""Test tqdm.pandas()"""
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True, total=123)
|
||||
series = pd.Series(randint(0, 50, (100,)))
|
||||
series.progress_apply(lambda x: x + 10)
|
||||
res = our_file.getvalue()
|
||||
assert '100/123' in res
|
||||
|
||||
|
||||
def test_pandas_rolling_expanding():
|
||||
"""Test pandas.(Series|DataFrame).(rolling|expanding)"""
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True)
|
||||
|
||||
series = pd.Series(randint(0, 50, (123,)))
|
||||
res1 = series.rolling(10).progress_apply(lambda x: 1, raw=True)
|
||||
res2 = series.rolling(10).apply(lambda x: 1, raw=True)
|
||||
assert res1.equals(res2)
|
||||
|
||||
res3 = series.expanding(10).progress_apply(lambda x: 2, raw=True)
|
||||
res4 = series.expanding(10).apply(lambda x: 2, raw=True)
|
||||
assert res3.equals(res4)
|
||||
|
||||
expects = ['114it'] # 123-10+1
|
||||
for exres in expects:
|
||||
our_file.seek(0)
|
||||
if our_file.getvalue().count(exres) < 2:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
|
||||
exres + " at least twice.", our_file.read()))
|
||||
|
||||
|
||||
def test_pandas_series():
|
||||
"""Test pandas.Series.progress_apply and .progress_map"""
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True)
|
||||
|
||||
series = pd.Series(randint(0, 50, (123,)))
|
||||
res1 = series.progress_apply(lambda x: x + 10)
|
||||
res2 = series.apply(lambda x: x + 10)
|
||||
assert res1.equals(res2)
|
||||
|
||||
res3 = series.progress_map(lambda x: x + 10)
|
||||
res4 = series.map(lambda x: x + 10)
|
||||
assert res3.equals(res4)
|
||||
|
||||
expects = ['100%', '123/123']
|
||||
for exres in expects:
|
||||
our_file.seek(0)
|
||||
if our_file.getvalue().count(exres) < 2:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
|
||||
exres + " at least twice.", our_file.read()))
|
||||
|
||||
|
||||
def test_pandas_data_frame():
|
||||
"""Test pandas.DataFrame.progress_apply and .progress_applymap"""
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True)
|
||||
df = pd.DataFrame(randint(0, 50, (100, 200)))
|
||||
|
||||
def task_func(x):
|
||||
return x + 1
|
||||
|
||||
# applymap
|
||||
res1 = df.progress_applymap(task_func)
|
||||
res2 = df.applymap(task_func)
|
||||
assert res1.equals(res2)
|
||||
|
||||
# apply unhashable
|
||||
res1 = []
|
||||
df.progress_apply(res1.extend)
|
||||
assert len(res1) == df.size
|
||||
|
||||
# apply
|
||||
for axis in [0, 1, 'index', 'columns']:
|
||||
res3 = df.progress_apply(task_func, axis=axis)
|
||||
res4 = df.apply(task_func, axis=axis)
|
||||
assert res3.equals(res4)
|
||||
|
||||
our_file.seek(0)
|
||||
if our_file.read().count('100%') < 3:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
|
||||
'100% at least three times', our_file.read()))
|
||||
|
||||
# apply_map, apply axis=0, apply axis=1
|
||||
expects = ['20000/20000', '200/200', '100/100']
|
||||
for exres in expects:
|
||||
our_file.seek(0)
|
||||
if our_file.getvalue().count(exres) < 1:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format(
|
||||
exres + " at least once.", our_file.read()))
|
||||
|
||||
|
||||
def test_pandas_groupby_apply():
|
||||
"""Test pandas.DataFrame.groupby(...).progress_apply"""
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=False, ascii=True)
|
||||
|
||||
df = pd.DataFrame(randint(0, 50, (500, 3)))
|
||||
df.groupby(0).progress_apply(lambda x: None)
|
||||
|
||||
dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc'))
|
||||
dfs.groupby(['a']).progress_apply(lambda x: None)
|
||||
|
||||
df2 = df = pd.DataFrame({'a': randint(1, 8, 10000), 'b': rand(10000)})
|
||||
res1 = df2.groupby("a").apply(max)
|
||||
res2 = df2.groupby("a").progress_apply(max)
|
||||
assert res1.equals(res2)
|
||||
|
||||
our_file.seek(0)
|
||||
|
||||
# don't expect final output since no `leave` and
|
||||
# high dynamic `miniters`
|
||||
nexres = '100%|##########|'
|
||||
if nexres in our_file.read():
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nDid not expect:\n{0}\nIn:{1}\n".format(
|
||||
nexres, our_file.read()))
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True)
|
||||
|
||||
dfs = pd.DataFrame(randint(0, 50, (500, 3)), columns=list('abc'))
|
||||
dfs.loc[0] = [2, 1, 1]
|
||||
dfs['d'] = 100
|
||||
|
||||
expects = ['500/500', '1/1', '4/4', '2/2']
|
||||
dfs.groupby(dfs.index).progress_apply(lambda x: None)
|
||||
dfs.groupby('d').progress_apply(lambda x: None)
|
||||
dfs.groupby(dfs.columns, axis=1).progress_apply(lambda x: None)
|
||||
dfs.groupby([2, 2, 1, 1], axis=1).progress_apply(lambda x: None)
|
||||
|
||||
our_file.seek(0)
|
||||
if our_file.read().count('100%') < 4:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n{1}\n".format(
|
||||
'100% at least four times', our_file.read()))
|
||||
|
||||
for exres in expects:
|
||||
our_file.seek(0)
|
||||
if our_file.getvalue().count(exres) < 1:
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:\n {1}\n".format(
|
||||
exres + " at least once.", our_file.read()))
|
||||
|
||||
|
||||
def test_pandas_leave():
|
||||
"""Test pandas with `leave=True`"""
|
||||
with closing(StringIO()) as our_file:
|
||||
df = pd.DataFrame(randint(0, 100, (1000, 6)))
|
||||
tqdm.pandas(file=our_file, leave=True, ascii=True)
|
||||
df.groupby(0).progress_apply(lambda x: None)
|
||||
|
||||
our_file.seek(0)
|
||||
|
||||
exres = '100%|##########| 100/100'
|
||||
if exres not in our_file.read():
|
||||
our_file.seek(0)
|
||||
raise AssertionError("\nExpected:\n{0}\nIn:{1}\n".format(
|
||||
exres, our_file.read()))
|
||||
|
||||
|
||||
def test_pandas_apply_args_deprecation():
|
||||
"""Test warning info in
|
||||
`pandas.Dataframe(Series).progress_apply(func, *args)`"""
|
||||
try:
|
||||
from tqdm import tqdm_pandas
|
||||
except ImportError as err:
|
||||
skip(str(err))
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20))
|
||||
df = pd.DataFrame(randint(0, 50, (500, 3)))
|
||||
df.progress_apply(lambda x: None, 1) # 1 shall cause a warning
|
||||
# Check deprecation message
|
||||
res = our_file.getvalue()
|
||||
assert all(i in res for i in (
|
||||
"TqdmDeprecationWarning", "not supported",
|
||||
"keyword arguments instead"))
|
||||
|
||||
|
||||
def test_pandas_deprecation():
|
||||
"""Test bar object instance as argument deprecation"""
|
||||
try:
|
||||
from tqdm import tqdm_pandas
|
||||
except ImportError as err:
|
||||
skip(str(err))
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm_pandas(tqdm(file=our_file, leave=False, ascii=True, ncols=20))
|
||||
df = pd.DataFrame(randint(0, 50, (500, 3)))
|
||||
df.groupby(0).progress_apply(lambda x: None)
|
||||
# Check deprecation message
|
||||
assert "TqdmDeprecationWarning" in our_file.getvalue()
|
||||
assert "instead of `tqdm_pandas(tqdm(...))`" in our_file.getvalue()
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
tqdm_pandas(tqdm, file=our_file, leave=False, ascii=True, ncols=20)
|
||||
df = pd.DataFrame(randint(0, 50, (500, 3)))
|
||||
df.groupby(0).progress_apply(lambda x: None)
|
||||
# Check deprecation message
|
||||
assert "TqdmDeprecationWarning" in our_file.getvalue()
|
||||
assert "instead of `tqdm_pandas(tqdm, ...)`" in our_file.getvalue()
|
325
tests/tests_perf.py
Normal file
325
tests/tests_perf.py
Normal file
|
@ -0,0 +1,325 @@
|
|||
from __future__ import division, print_function
|
||||
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from time import sleep, time
|
||||
|
||||
# Use relative/cpu timer to have reliable timings when there is a sudden load
|
||||
try:
|
||||
from time import process_time
|
||||
except ImportError:
|
||||
from time import clock
|
||||
process_time = clock
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from .tests_tqdm import _range, importorskip, mark, patch_lock, skip
|
||||
|
||||
pytestmark = mark.slow
|
||||
|
||||
|
||||
def cpu_sleep(t):
|
||||
"""Sleep the given amount of cpu time"""
|
||||
start = process_time()
|
||||
while (process_time() - start) < t:
|
||||
pass
|
||||
|
||||
|
||||
def checkCpuTime(sleeptime=0.2):
|
||||
"""Check if cpu time works correctly"""
|
||||
if checkCpuTime.passed:
|
||||
return True
|
||||
# First test that sleeping does not consume cputime
|
||||
start1 = process_time()
|
||||
sleep(sleeptime)
|
||||
t1 = process_time() - start1
|
||||
|
||||
# secondly check by comparing to cpusleep (where we actually do something)
|
||||
start2 = process_time()
|
||||
cpu_sleep(sleeptime)
|
||||
t2 = process_time() - start2
|
||||
|
||||
if abs(t1) < 0.0001 and t1 < t2 / 10:
|
||||
checkCpuTime.passed = True
|
||||
return True
|
||||
skip("cpu time not reliable on this machine")
|
||||
|
||||
|
||||
checkCpuTime.passed = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def relative_timer():
|
||||
"""yields a context timer function which stops ticking on exit"""
|
||||
start = process_time()
|
||||
|
||||
def elapser():
|
||||
return process_time() - start
|
||||
|
||||
yield lambda: elapser()
|
||||
spent = elapser()
|
||||
|
||||
def elapser(): # NOQA
|
||||
return spent
|
||||
|
||||
|
||||
def retry_on_except(n=3, check_cpu_time=True):
|
||||
"""decroator for retrying `n` times before raising Exceptions"""
|
||||
def wrapper(func):
|
||||
"""actual decorator"""
|
||||
@wraps(func)
|
||||
def test_inner(*args, **kwargs):
|
||||
"""may skip if `check_cpu_time` fails"""
|
||||
for i in range(1, n + 1):
|
||||
try:
|
||||
if check_cpu_time:
|
||||
checkCpuTime()
|
||||
func(*args, **kwargs)
|
||||
except Exception:
|
||||
if i >= n:
|
||||
raise
|
||||
else:
|
||||
return
|
||||
return test_inner
|
||||
return wrapper
|
||||
|
||||
|
||||
def simple_progress(iterable=None, total=None, file=sys.stdout, desc='',
|
||||
leave=False, miniters=1, mininterval=0.1, width=60):
|
||||
"""Simple progress bar reproducing tqdm's major features"""
|
||||
n = [0] # use a closure
|
||||
start_t = [time()]
|
||||
last_n = [0]
|
||||
last_t = [0]
|
||||
if iterable is not None:
|
||||
total = len(iterable)
|
||||
|
||||
def format_interval(t):
|
||||
mins, s = divmod(int(t), 60)
|
||||
h, m = divmod(mins, 60)
|
||||
if h:
|
||||
return '{0:d}:{1:02d}:{2:02d}'.format(h, m, s)
|
||||
else:
|
||||
return '{0:02d}:{1:02d}'.format(m, s)
|
||||
|
||||
def update_and_print(i=1):
|
||||
n[0] += i
|
||||
if (n[0] - last_n[0]) >= miniters:
|
||||
last_n[0] = n[0]
|
||||
|
||||
if (time() - last_t[0]) >= mininterval:
|
||||
last_t[0] = time() # last_t[0] == current time
|
||||
|
||||
spent = last_t[0] - start_t[0]
|
||||
spent_fmt = format_interval(spent)
|
||||
rate = n[0] / spent if spent > 0 else 0
|
||||
rate_fmt = "%.2fs/it" % (1.0 / rate) if 0.0 < rate < 1.0 else "%.2fit/s" % rate
|
||||
|
||||
frac = n[0] / total
|
||||
percentage = int(frac * 100)
|
||||
eta = (total - n[0]) / rate if rate > 0 else 0
|
||||
eta_fmt = format_interval(eta)
|
||||
|
||||
# full_bar = "#" * int(frac * width)
|
||||
barfill = " " * int((1.0 - frac) * width)
|
||||
bar_length, frac_bar_length = divmod(int(frac * width * 10), 10)
|
||||
full_bar = '#' * bar_length
|
||||
frac_bar = chr(48 + frac_bar_length) if frac_bar_length else ' '
|
||||
|
||||
file.write("\r%s %i%%|%s%s%s| %i/%i [%s<%s, %s]" %
|
||||
(desc, percentage, full_bar, frac_bar, barfill, n[0],
|
||||
total, spent_fmt, eta_fmt, rate_fmt))
|
||||
|
||||
if n[0] == total and leave:
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
def update_and_yield():
|
||||
for elt in iterable:
|
||||
yield elt
|
||||
update_and_print()
|
||||
|
||||
update_and_print(0)
|
||||
if iterable is not None:
|
||||
return update_and_yield()
|
||||
else:
|
||||
return update_and_print
|
||||
|
||||
|
||||
def assert_performance(thresh, name_left, time_left, name_right, time_right):
|
||||
"""raises if time_left > thresh * time_right"""
|
||||
if time_left > thresh * time_right:
|
||||
raise ValueError(
|
||||
('{name[0]}: {time[0]:f}, '
|
||||
'{name[1]}: {time[1]:f}, '
|
||||
'ratio {ratio:f} > {thresh:f}').format(
|
||||
name=(name_left, name_right),
|
||||
time=(time_left, time_right),
|
||||
ratio=time_left / time_right, thresh=thresh))
|
||||
|
||||
|
||||
@retry_on_except()
|
||||
def test_iter_basic_overhead():
|
||||
"""Test overhead of iteration based tqdm"""
|
||||
total = int(1e6)
|
||||
|
||||
a = 0
|
||||
with trange(total) as t:
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in t:
|
||||
a += i
|
||||
assert a == (total ** 2 - total) / 2.0
|
||||
|
||||
a = 0
|
||||
with relative_timer() as time_bench:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
sys.stdout.write(str(a))
|
||||
|
||||
assert_performance(3, 'trange', time_tqdm(), 'range', time_bench())
|
||||
|
||||
|
||||
@retry_on_except()
|
||||
def test_manual_basic_overhead():
|
||||
"""Test overhead of manual tqdm"""
|
||||
total = int(1e6)
|
||||
|
||||
with tqdm(total=total * 10, leave=True) as t:
|
||||
a = 0
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
t.update(10)
|
||||
|
||||
a = 0
|
||||
with relative_timer() as time_bench:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
sys.stdout.write(str(a))
|
||||
|
||||
assert_performance(5, 'tqdm', time_tqdm(), 'range', time_bench())
|
||||
|
||||
|
||||
def worker(total, blocking=True):
|
||||
def incr_bar(x):
|
||||
for _ in trange(total, lock_args=None if blocking else (False,),
|
||||
miniters=1, mininterval=0, maxinterval=0):
|
||||
pass
|
||||
return x + 1
|
||||
return incr_bar
|
||||
|
||||
|
||||
@retry_on_except()
|
||||
@patch_lock(thread=True)
|
||||
def test_lock_args():
|
||||
"""Test overhead of nonblocking threads"""
|
||||
ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor
|
||||
|
||||
total = 16
|
||||
subtotal = 10000
|
||||
|
||||
with ThreadPoolExecutor() as pool:
|
||||
sys.stderr.write('block ... ')
|
||||
sys.stderr.flush()
|
||||
with relative_timer() as time_tqdm:
|
||||
res = list(pool.map(worker(subtotal, True), range(total)))
|
||||
assert sum(res) == sum(range(total)) + total
|
||||
sys.stderr.write('noblock ... ')
|
||||
sys.stderr.flush()
|
||||
with relative_timer() as time_noblock:
|
||||
res = list(pool.map(worker(subtotal, False), range(total)))
|
||||
assert sum(res) == sum(range(total)) + total
|
||||
|
||||
assert_performance(0.5, 'noblock', time_noblock(), 'tqdm', time_tqdm())
|
||||
|
||||
|
||||
@retry_on_except(10)
|
||||
def test_iter_overhead_hard():
|
||||
"""Test overhead of iteration based tqdm (hard)"""
|
||||
total = int(1e5)
|
||||
|
||||
a = 0
|
||||
with trange(total, leave=True, miniters=1,
|
||||
mininterval=0, maxinterval=0) as t:
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in t:
|
||||
a += i
|
||||
assert a == (total ** 2 - total) / 2.0
|
||||
|
||||
a = 0
|
||||
with relative_timer() as time_bench:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
sys.stdout.write(("%i" % a) * 40)
|
||||
|
||||
assert_performance(130, 'trange', time_tqdm(), 'range', time_bench())
|
||||
|
||||
|
||||
@retry_on_except(10)
|
||||
def test_manual_overhead_hard():
|
||||
"""Test overhead of manual tqdm (hard)"""
|
||||
total = int(1e5)
|
||||
|
||||
with tqdm(total=total * 10, leave=True, miniters=1,
|
||||
mininterval=0, maxinterval=0) as t:
|
||||
a = 0
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
t.update(10)
|
||||
|
||||
a = 0
|
||||
with relative_timer() as time_bench:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
sys.stdout.write(("%i" % a) * 40)
|
||||
|
||||
assert_performance(130, 'tqdm', time_tqdm(), 'range', time_bench())
|
||||
|
||||
|
||||
@retry_on_except(10)
|
||||
def test_iter_overhead_simplebar_hard():
|
||||
"""Test overhead of iteration based tqdm vs simple progress bar (hard)"""
|
||||
total = int(1e4)
|
||||
|
||||
a = 0
|
||||
with trange(total, leave=True, miniters=1,
|
||||
mininterval=0, maxinterval=0) as t:
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in t:
|
||||
a += i
|
||||
assert a == (total ** 2 - total) / 2.0
|
||||
|
||||
a = 0
|
||||
s = simple_progress(_range(total), leave=True,
|
||||
miniters=1, mininterval=0)
|
||||
with relative_timer() as time_bench:
|
||||
for i in s:
|
||||
a += i
|
||||
|
||||
assert_performance(10, 'trange', time_tqdm(), 'simple_progress', time_bench())
|
||||
|
||||
|
||||
@retry_on_except(10)
|
||||
def test_manual_overhead_simplebar_hard():
|
||||
"""Test overhead of manual tqdm vs simple progress bar (hard)"""
|
||||
total = int(1e4)
|
||||
|
||||
with tqdm(total=total * 10, leave=True, miniters=1,
|
||||
mininterval=0, maxinterval=0) as t:
|
||||
a = 0
|
||||
with relative_timer() as time_tqdm:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
t.update(10)
|
||||
|
||||
simplebar_update = simple_progress(total=total * 10, leave=True,
|
||||
miniters=1, mininterval=0)
|
||||
a = 0
|
||||
with relative_timer() as time_bench:
|
||||
for i in _range(total):
|
||||
a += i
|
||||
simplebar_update(10)
|
||||
|
||||
assert_performance(10, 'tqdm', time_tqdm(), 'simple_progress', time_bench())
|
10
tests/tests_rich.py
Normal file
10
tests/tests_rich.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
"""Test `tqdm.rich`."""
|
||||
import sys
|
||||
|
||||
from .tests_tqdm import importorskip, mark
|
||||
|
||||
|
||||
@mark.skipif(sys.version_info[:3] < (3, 6, 1), reason="`rich` needs py>=3.6.1")
|
||||
def test_rich_import():
|
||||
"""Test `tqdm.rich` import"""
|
||||
importorskip('tqdm.rich')
|
224
tests/tests_synchronisation.py
Normal file
224
tests/tests_synchronisation.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
from __future__ import division
|
||||
|
||||
import sys
|
||||
from functools import wraps
|
||||
from threading import Event
|
||||
from time import sleep, time
|
||||
|
||||
from tqdm import TMonitor, tqdm, trange
|
||||
|
||||
from .tests_perf import retry_on_except
|
||||
from .tests_tqdm import StringIO, closing, importorskip, patch_lock, skip
|
||||
|
||||
|
||||
class Time(object):
|
||||
"""Fake time class class providing an offset"""
|
||||
offset = 0
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""zeroes internal offset"""
|
||||
cls.offset = 0
|
||||
|
||||
@classmethod
|
||||
def time(cls):
|
||||
"""time.time() + offset"""
|
||||
return time() + cls.offset
|
||||
|
||||
@staticmethod
|
||||
def sleep(dur):
|
||||
"""identical to time.sleep()"""
|
||||
sleep(dur)
|
||||
|
||||
@classmethod
|
||||
def fake_sleep(cls, dur):
|
||||
"""adds `dur` to internal offset"""
|
||||
cls.offset += dur
|
||||
sleep(0.000001) # sleep to allow interrupt (instead of pass)
|
||||
|
||||
|
||||
def FakeEvent():
|
||||
"""patched `threading.Event` where `wait()` uses `Time.fake_sleep()`"""
|
||||
event = Event() # not a class in py2 so can't inherit
|
||||
|
||||
def wait(timeout=None):
|
||||
"""uses Time.fake_sleep"""
|
||||
if timeout is not None:
|
||||
Time.fake_sleep(timeout)
|
||||
return event.is_set()
|
||||
|
||||
event.wait = wait
|
||||
return event
|
||||
|
||||
|
||||
def patch_sleep(func):
|
||||
"""Temporarily makes TMonitor use Time.fake_sleep"""
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
"""restores TMonitor on completion regardless of Exceptions"""
|
||||
TMonitor._test["time"] = Time.time
|
||||
TMonitor._test["Event"] = FakeEvent
|
||||
if tqdm.monitor:
|
||||
assert not tqdm.monitor.get_instances()
|
||||
tqdm.monitor.exit()
|
||||
del tqdm.monitor
|
||||
tqdm.monitor = None
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# Check that class var monitor is deleted if no instance left
|
||||
tqdm.monitor_interval = 10
|
||||
if tqdm.monitor:
|
||||
assert not tqdm.monitor.get_instances()
|
||||
tqdm.monitor.exit()
|
||||
del tqdm.monitor
|
||||
tqdm.monitor = None
|
||||
TMonitor._test.pop("Event")
|
||||
TMonitor._test.pop("time")
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def cpu_timify(t, timer=Time):
|
||||
"""Force tqdm to use the specified timer instead of system-wide time"""
|
||||
t._time = timer.time
|
||||
t._sleep = timer.fake_sleep
|
||||
t.start_t = t.last_print_t = t._time()
|
||||
return timer
|
||||
|
||||
|
||||
class FakeTqdm(object):
|
||||
_instances = set()
|
||||
get_lock = tqdm.get_lock
|
||||
|
||||
|
||||
def incr(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def incr_bar(x):
|
||||
with closing(StringIO()) as our_file:
|
||||
for _ in trange(x, lock_args=(False,), file=our_file):
|
||||
pass
|
||||
return incr(x)
|
||||
|
||||
|
||||
@patch_sleep
|
||||
def test_monitor_thread():
|
||||
"""Test dummy monitoring thread"""
|
||||
monitor = TMonitor(FakeTqdm, 10)
|
||||
# Test if alive, then killed
|
||||
assert monitor.report()
|
||||
monitor.exit()
|
||||
assert not monitor.report()
|
||||
assert not monitor.is_alive()
|
||||
del monitor
|
||||
|
||||
|
||||
@patch_sleep
|
||||
def test_monitoring_and_cleanup():
|
||||
"""Test for stalled tqdm instance and monitor deletion"""
|
||||
# Note: should fix miniters for these tests, else with dynamic_miniters
|
||||
# it's too complicated to handle with monitoring update and maxinterval...
|
||||
maxinterval = tqdm.monitor_interval
|
||||
assert maxinterval == 10
|
||||
total = 1000
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
|
||||
maxinterval=maxinterval) as t:
|
||||
cpu_timify(t, Time)
|
||||
# Do a lot of iterations in a small timeframe
|
||||
# (smaller than monitor interval)
|
||||
Time.fake_sleep(maxinterval / 10) # monitor won't wake up
|
||||
t.update(500)
|
||||
# check that our fixed miniters is still there
|
||||
assert t.miniters <= 500 # TODO: should really be == 500
|
||||
# Then do 1 it after monitor interval, so that monitor kicks in
|
||||
Time.fake_sleep(maxinterval)
|
||||
t.update(1)
|
||||
# Wait for the monitor to get out of sleep's loop and update tqdm.
|
||||
timeend = Time.time()
|
||||
while not (t.monitor.woken >= timeend and t.miniters == 1):
|
||||
Time.fake_sleep(1) # Force awake up if it woken too soon
|
||||
assert t.miniters == 1 # check that monitor corrected miniters
|
||||
# Note: at this point, there may be a race condition: monitor saved
|
||||
# current woken time but Time.sleep() happen just before monitor
|
||||
# sleep. To fix that, either sleep here or increase time in a loop
|
||||
# to ensure that monitor wakes up at some point.
|
||||
|
||||
# Try again but already at miniters = 1 so nothing will be done
|
||||
Time.fake_sleep(maxinterval)
|
||||
t.update(2)
|
||||
timeend = Time.time()
|
||||
while t.monitor.woken < timeend:
|
||||
Time.fake_sleep(1) # Force awake if it woken too soon
|
||||
# Wait for the monitor to get out of sleep's loop and update
|
||||
# tqdm
|
||||
assert t.miniters == 1 # check that monitor corrected miniters
|
||||
|
||||
|
||||
@patch_sleep
|
||||
def test_monitoring_multi():
|
||||
"""Test on multiple bars, one not needing miniters adjustment"""
|
||||
# Note: should fix miniters for these tests, else with dynamic_miniters
|
||||
# it's too complicated to handle with monitoring update and maxinterval...
|
||||
maxinterval = tqdm.monitor_interval
|
||||
assert maxinterval == 10
|
||||
total = 1000
|
||||
|
||||
with closing(StringIO()) as our_file:
|
||||
with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
|
||||
maxinterval=maxinterval) as t1:
|
||||
# Set high maxinterval for t2 so monitor does not need to adjust it
|
||||
with tqdm(total=total, file=our_file, miniters=500, mininterval=0.1,
|
||||
maxinterval=1E5) as t2:
|
||||
cpu_timify(t1, Time)
|
||||
cpu_timify(t2, Time)
|
||||
# Do a lot of iterations in a small timeframe
|
||||
Time.fake_sleep(maxinterval / 10)
|
||||
t1.update(500)
|
||||
t2.update(500)
|
||||
assert t1.miniters <= 500 # TODO: should really be == 500
|
||||
assert t2.miniters == 500
|
||||
# Then do 1 it after monitor interval, so that monitor kicks in
|
||||
Time.fake_sleep(maxinterval)
|
||||
t1.update(1)
|
||||
t2.update(1)
|
||||
# Wait for the monitor to get out of sleep and update tqdm
|
||||
timeend = Time.time()
|
||||
while not (t1.monitor.woken >= timeend and t1.miniters == 1):
|
||||
Time.fake_sleep(1)
|
||||
assert t1.miniters == 1 # check that monitor corrected miniters
|
||||
assert t2.miniters == 500 # check that t2 was not adjusted
|
||||
|
||||
|
||||
def test_imap():
|
||||
"""Test multiprocessing.Pool"""
|
||||
try:
|
||||
from multiprocessing import Pool
|
||||
except ImportError as err:
|
||||
skip(str(err))
|
||||
|
||||
pool = Pool()
|
||||
res = list(tqdm(pool.imap(incr, range(100)), disable=True))
|
||||
pool.close()
|
||||
assert res[-1] == 100
|
||||
|
||||
|
||||
# py2: locks won't propagate to incr_bar so may cause `AttributeError`
|
||||
@retry_on_except(n=3 if sys.version_info < (3,) else 1, check_cpu_time=False)
|
||||
@patch_lock(thread=True)
|
||||
def test_threadpool():
|
||||
"""Test concurrent.futures.ThreadPoolExecutor"""
|
||||
ThreadPoolExecutor = importorskip('concurrent.futures').ThreadPoolExecutor
|
||||
|
||||
with ThreadPoolExecutor(8) as pool:
|
||||
try:
|
||||
res = list(tqdm(pool.map(incr_bar, range(100)), disable=True))
|
||||
except AttributeError:
|
||||
if sys.version_info < (3,):
|
||||
skip("not supported on py2")
|
||||
else:
|
||||
raise
|
||||
assert sum(res) == sum(range(1, 101))
|
7
tests/tests_tk.py
Normal file
7
tests/tests_tk.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
"""Test `tqdm.tk`."""
|
||||
from .tests_tqdm import importorskip
|
||||
|
||||
|
||||
def test_tk_import():
|
||||
"""Test `tqdm.tk` import"""
|
||||
importorskip('tqdm.tk')
|
1996
tests/tests_tqdm.py
Normal file
1996
tests/tests_tqdm.py
Normal file
File diff suppressed because it is too large
Load diff
14
tests/tests_version.py
Normal file
14
tests/tests_version.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
"""Test `tqdm.__version__`."""
|
||||
import re
|
||||
from ast import literal_eval
|
||||
|
||||
|
||||
def test_version():
|
||||
"""Test version string"""
|
||||
from tqdm import __version__
|
||||
version_parts = re.split('[.-]', __version__)
|
||||
if __version__ != "UNKNOWN":
|
||||
assert 3 <= len(version_parts), "must have at least Major.minor.patch"
|
||||
assert all(
|
||||
isinstance(literal_eval(i), int) for i in version_parts[:3]
|
||||
), "Version Major.minor.patch must be 3 integers"
|
Loading…
Add table
Add a link
Reference in a new issue