summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAndrei Lebedev <lebdron@gmail.com>2022-11-03 11:09:37 +0100
committerGitHub <noreply@github.com>2022-11-03 10:09:37 +0000
commit27ed77aabba8c9eb08d66f34092b1bfcc22c482e (patch)
tree7cc41fc5e398009a5cf8e7e4156afb0246aa34d3 /test
parentc4b19a88169fa76c5eb665d274e7270a0fe452c4 (diff)
downloadyoutube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.tar.gz
youtube-dl-27ed77aabba8c9eb08d66f34092b1bfcc22c482e.tar.xz
[utils] Backport traverse_obj (etc) from yt-dlp (#31156)
* Backport traverse_obj and closely related function from yt-dlp (code by pukkandan) * Backport LazyList, variadic(), try_call (code by pukkandan) * Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k) * Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d) * Improve/add tests for variadic, try_call, join_nonempty Co-authored-by: dirkf <fieldhouse@gmx.net>
Diffstat (limited to 'test')
-rw-r--r--test/test_utils.py323
1 files changed, 323 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index f1a748dde..9d364c863 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -12,7 +12,9 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Various small unit tests
import io
+import itertools
import json
+import re
import xml.etree.ElementTree
from youtube_dl.utils import (
@@ -40,11 +42,14 @@ from youtube_dl.utils import (
get_element_by_attribute,
get_elements_by_class,
get_elements_by_attribute,
+ get_first,
InAdvancePagedList,
int_or_none,
intlist_to_bytes,
is_html,
+ join_nonempty,
js_to_json,
+ LazyList,
limit_length,
merge_dicts,
mimetype2ext,
@@ -79,6 +84,8 @@ from youtube_dl.utils import (
strip_or_none,
subtitles_filename,
timeconvert,
+ traverse_obj,
+ try_call,
unescapeHTML,
unified_strdate,
unified_timestamp,
@@ -92,6 +99,7 @@ from youtube_dl.utils import (
urlencode_postdata,
urshift,
update_url_query,
+ variadic,
version_tuple,
xpath_with_ns,
xpath_element,
@@ -112,12 +120,18 @@ from youtube_dl.compat import (
compat_getenv,
compat_os_name,
compat_setenv,
+ compat_str,
compat_urlparse,
compat_parse_qs,
)
class TestUtil(unittest.TestCase):
+
+ # yt-dlp shim
+ def assertCountEqual(self, expected, got, msg='count should be the same'):
+ return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg)
+
def test_timeconvert(self):
self.assertTrue(timeconvert('') is None)
self.assertTrue(timeconvert('bougrg') is None)
@@ -1478,6 +1492,315 @@ Line 1
self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')
self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3')
+ def test_LazyList(self):
+ it = list(range(10))
+
+ self.assertEqual(list(LazyList(it)), it)
+ self.assertEqual(LazyList(it).exhaust(), it)
+ self.assertEqual(LazyList(it)[5], it[5])
+
+ self.assertEqual(LazyList(it)[5:], it[5:])
+ self.assertEqual(LazyList(it)[:5], it[:5])
+ self.assertEqual(LazyList(it)[::2], it[::2])
+ self.assertEqual(LazyList(it)[1::2], it[1::2])
+ self.assertEqual(LazyList(it)[5::-1], it[5::-1])
+ self.assertEqual(LazyList(it)[6:2:-2], it[6:2:-2])
+ self.assertEqual(LazyList(it)[::-1], it[::-1])
+
+ self.assertTrue(LazyList(it))
+ self.assertFalse(LazyList(range(0)))
+ self.assertEqual(len(LazyList(it)), len(it))
+ self.assertEqual(repr(LazyList(it)), repr(it))
+ self.assertEqual(compat_str(LazyList(it)), compat_str(it))
+
+ self.assertEqual(list(LazyList(it, reverse=True)), it[::-1])
+ self.assertEqual(list(reversed(LazyList(it))[::-1]), it)
+ self.assertEqual(list(reversed(LazyList(it))[1:3:7]), it[::-1][1:3:7])
+
+ def test_LazyList_laziness(self):
+
+ def test(ll, idx, val, cache):
+ self.assertEqual(ll[idx], val)
+ self.assertEqual(ll._cache, list(cache))
+
+ ll = LazyList(range(10))
+ test(ll, 0, 0, range(1))
+ test(ll, 5, 5, range(6))
+ test(ll, -3, 7, range(10))
+
+ ll = LazyList(range(10), reverse=True)
+ test(ll, -1, 0, range(1))
+ test(ll, 3, 6, range(10))
+
+ ll = LazyList(itertools.count())
+ test(ll, 10, 10, range(11))
+ ll = reversed(ll)
+ test(ll, -15, 14, range(15))
+
+ def test_try_call(self):
+ def total(*x, **kwargs):
+ return sum(x) + sum(kwargs.values())
+
+ self.assertEqual(try_call(None), None,
+ msg='not a fn should give None')
+ self.assertEqual(try_call(lambda: 1), 1,
+ msg='int fn with no expected_type should give int')
+ self.assertEqual(try_call(lambda: 1, expected_type=int), 1,
+ msg='int fn with expected_type int should give int')
+ self.assertEqual(try_call(lambda: 1, expected_type=dict), None,
+ msg='int fn with wrong expected_type should give None')
+ self.assertEqual(try_call(total, args=(0, 1, 0, ), expected_type=int), 1,
+ msg='fn should accept arglist')
+ self.assertEqual(try_call(total, kwargs={'a': 0, 'b': 1, 'c': 0}, expected_type=int), 1,
+ msg='fn should accept kwargs')
+ self.assertEqual(try_call(lambda: 1, expected_type=dict), None,
+ msg='int fn with no expected_type should give None')
+ self.assertEqual(try_call(lambda x: {}, total, args=(42, ), expected_type=int), 42,
+ msg='expect first int result with expected_type int')
+
+ def test_variadic(self):
+ self.assertEqual(variadic(None), (None, ))
+ self.assertEqual(variadic('spam'), ('spam', ))
+ self.assertEqual(variadic('spam', allowed_types=dict), 'spam')
+
+ def test_traverse_obj(self):
+ _TEST_DATA = {
+ 100: 100,
+ 1.2: 1.2,
+ 'str': 'str',
+ 'None': None,
+ '...': Ellipsis,
+ 'urls': [
+ {'index': 0, 'url': 'https://www.example.com/0'},
+ {'index': 1, 'url': 'https://www.example.com/1'},
+ ],
+ 'data': (
+ {'index': 2},
+ {'index': 3},
+ ),
+ 'dict': {},
+ }
+
+ # Test base functionality
+ self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
+ msg='allow tuple path')
+ self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str',
+ msg='allow list path')
+ self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str',
+ msg='allow iterable path')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str',
+ msg='single items should be treated as a path')
+ self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA)
+ self.assertEqual(traverse_obj(_TEST_DATA, 100), 100)
+ self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2)
+
+ # Test Ellipsis behavior
+ self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
+ (item for item in _TEST_DATA.values() if item is not None),
+ msg='`...` should give all values except `None`')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
+ msg='`...` selection for dicts should select all values')
+ self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='nested `...` queries should work')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
+ msg='`...` query result should be flattened')
+
+ # Test function as key
+ self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
+ [_TEST_DATA['urls']],
+ msg='function as query key should perform a filter based on (key, value)')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), {'str'},
+ msg='exceptions in the query function should be caught')
+
+ # Test alternative paths
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
+ msg='multiple `paths` should be treated as alternative paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
+ msg='alternatives should exit early')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
+ msg='alternatives should return `default` if exhausted')
+ self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, 'fail'), 100), 100,
+ msg='alternatives should track their own branching return')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('dict', Ellipsis), ('data', Ellipsis)), list(_TEST_DATA['data']),
+ msg='alternatives on empty objects should search further')
+
+ # Test branch and path nesting
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
+ msg='tuple as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'],
+ msg='list as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'],
+ msg='double nesting in path should be treated as paths')
+ self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1],
+ msg='do not fail early on branching')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='triple nesting in path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (Ellipsis, 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='ellipsis as branch path start gets flattened')
+
+ # Test dictionary as key
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2},
+ msg='dict key should result in a dict with the same keys')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}),
+ {0: 'https://www.example.com/0'},
+ msg='dict key should allow paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}),
+ {0: ['https://www.example.com/0']},
+ msg='tuple in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}),
+ {0: ['https://www.example.com/0']},
+ msg='double nesting in dict path should be treated as paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
+ {0: ['https://www.example.com/1', 'https://www.example.com/0']},
+ msg='triple nesting in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
+ msg='remove `None` values when dict key')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
+ msg='do not remove `None` values if `default`')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
+ msg='do not remove empty values when dict key')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}},
+ msg='do not remove empty values when dict key and a default')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []},
+ msg='if branch in dict key not successful, return `[]`')
+
+ # Testing default parameter behavior
+ _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None,
+ msg='default value should be `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=Ellipsis), Ellipsis,
+ msg='chained fails should result in default')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0,
+ msg='should not short cirquit on `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1,
+ msg='invalid dict key should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1,
+ msg='`None` is a deliberate sentinel and should become `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
+ msg='`IndexError` should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=1), 1,
+ msg='if branched but not successful return `default` if defined, not `[]`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=None), None,
+ msg='if branched but not successful return `default` even if `default` is `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail')), [],
+ msg='if branched but not successful return `[]`, not `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
+ msg='if branched but object is empty return `[]`, not `default`')
+
+ # Testing expected_type behavior
+ _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str',
+ msg='accept matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
+ msg='reject non matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0',
+ msg='transform type using type function')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
+ expected_type=lambda _: 1 / 0), None,
+ msg='wrap expected_type function in try_call')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'],
+ msg='eliminate items that expected_type fails on')
+
+ # Test get_all behavior
+ _GET_ALL_DATA = {'key': [0, 1, 2]}
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', Ellipsis), get_all=False), 0,
+ msg='if not `get_all`, return only first matching value')
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, Ellipsis, get_all=False), [0, 1, 2],
+ msg='do not overflatten if not `get_all`')
+
+ # Test casesense behavior
+ _CASESENSE_DATA = {
+ 'KeY': 'value0',
+ 0: {
+ 'KeY': 'value1',
+ 0: {'KeY': 'value2'},
+ },
+ # FULLWIDTH LATIN CAPITAL LETTER K
+ '\uff2bey': 'value3',
+ }
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None,
+ msg='dict keys should be case sensitive unless `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY',
+ casesense=False), 'value0',
+ msg='allow non matching key case if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, '\uff4bey', # FULLWIDTH LATIN SMALL LETTER K
+ casesense=False), 'value3',
+ msg='allow non matching Unicode key case if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)),
+ casesense=False), ['value1'],
+ msg='allow non matching key case in branch if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)),
+ casesense=False), ['value2'],
+ msg='allow non matching key case in branch path if `casesense`')
+
+ # Test traverse_string behavior
+ _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None,
+ msg='do not traverse into string if not `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0),
+ _traverse_string=True), 's',
+ msg='traverse into string if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1),
+ _traverse_string=True), '.',
+ msg='traverse into converted data if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
+ _traverse_string=True), list('str'),
+ msg='`...` branching into string should result in list')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
+ _traverse_string=True), ['s', 'r'],
+ msg='branching into string should result in list')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
+ _traverse_string=True), list('str'),
+ msg='function branching into string should result in list')
+
+ # Test is_user_input behavior
+ _IS_USER_INPUT_DATA = {'range8': list(range(8))}
+ self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
+ _is_user_input=True), 3,
+ msg='allow for string indexing if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
+ _is_user_input=True), tuple(range(8))[3:],
+ msg='allow for string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
+ _is_user_input=True), tuple(range(8))[:4:2],
+ msg='allow step in string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
+ _is_user_input=True), range(8),
+ msg='`:` should be treated as `...` if `is_user_input`')
+ with self.assertRaises(TypeError, msg='too many params should result in error'):
+ traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True)
+
+ # Test re.Match as input obj
+ mobj = re.match(r'^0(12)(?P<group>3)(4)?$', '0123')
+ self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None],
+ msg='`...` on a `re.Match` should give its `groups()`')
+ self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'],
+ msg='function on a `re.Match` should give groupno, value starting at 0')
+ self.assertEqual(traverse_obj(mobj, 'group'), '3',
+ msg='str key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 2), '3',
+ msg='int key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3',
+ msg='str key on a `re.Match` should respect casesense')
+ self.assertEqual(traverse_obj(mobj, 'fail'), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 8), None,
+ msg='failing int key on a `re.Match` should return `default`')
+
+ def test_get_first(self):
+ self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
+
+ def test_join_nonempty(self):
+ self.assertEqual(join_nonempty('a', 'b'), 'a-b')
+ self.assertEqual(join_nonempty(
+ 'a', 'b', 'c', 'd',
+ from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
+
if __name__ == '__main__':
unittest.main()

Generated by cgit