summaryrefslogtreecommitdiff
path: root/youtube_dl/jsinterp.py
blob: c60a9b3c234768d2853c977bb2153d9d138e8314 (plain)
    1 from __future__ import unicode_literals
    2 
    3 import itertools
    4 import json
    5 import math
    6 import operator
    7 import re
    8 
    9 from .utils import (
   10     NO_DEFAULT,
   11     ExtractorError,
   12     js_to_json,
   13     remove_quotes,
   14     unified_timestamp,
   15 )
   16 from .compat import (
   17     compat_collections_chain_map as ChainMap,
   18     compat_itertools_zip_longest as zip_longest,
   19     compat_str,
   20 )
   21 
   22 _NAME_RE = r'[a-zA-Z_$][\w$]*'
   23 
   24 # (op, definition) in order of binding priority, tightest first
   25 # avoid dict to maintain order
   26 # definition None => Defined in JSInterpreter._operator
   27 _DOT_OPERATORS = (
   28     ('.', None),
   29     # TODO: ('?.', None),
   30 )
   31 
   32 _OPERATORS = (
   33     ('|', operator.or_),
   34     ('^', operator.xor),
   35     ('&', operator.and_),
   36     ('>>', operator.rshift),
   37     ('<<', operator.lshift),
   38     ('+', operator.add),
   39     ('-', operator.sub),
   40     ('*', operator.mul),
   41     ('/', operator.truediv),
   42     ('%', operator.mod),
   43 )
   44 
   45 _COMP_OPERATORS = (
   46     ('===', operator.is_),
   47     ('==', operator.eq),
   48     ('!==', operator.is_not),
   49     ('!=', operator.ne),
   50     ('<=', operator.le),
   51     ('>=', operator.ge),
   52     ('<', operator.lt),
   53     ('>', operator.gt),
   54 )
   55 
   56 _LOG_OPERATORS = (
   57     ('&', operator.and_),
   58     ('|', operator.or_),
   59     ('^', operator.xor),
   60 )
   61 
   62 _SC_OPERATORS = (
   63     ('?', None),
   64     ('||', None),
   65     ('&&', None),
   66     # TODO: ('??', None),
   67 )
   68 
   69 _OPERATOR_RE = '|'.join(map(lambda x: re.escape(x[0]), _OPERATORS + _LOG_OPERATORS))
   70 
   71 _MATCHING_PARENS = dict(zip(*zip('()', '{}', '[]')))
   72 _QUOTES = '\'"'
   73 
   74 
   75 def _ternary(cndn, if_true=True, if_false=False):
   76     """Simulate JS's ternary operator (cndn?if_true:if_false)"""
   77     if cndn in (False, None, 0, ''):
   78         return if_false
   79     try:
   80         if math.isnan(cndn):  # NB: NaN cannot be checked by membership
   81             return if_false
   82     except TypeError:
   83         pass
   84     return if_true
   85 
   86 
   87 class JS_Break(ExtractorError):
   88     def __init__(self):
   89         ExtractorError.__init__(self, 'Invalid break')
   90 
   91 
   92 class JS_Continue(ExtractorError):
   93     def __init__(self):
   94         ExtractorError.__init__(self, 'Invalid continue')
   95 
   96 
   97 class LocalNameSpace(ChainMap):
   98     def __setitem__(self, key, value):
   99         for scope in self.maps:
  100             if key in scope:
  101                 scope[key] = value
  102                 return
  103         self.maps[0][key] = value
  104 
  105     def __delitem__(self, key):
  106         raise NotImplementedError('Deleting is not supported')
  107 
  108     def __repr__(self):
  109         return 'LocalNameSpace%s' % (self.maps, )
  110 
  111 
  112 class JSInterpreter(object):
  113     __named_object_counter = 0
  114 
  115     def __init__(self, code, objects=None):
  116         self.code, self._functions = code, {}
  117         self._objects = {} if objects is None else objects
  118 
  119     class Exception(ExtractorError):
  120         def __init__(self, msg, *args, **kwargs):
  121             expr = kwargs.pop('expr', None)
  122             if expr is not None:
  123                 msg = '{0} in: {1!r}'.format(msg.rstrip(), expr[:100])
  124             super(JSInterpreter.Exception, self).__init__(msg, *args, **kwargs)
  125 
  126     def _named_object(self, namespace, obj):
  127         self.__named_object_counter += 1
  128         name = '__youtube_dl_jsinterp_obj%d' % (self.__named_object_counter, )
  129         namespace[name] = obj
  130         return name
  131 
  132     @staticmethod
  133     def _separate(expr, delim=',', max_split=None, skip_delims=None):
  134         if not expr:
  135             return
  136         counters = {k: 0 for k in _MATCHING_PARENS.values()}
  137         start, splits, pos, skipping, delim_len = 0, 0, 0, 0, len(delim) - 1
  138         in_quote, escaping = None, False
  139         for idx, char in enumerate(expr):
  140             if not in_quote:
  141                 if char in _MATCHING_PARENS:
  142                     counters[_MATCHING_PARENS[char]] += 1
  143                 elif char in counters:
  144                     counters[char] -= 1
  145             if not escaping:
  146                 if char in _QUOTES and in_quote in (char, None):
  147                     in_quote = None if in_quote else char
  148                 else:
  149                     escaping = in_quote and char == '\\'
  150             else:
  151                 escaping = False
  152 
  153             if char != delim[pos] or any(counters.values()) or in_quote:
  154                 pos = skipping = 0
  155                 continue
  156             elif skipping > 0:
  157                 skipping -= 1
  158                 continue
  159             elif pos == 0 and skip_delims:
  160                 here = expr[idx:]
  161                 for s in skip_delims if isinstance(skip_delims, (list, tuple)) else [skip_delims]:
  162                     if here.startswith(s) and s:
  163                         skipping = len(s) - 1
  164                         break
  165                 if skipping > 0:
  166                     continue
  167             if pos < delim_len:
  168                 pos += 1
  169                 continue
  170             yield expr[start: idx - delim_len]
  171             start, pos = idx + 1, 0
  172             splits += 1
  173             if max_split and splits >= max_split:
  174                 break
  175         yield expr[start:]
  176 
  177     @classmethod
  178     def _separate_at_paren(cls, expr, delim):
  179         separated = list(cls._separate(expr, delim, 1))
  180 
  181         if len(separated) < 2:
  182             raise cls.Exception('No terminating paren {delim} in {expr}'.format(**locals()))
  183         return separated[0][1:].strip(), separated[1].strip()
  184 
  185     @staticmethod
  186     def _all_operators():
  187         return itertools.chain(
  188             _SC_OPERATORS, _LOG_OPERATORS, _COMP_OPERATORS, _OPERATORS)
  189 
  190     def _operator(self, op, left_val, right_expr, expr, local_vars, allow_recursion):
  191         if op in ('||', '&&'):
  192             if (op == '&&') ^ _ternary(left_val):
  193                 return left_val  # short circuiting
  194         elif op == '?':
  195             right_expr = _ternary(left_val, *self._separate(right_expr, ':', 1))
  196 
  197         right_val = self.interpret_expression(right_expr, local_vars, allow_recursion)
  198         opfunc = op and next((v for k, v in self._all_operators() if k == op), None)
  199         if not opfunc:
  200             return right_val
  201 
  202         try:
  203             return opfunc(left_val, right_val)
  204         except Exception as e:
  205             raise self.Exception('Failed to evaluate {left_val!r} {op} {right_val!r}'.format(**locals()), expr, cause=e)
  206 
  207     def _index(self, obj, idx):
  208         if idx == 'length':
  209             return len(obj)
  210         try:
  211             return obj[int(idx)] if isinstance(obj, list) else obj[idx]
  212         except Exception as e:
  213             raise self.Exception('Cannot get index {idx}'.format(**locals()), expr=repr(obj), cause=e)
  214 
  215     def _dump(self, obj, namespace):
  216         try:
  217             return json.dumps(obj)
  218         except TypeError:
  219             return self._named_object(namespace, obj)
  220 
  221     def interpret_statement(self, stmt, local_vars, allow_recursion=100):
  222         if allow_recursion < 0:
  223             raise self.Exception('Recursion limit reached')
  224         allow_recursion -= 1
  225 
  226         should_return = False
  227         sub_statements = list(self._separate(stmt, ';')) or ['']
  228         expr = stmt = sub_statements.pop().strip()
  229         for sub_stmt in sub_statements:
  230             ret, should_return = self.interpret_statement(sub_stmt, local_vars, allow_recursion)
  231             if should_return:
  232                 return ret, should_return
  233 
  234         m = re.match(r'(?P<var>(?:var|const|let)\s)|return(?:\s+|$)', stmt)
  235         if m:
  236             expr = stmt[len(m.group(0)):].strip()
  237             should_return = not m.group('var')
  238         if not expr:
  239             return None, should_return
  240 
  241         if expr[0] in _QUOTES:
  242             inner, outer = self._separate(expr, expr[0], 1)
  243             inner = json.loads(js_to_json(inner + expr[0]))  # , strict=True))
  244             if not outer:
  245                 return inner, should_return
  246             expr = self._named_object(local_vars, inner) + outer
  247 
  248         if expr.startswith('new '):
  249             obj = expr[4:]
  250             if obj.startswith('Date('):
  251                 left, right = self._separate_at_paren(obj[4:], ')')
  252                 left = self.interpret_expression(left, local_vars, allow_recursion)
  253                 expr = unified_timestamp(left, False)
  254                 if not expr:
  255                     raise self.Exception('Failed to parse date {left!r}'.format(**locals()), expr=expr)
  256                 expr = self._dump(int(expr * 1000), local_vars) + right
  257             else:
  258                 raise self.Exception('Unsupported object {obj}'.format(**locals()), expr=expr)
  259 
  260         if expr.startswith('void '):
  261             left = self.interpret_expression(expr[5:], local_vars, allow_recursion)
  262             return None, should_return
  263 
  264         if expr.startswith('{'):
  265             inner, outer = self._separate_at_paren(expr, '}')
  266             inner, should_abort = self.interpret_statement(inner, local_vars, allow_recursion)
  267             if not outer or should_abort:
  268                 return inner, should_abort or should_return
  269             else:
  270                 expr = self._dump(inner, local_vars) + outer
  271 
  272         if expr.startswith('('):
  273             inner, outer = self._separate_at_paren(expr, ')')
  274             inner, should_abort = self.interpret_statement(inner, local_vars, allow_recursion)
  275             if not outer or should_abort:
  276                 return inner, should_abort or should_return
  277             else:
  278                 expr = self._dump(inner, local_vars) + outer
  279 
  280         if expr.startswith('['):
  281             inner, outer = self._separate_at_paren(expr, ']')
  282             name = self._named_object(local_vars, [
  283                 self.interpret_expression(item, local_vars, allow_recursion)
  284                 for item in self._separate(inner)])
  285             expr = name + outer
  286 
  287         m = re.match(r'(?P<try>try|finally)\s*|(?:(?P<catch>catch)|(?P<for>for)|(?P<switch>switch))\s*\(', expr)
  288         md = m.groupdict() if m else {}
  289         if md.get('try'):
  290             if expr[m.end()] == '{':
  291                 try_expr, expr = self._separate_at_paren(expr[m.end():], '}')
  292             else:
  293                 try_expr, expr = expr[m.end() - 1:], ''
  294             ret, should_abort = self.interpret_statement(try_expr, local_vars, allow_recursion)
  295             if should_abort:
  296                 return ret, True
  297             ret, should_abort = self.interpret_statement(expr, local_vars, allow_recursion)
  298             return ret, should_abort or should_return
  299 
  300         elif md.get('catch'):
  301             # We ignore the catch block
  302             _, expr = self._separate_at_paren(expr, '}')
  303             ret, should_abort = self.interpret_statement(expr, local_vars, allow_recursion)
  304             return ret, should_abort or should_return
  305 
  306         elif md.get('for'):
  307             constructor, remaining = self._separate_at_paren(expr[m.end() - 1:], ')')
  308             if remaining.startswith('{'):
  309                 body, expr = self._separate_at_paren(remaining, '}')
  310             else:
  311                 switch_m = re.match(r'switch\s*\(', remaining)  # FIXME
  312                 if switch_m:
  313                     switch_val, remaining = self._separate_at_paren(remaining[switch_m.end() - 1:], ')')
  314                     body, expr = self._separate_at_paren(remaining, '}')
  315                     body = 'switch(%s){%s}' % (switch_val, body)
  316                 else:
  317                     body, expr = remaining, ''
  318             start, cndn, increment = self._separate(constructor, ';')
  319             self.interpret_expression(start, local_vars, allow_recursion)
  320             while True:
  321                 if not _ternary(self.interpret_expression(cndn, local_vars, allow_recursion)):
  322                     break
  323                 try:
  324                     ret, should_abort = self.interpret_statement(body, local_vars, allow_recursion)
  325                     if should_abort:
  326                         return ret, True
  327                 except JS_Break:
  328                     break
  329                 except JS_Continue:
  330                     pass
  331                 self.interpret_expression(increment, local_vars, allow_recursion)
  332             ret, should_abort = self.interpret_statement(expr, local_vars, allow_recursion)
  333             return ret, should_abort or should_return
  334 
  335         elif md.get('switch'):
  336             switch_val, remaining = self._separate_at_paren(expr[m.end() - 1:], ')')
  337             switch_val = self.interpret_expression(switch_val, local_vars, allow_recursion)
  338             body, expr = self._separate_at_paren(remaining, '}')
  339             items = body.replace('default:', 'case default:').split('case ')[1:]
  340             for default in (False, True):
  341                 matched = False
  342                 for item in items:
  343                     case, stmt = (i.strip() for i in self._separate(item, ':', 1))
  344                     if default:
  345                         matched = matched or case == 'default'
  346                     elif not matched:
  347                         matched = (case != 'default'
  348                                    and switch_val == self.interpret_expression(case, local_vars, allow_recursion))
  349                     if not matched:
  350                         continue
  351                     try:
  352                         ret, should_abort = self.interpret_statement(stmt, local_vars, allow_recursion)
  353                         if should_abort:
  354                             return ret
  355                     except JS_Break:
  356                         break
  357                 if matched:
  358                     break
  359             ret, should_abort = self.interpret_statement(expr, local_vars, allow_recursion)
  360             return ret, should_abort or should_return
  361 
  362         # Comma separated statements
  363         sub_expressions = list(self._separate(expr))
  364         if len(sub_expressions) > 1:
  365             for sub_expr in sub_expressions:
  366                 ret, should_abort = self.interpret_statement(sub_expr, local_vars, allow_recursion)
  367                 if should_abort:
  368                     return ret, True
  369             return ret, False
  370 
  371         for m in re.finditer(r'''(?x)
  372                 (?P<pre_sign>\+\+|--)(?P<var1>{_NAME_RE})|
  373                 (?P<var2>{_NAME_RE})(?P<post_sign>\+\+|--)'''.format(**globals()), expr):
  374             var = m.group('var1') or m.group('var2')
  375             start, end = m.span()
  376             sign = m.group('pre_sign') or m.group('post_sign')
  377             ret = local_vars[var]
  378             local_vars[var] += 1 if sign[0] == '+' else -1
  379             if m.group('pre_sign'):
  380                 ret = local_vars[var]
  381             expr = expr[:start] + self._dump(ret, local_vars) + expr[end:]
  382 
  383         if not expr:
  384             return None, should_return
  385 
  386         m = re.match(r'''(?x)
  387             (?P<assign>
  388                 (?P<out>{_NAME_RE})(?:\[(?P<index>[^\]]+?)\])?\s*
  389                 (?P<op>{_OPERATOR_RE})?
  390                 =(?P<expr>.*)$
  391             )|(?P<return>
  392                 (?!if|return|true|false|null|undefined)(?P<name>{_NAME_RE})$
  393             )|(?P<indexing>
  394                 (?P<in>{_NAME_RE})\[(?P<idx>.+)\]$
  395             )|(?P<attribute>
  396                 (?P<var>{_NAME_RE})(?:\.(?P<member>[^(]+)|\[(?P<member2>[^\]]+)\])\s*
  397             )|(?P<function>
  398                 (?P<fname>{_NAME_RE})\((?P<args>.*)\)$
  399             )'''.format(**globals()), expr)
  400         md = m.groupdict() if m else {}
  401         if md.get('assign'):
  402             left_val = local_vars.get(m.group('out'))
  403 
  404             if not m.group('index'):
  405                 local_vars[m.group('out')] = self._operator(
  406                     m.group('op'), left_val, m.group('expr'), expr, local_vars, allow_recursion)
  407                 return local_vars[m.group('out')], should_return
  408             elif left_val is None:
  409                 raise self.Exception('Cannot index undefined variable ' + m.group('out'), expr=expr)
  410 
  411             idx = self.interpret_expression(m.group('index'), local_vars, allow_recursion)
  412             if not isinstance(idx, (int, float)):
  413                 raise self.Exception('List index %s must be integer' % (idx, ), expr=expr)
  414             idx = int(idx)
  415             left_val[idx] = self._operator(
  416                 m.group('op'), left_val[idx], m.group('expr'), expr, local_vars, allow_recursion)
  417             return left_val[idx], should_return
  418 
  419         elif expr.isdigit():
  420             return int(expr), should_return
  421 
  422         elif expr == 'break':
  423             raise JS_Break()
  424         elif expr == 'continue':
  425             raise JS_Continue()
  426 
  427         elif md.get('return'):
  428             return local_vars[m.group('name')], should_return
  429 
  430         try:
  431             ret = json.loads(js_to_json(expr))  # strict=True)
  432             if not md.get('attribute'):
  433                 return ret, should_return
  434         except ValueError:
  435             pass
  436 
  437         if md.get('indexing'):
  438             val = local_vars[m.group('in')]
  439             idx = self.interpret_expression(m.group('idx'), local_vars, allow_recursion)
  440             return self._index(val, idx), should_return
  441 
  442         for op, _ in self._all_operators():
  443             # hackety: </> have higher priority than <</>>, but don't confuse them
  444             skip_delim = (op + op) if op in ('<', '>') else None
  445             separated = list(self._separate(expr, op, skip_delims=skip_delim))
  446             if len(separated) < 2:
  447                 continue
  448 
  449             right_expr = separated.pop()
  450             while op == '-' and len(separated) > 1 and not separated[-1].strip():
  451                 right_expr = '-' + right_expr
  452                 separated.pop()
  453             left_val = self.interpret_expression(op.join(separated), local_vars, allow_recursion)
  454             return self._operator(op, 0 if left_val is None else left_val,
  455                                   right_expr, expr, local_vars, allow_recursion), should_return
  456 
  457         if md.get('attribute'):
  458             variable = m.group('var')
  459             member = m.group('member')
  460             if not member:
  461                 member = self.interpret_expression(m.group('member2'), local_vars, allow_recursion)
  462             arg_str = expr[m.end():]
  463             if arg_str.startswith('('):
  464                 arg_str, remaining = self._separate_at_paren(arg_str, ')')
  465             else:
  466                 arg_str, remaining = None, arg_str
  467 
  468             def assertion(cndn, msg):
  469                 """ assert, but without risk of getting optimized out """
  470                 if not cndn:
  471                     raise ExtractorError('{member} {msg}'.format(**locals()), expr=expr)
  472 
  473             def eval_method():
  474                 if (variable, member) == ('console', 'debug'):
  475                     return
  476                 types = {
  477                     'String': compat_str,
  478                     'Math': float,
  479                 }
  480                 obj = local_vars.get(variable, types.get(variable, NO_DEFAULT))
  481                 if obj is NO_DEFAULT:
  482                     if variable not in self._objects:
  483                         self._objects[variable] = self.extract_object(variable)
  484                     obj = self._objects[variable]
  485 
  486                 # Member access
  487                 if arg_str is None:
  488                     return self._index(obj, member)
  489 
  490                 # Function call
  491                 argvals = [
  492                     self.interpret_expression(v, local_vars, allow_recursion)
  493                     for v in self._separate(arg_str)]
  494 
  495                 if obj == compat_str:
  496                     if member == 'fromCharCode':
  497                         assertion(argvals, 'takes one or more arguments')
  498                         return ''.join(map(chr, argvals))
  499                     raise self.Exception('Unsupported string method ' + member, expr=expr)
  500                 elif obj == float:
  501                     if member == 'pow':
  502                         assertion(len(argvals) == 2, 'takes two arguments')
  503                         return argvals[0] ** argvals[1]
  504                     raise self.Exception('Unsupported Math method ' + member, expr=expr)
  505 
  506                 if member == 'split':
  507                     assertion(argvals, 'takes one or more arguments')
  508                     assertion(len(argvals) == 1, 'with limit argument is not implemented')
  509                     return obj.split(argvals[0]) if argvals[0] else list(obj)
  510                 elif member == 'join':
  511                     assertion(isinstance(obj, list), 'must be applied on a list')
  512                     assertion(len(argvals) == 1, 'takes exactly one argument')
  513                     return argvals[0].join(obj)
  514                 elif member == 'reverse':
  515                     assertion(not argvals, 'does not take any arguments')
  516                     obj.reverse()
  517                     return obj
  518                 elif member == 'slice':
  519                     assertion(isinstance(obj, list), 'must be applied on a list')
  520                     assertion(len(argvals) == 1, 'takes exactly one argument')
  521                     return obj[argvals[0]:]
  522                 elif member == 'splice':
  523                     assertion(isinstance(obj, list), 'must be applied on a list')
  524                     assertion(argvals, 'takes one or more arguments')
  525                     index, howMany = map(int, (argvals + [len(obj)])[:2])
  526                     if index < 0:
  527                         index += len(obj)
  528                     add_items = argvals[2:]
  529                     res = []
  530                     for i in range(index, min(index + howMany, len(obj))):
  531                         res.append(obj.pop(index))
  532                     for i, item in enumerate(add_items):
  533                         obj.insert(index + i, item)
  534                     return res
  535                 elif member == 'unshift':
  536                     assertion(isinstance(obj, list), 'must be applied on a list')
  537                     assertion(argvals, 'takes one or more arguments')
  538                     for item in reversed(argvals):
  539                         obj.insert(0, item)
  540                     return obj
  541                 elif member == 'pop':
  542                     assertion(isinstance(obj, list), 'must be applied on a list')
  543                     assertion(not argvals, 'does not take any arguments')
  544                     if not obj:
  545                         return
  546                     return obj.pop()
  547                 elif member == 'push':
  548                     assertion(argvals, 'takes one or more arguments')
  549                     obj.extend(argvals)
  550                     return obj
  551                 elif member == 'forEach':
  552                     assertion(argvals, 'takes one or more arguments')
  553                     assertion(len(argvals) <= 2, 'takes at-most 2 arguments')
  554                     f, this = (argvals + [''])[:2]
  555                     return [f((item, idx, obj), {'this': this}, allow_recursion) for idx, item in enumerate(obj)]
  556                 elif member == 'indexOf':
  557                     assertion(argvals, 'takes one or more arguments')
  558                     assertion(len(argvals) <= 2, 'takes at-most 2 arguments')
  559                     idx, start = (argvals + [0])[:2]
  560                     try:
  561                         return obj.index(idx, start)
  562                     except ValueError:
  563                         return -1
  564 
  565                 idx = int(member) if isinstance(obj, list) else member
  566                 return obj[idx](argvals, allow_recursion=allow_recursion)
  567 
  568             if remaining:
  569                 ret, should_abort = self.interpret_statement(
  570                     self._named_object(local_vars, eval_method()) + remaining,
  571                     local_vars, allow_recursion)
  572                 return ret, should_return or should_abort
  573             else:
  574                 return eval_method(), should_return
  575 
  576         elif md.get('function'):
  577             fname = m.group('fname')
  578             argvals = [self.interpret_expression(v, local_vars, allow_recursion)
  579                        for v in self._separate(m.group('args'))]
  580             if fname in local_vars:
  581                 return local_vars[fname](argvals, allow_recursion=allow_recursion), should_return
  582             elif fname not in self._functions:
  583                 self._functions[fname] = self.extract_function(fname)
  584             return self._functions[fname](argvals, allow_recursion=allow_recursion), should_return
  585 
  586         raise self.Exception(
  587             'Unsupported JS expression ' + (expr[:40] if expr != stmt else ''), expr=stmt)
  588 
  589     def interpret_expression(self, expr, local_vars, allow_recursion):
  590         ret, should_return = self.interpret_statement(expr, local_vars, allow_recursion)
  591         if should_return:
  592             raise self.Exception('Cannot return from an expression', expr)
  593         return ret
  594 
  595     def extract_object(self, objname):
  596         _FUNC_NAME_RE = r'''(?:[a-zA-Z$0-9]+|"[a-zA-Z$0-9]+"|'[a-zA-Z$0-9]+')'''
  597         obj = {}
  598         obj_m = re.search(
  599             r'''(?x)
  600                 (?<!this\.)%s\s*=\s*{\s*
  601                     (?P<fields>(%s\s*:\s*function\s*\(.*?\)\s*{.*?}(?:,\s*)?)*)
  602                 }\s*;
  603             ''' % (re.escape(objname), _FUNC_NAME_RE),
  604             self.code)
  605         if not obj_m:
  606             raise self.Exception('Could not find object ' + objname)
  607         fields = obj_m.group('fields')
  608         # Currently, it only supports function definitions
  609         fields_m = re.finditer(
  610             r'''(?x)
  611                 (?P<key>%s)\s*:\s*function\s*\((?P<args>(?:%s|,)*)\){(?P<code>[^}]+)}
  612             ''' % (_FUNC_NAME_RE, _NAME_RE),
  613             fields)
  614         for f in fields_m:
  615             argnames = self.build_arglist(f.group('args'))
  616             obj[remove_quotes(f.group('key'))] = self.build_function(argnames, f.group('code'))
  617 
  618         return obj
  619 
  620     def extract_function_code(self, funcname):
  621         """ @returns argnames, code """
  622         func_m = re.search(
  623             r'''(?xs)
  624                 (?:
  625                     function\s+%(name)s|
  626                     [{;,]\s*%(name)s\s*=\s*function|
  627                     (?:var|const|let)\s+%(name)s\s*=\s*function
  628                 )\s*
  629                 \((?P<args>[^)]*)\)\s*
  630                 (?P<code>{.+})''' % {'name': re.escape(funcname)},
  631             self.code)
  632         code, _ = self._separate_at_paren(func_m.group('code'), '}')  # refine the match
  633         if func_m is None:
  634             raise self.Exception('Could not find JS function "{funcname}"'.format(**locals()))
  635         return self.build_arglist(func_m.group('args')), code
  636 
  637     def extract_function(self, funcname):
  638         return self.extract_function_from_code(*self.extract_function_code(funcname))
  639 
  640     def extract_function_from_code(self, argnames, code, *global_stack):
  641         local_vars = {}
  642         while True:
  643             mobj = re.search(r'function\((?P<args>[^)]*)\)\s*{', code)
  644             if mobj is None:
  645                 break
  646             start, body_start = mobj.span()
  647             body, remaining = self._separate_at_paren(code[body_start - 1:], '}')
  648             name = self._named_object(
  649                 local_vars,
  650                 self.extract_function_from_code(
  651                     self.build_arglist(mobj.group('args')),
  652                     body, local_vars, *global_stack))
  653             code = code[:start] + name + remaining
  654         return self.build_function(argnames, code, local_vars, *global_stack)
  655 
  656     def call_function(self, funcname, *args):
  657         return self.extract_function(funcname)(args)
  658 
  659     @classmethod
  660     def build_arglist(cls, arg_text):
  661         if not arg_text:
  662             return []
  663         return list(filter(None, (x.strip() or None for x in cls._separate(arg_text))))
  664 
  665     def build_function(self, argnames, code, *global_stack):
  666         global_stack = list(global_stack) or [{}]
  667         argnames = tuple(argnames)
  668 
  669         def resf(args, kwargs={}, allow_recursion=100):
  670             global_stack[0].update(
  671                 zip_longest(argnames, args, fillvalue=None))
  672             global_stack[0].update(kwargs)
  673             var_stack = LocalNameSpace(*global_stack)
  674             ret, should_abort = self.interpret_statement(code.replace('\n', ''), var_stack, allow_recursion - 1)
  675             if should_abort:
  676                 return ret
  677         return resf

Generated by cgit