summaryrefslogtreecommitdiff
path: root/youtube_dl/jsinterp.py
blob: 7bda596102a40775b06fe4318c3915633d586a67 (plain)
    1 from __future__ import unicode_literals
    2 
    3 import json
    4 import operator
    5 import re
    6 
    7 from .utils import (
    8     ExtractorError,
    9     remove_quotes,
   10 )
   11 
   12 _OPERATORS = [
   13     ('|', operator.or_),
   14     ('^', operator.xor),
   15     ('&', operator.and_),
   16     ('>>', operator.rshift),
   17     ('<<', operator.lshift),
   18     ('-', operator.sub),
   19     ('+', operator.add),
   20     ('%', operator.mod),
   21     ('/', operator.truediv),
   22     ('*', operator.mul),
   23 ]
   24 _ASSIGN_OPERATORS = [(op + '=', opfunc) for op, opfunc in _OPERATORS]
   25 _ASSIGN_OPERATORS.append(('=', lambda cur, right: right))
   26 
   27 _NAME_RE = r'[a-zA-Z_$][a-zA-Z_$0-9]*'
   28 
   29 
   30 class JSInterpreter(object):
   31     def __init__(self, code, objects=None):
   32         if objects is None:
   33             objects = {}
   34         self.code = code
   35         self._functions = {}
   36         self._objects = objects
   37 
   38     def interpret_statement(self, stmt, local_vars, allow_recursion=100):
   39         if allow_recursion < 0:
   40             raise ExtractorError('Recursion limit reached')
   41 
   42         should_abort = False
   43         stmt = stmt.lstrip()
   44         stmt_m = re.match(r'var\s', stmt)
   45         if stmt_m:
   46             expr = stmt[len(stmt_m.group(0)):]
   47         else:
   48             return_m = re.match(r'return(?:\s+|$)', stmt)
   49             if return_m:
   50                 expr = stmt[len(return_m.group(0)):]
   51                 should_abort = True
   52             else:
   53                 # Try interpreting it as an expression
   54                 expr = stmt
   55 
   56         v = self.interpret_expression(expr, local_vars, allow_recursion)
   57         return v, should_abort
   58 
   59     def interpret_expression(self, expr, local_vars, allow_recursion):
   60         expr = expr.strip()
   61         if expr == '':  # Empty expression
   62             return None
   63 
   64         if expr.startswith('('):
   65             parens_count = 0
   66             for m in re.finditer(r'[()]', expr):
   67                 if m.group(0) == '(':
   68                     parens_count += 1
   69                 else:
   70                     parens_count -= 1
   71                     if parens_count == 0:
   72                         sub_expr = expr[1:m.start()]
   73                         sub_result = self.interpret_expression(
   74                             sub_expr, local_vars, allow_recursion)
   75                         remaining_expr = expr[m.end():].strip()
   76                         if not remaining_expr:
   77                             return sub_result
   78                         else:
   79                             expr = json.dumps(sub_result) + remaining_expr
   80                         break
   81             else:
   82                 raise ExtractorError('Premature end of parens in %r' % expr)
   83 
   84         for op, opfunc in _ASSIGN_OPERATORS:
   85             m = re.match(r'''(?x)
   86                 (?P<out>%s)(?:\[(?P<index>[^\]]+?)\])?
   87                 \s*%s
   88                 (?P<expr>.*)$''' % (_NAME_RE, re.escape(op)), expr)
   89             if not m:
   90                 continue
   91             right_val = self.interpret_expression(
   92                 m.group('expr'), local_vars, allow_recursion - 1)
   93 
   94             if m.groupdict().get('index'):
   95                 lvar = local_vars[m.group('out')]
   96                 idx = self.interpret_expression(
   97                     m.group('index'), local_vars, allow_recursion)
   98                 assert isinstance(idx, int)
   99                 cur = lvar[idx]
  100                 val = opfunc(cur, right_val)
  101                 lvar[idx] = val
  102                 return val
  103             else:
  104                 cur = local_vars.get(m.group('out'))
  105                 val = opfunc(cur, right_val)
  106                 local_vars[m.group('out')] = val
  107                 return val
  108 
  109         if expr.isdigit():
  110             return int(expr)
  111 
  112         var_m = re.match(
  113             r'(?!if|return|true|false)(?P<name>%s)$' % _NAME_RE,
  114             expr)
  115         if var_m:
  116             return local_vars[var_m.group('name')]
  117 
  118         try:
  119             return json.loads(expr)
  120         except ValueError:
  121             pass
  122 
  123         m = re.match(
  124             r'(?P<in>%s)\[(?P<idx>.+)\]$' % _NAME_RE, expr)
  125         if m:
  126             val = local_vars[m.group('in')]
  127             idx = self.interpret_expression(
  128                 m.group('idx'), local_vars, allow_recursion - 1)
  129             return val[idx]
  130 
  131         m = re.match(
  132             r'(?P<var>%s)(?:\.(?P<member>[^(]+)|\[(?P<member2>[^]]+)\])\s*(?:\(+(?P<args>[^()]*)\))?$' % _NAME_RE,
  133             expr)
  134         if m:
  135             variable = m.group('var')
  136             member = remove_quotes(m.group('member') or m.group('member2'))
  137             arg_str = m.group('args')
  138 
  139             if variable in local_vars:
  140                 obj = local_vars[variable]
  141             else:
  142                 if variable not in self._objects:
  143                     self._objects[variable] = self.extract_object(variable)
  144                 obj = self._objects[variable]
  145 
  146             if arg_str is None:
  147                 # Member access
  148                 if member == 'length':
  149                     return len(obj)
  150                 return obj[member]
  151 
  152             assert expr.endswith(')')
  153             # Function call
  154             if arg_str == '':
  155                 argvals = tuple()
  156             else:
  157                 argvals = tuple([
  158                     self.interpret_expression(v, local_vars, allow_recursion)
  159                     for v in arg_str.split(',')])
  160 
  161             if member == 'split':
  162                 assert argvals == ('',)
  163                 return list(obj)
  164             if member == 'join':
  165                 assert len(argvals) == 1
  166                 return argvals[0].join(obj)
  167             if member == 'reverse':
  168                 assert len(argvals) == 0
  169                 obj.reverse()
  170                 return obj
  171             if member == 'slice':
  172                 assert len(argvals) == 1
  173                 return obj[argvals[0]:]
  174             if member == 'splice':
  175                 assert isinstance(obj, list)
  176                 index, howMany = argvals
  177                 res = []
  178                 for i in range(index, min(index + howMany, len(obj))):
  179                     res.append(obj.pop(index))
  180                 return res
  181 
  182             return obj[member](argvals)
  183 
  184         for op, opfunc in _OPERATORS:
  185             m = re.match(r'(?P<x>.+?)%s(?P<y>.+)' % re.escape(op), expr)
  186             if not m:
  187                 continue
  188             x, abort = self.interpret_statement(
  189                 m.group('x'), local_vars, allow_recursion - 1)
  190             if abort:
  191                 raise ExtractorError(
  192                     'Premature left-side return of %s in %r' % (op, expr))
  193             y, abort = self.interpret_statement(
  194                 m.group('y'), local_vars, allow_recursion - 1)
  195             if abort:
  196                 raise ExtractorError(
  197                     'Premature right-side return of %s in %r' % (op, expr))
  198             return opfunc(x, y)
  199 
  200         m = re.match(
  201             r'^(?P<func>%s)\((?P<args>[a-zA-Z0-9_$,]*)\)$' % _NAME_RE, expr)
  202         if m:
  203             fname = m.group('func')
  204             argvals = tuple([
  205                 int(v) if v.isdigit() else local_vars[v]
  206                 for v in m.group('args').split(',')]) if len(m.group('args')) > 0 else tuple()
  207             if fname not in self._functions:
  208                 self._functions[fname] = self.extract_function(fname)
  209             return self._functions[fname](argvals)
  210 
  211         raise ExtractorError('Unsupported JS expression %r' % expr)
  212 
  213     def extract_object(self, objname):
  214         _FUNC_NAME_RE = r'''(?:[a-zA-Z$0-9]+|"[a-zA-Z$0-9]+"|'[a-zA-Z$0-9]+')'''
  215         obj = {}
  216         obj_m = re.search(
  217             r'''(?x)
  218                 (?<!this\.)%s\s*=\s*{\s*
  219                     (?P<fields>(%s\s*:\s*function\s*\(.*?\)\s*{.*?}(?:,\s*)?)*)
  220                 }\s*;
  221             ''' % (re.escape(objname), _FUNC_NAME_RE),
  222             self.code)
  223         fields = obj_m.group('fields')
  224         # Currently, it only supports function definitions
  225         fields_m = re.finditer(
  226             r'''(?x)
  227                 (?P<key>%s)\s*:\s*function\s*\((?P<args>[a-z,]+)\){(?P<code>[^}]+)}
  228             ''' % _FUNC_NAME_RE,
  229             fields)
  230         for f in fields_m:
  231             argnames = f.group('args').split(',')
  232             obj[remove_quotes(f.group('key'))] = self.build_function(argnames, f.group('code'))
  233 
  234         return obj
  235 
  236     def extract_function(self, funcname):
  237         func_m = re.search(
  238             r'''(?x)
  239                 (?:function\s+%s|[{;,]\s*%s\s*=\s*function|var\s+%s\s*=\s*function)\s*
  240                 \((?P<args>[^)]*)\)\s*
  241                 \{(?P<code>[^}]+)\}''' % (
  242                 re.escape(funcname), re.escape(funcname), re.escape(funcname)),
  243             self.code)
  244         if func_m is None:
  245             raise ExtractorError('Could not find JS function %r' % funcname)
  246         argnames = func_m.group('args').split(',')
  247 
  248         return self.build_function(argnames, func_m.group('code'))
  249 
  250     def call_function(self, funcname, *args):
  251         f = self.extract_function(funcname)
  252         return f(args)
  253 
  254     def build_function(self, argnames, code):
  255         def resf(args):
  256             local_vars = dict(zip(argnames, args))
  257             for stmt in code.split(';'):
  258                 res, abort = self.interpret_statement(stmt, local_vars)
  259                 if abort:
  260                     break
  261             return res
  262         return resf

Generated by cgit