diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/core/einsumfunc.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/numpy/core/einsumfunc.py | 1443 |
1 files changed, 1443 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/core/einsumfunc.py b/.venv/lib/python3.12/site-packages/numpy/core/einsumfunc.py new file mode 100644 index 00000000..01966f0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/numpy/core/einsumfunc.py @@ -0,0 +1,1443 @@ +""" +Implementation of optimized einsum. + +""" +import itertools +import operator + +from numpy.core.multiarray import c_einsum +from numpy.core.numeric import asanyarray, tensordot +from numpy.core.overrides import array_function_dispatch + +__all__ = ['einsum', 'einsum_path'] + +einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' +einsum_symbols_set = set(einsum_symbols) + + +def _flop_count(idx_contraction, inner, num_terms, size_dictionary): + """ + Computes the number of FLOPS in the contraction. + + Parameters + ---------- + idx_contraction : iterable + The indices involved in the contraction + inner : bool + Does this contraction require an inner product? + num_terms : int + The number of terms in a contraction + size_dictionary : dict + The size of each of the indices in idx_contraction + + Returns + ------- + flop_count : int + The total number of FLOPS required for the contraction. + + Examples + -------- + + >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) + 30 + + >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) + 60 + + """ + + overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) + op_factor = max(1, num_terms - 1) + if inner: + op_factor += 1 + + return overall_size * op_factor + +def _compute_size_by_dict(indices, idx_dict): + """ + Computes the product of the elements in indices based on the dictionary + idx_dict. + + Parameters + ---------- + indices : iterable + Indices to base the product on. + idx_dict : dictionary + Dictionary of index sizes + + Returns + ------- + ret : int + The resulting product. + + Examples + -------- + >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) + 90 + + """ + ret = 1 + for i in indices: + ret *= idx_dict[i] + return ret + + +def _find_contraction(positions, input_sets, output_set): + """ + Finds the contraction for a given set of input and output sets. + + Parameters + ---------- + positions : iterable + Integer positions of terms used in the contraction. + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + + Returns + ------- + new_result : set + The indices of the resulting contraction + remaining : list + List of sets that have not been contracted, the new set is appended to + the end of this list + idx_removed : set + Indices removed from the entire contraction + idx_contraction : set + The indices used in the current contraction + + Examples + -------- + + # A simple dot product test case + >>> pos = (0, 1) + >>> isets = [set('ab'), set('bc')] + >>> oset = set('ac') + >>> _find_contraction(pos, isets, oset) + ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) + + # A more complex case with additional terms in the contraction + >>> pos = (0, 2) + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set('ac') + >>> _find_contraction(pos, isets, oset) + ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) + """ + + idx_contract = set() + idx_remain = output_set.copy() + remaining = [] + for ind, value in enumerate(input_sets): + if ind in positions: + idx_contract |= value + else: + remaining.append(value) + idx_remain |= value + + new_result = idx_remain & idx_contract + idx_removed = (idx_contract - new_result) + remaining.append(new_result) + + return (new_result, remaining, idx_removed, idx_contract) + + +def _optimal_path(input_sets, output_set, idx_dict, memory_limit): + """ + Computes all possible pair contractions, sieves the results based + on ``memory_limit`` and returns the lowest cost path. This algorithm + scales factorial with respect to the elements in the list ``input_sets``. + + Parameters + ---------- + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + idx_dict : dictionary + Dictionary of index sizes + memory_limit : int + The maximum number of elements in a temporary array + + Returns + ------- + path : list + The optimal contraction order within the memory limit constraint. + + Examples + -------- + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set() + >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} + >>> _optimal_path(isets, oset, idx_sizes, 5000) + [(0, 2), (0, 1)] + """ + + full_results = [(0, [], input_sets)] + for iteration in range(len(input_sets) - 1): + iter_results = [] + + # Compute all unique pairs + for curr in full_results: + cost, positions, remaining = curr + for con in itertools.combinations(range(len(input_sets) - iteration), 2): + + # Find the contraction + cont = _find_contraction(con, remaining, output_set) + new_result, new_input_sets, idx_removed, idx_contract = cont + + # Sieve the results based on memory_limit + new_size = _compute_size_by_dict(new_result, idx_dict) + if new_size > memory_limit: + continue + + # Build (total_cost, positions, indices_remaining) + total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict) + new_pos = positions + [con] + iter_results.append((total_cost, new_pos, new_input_sets)) + + # Update combinatorial list, if we did not find anything return best + # path + remaining contractions + if iter_results: + full_results = iter_results + else: + path = min(full_results, key=lambda x: x[0])[1] + path += [tuple(range(len(input_sets) - iteration))] + return path + + # If we have not found anything return single einsum contraction + if len(full_results) == 0: + return [tuple(range(len(input_sets)))] + + path = min(full_results, key=lambda x: x[0])[1] + return path + +def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): + """Compute the cost (removed size + flops) and resultant indices for + performing the contraction specified by ``positions``. + + Parameters + ---------- + positions : tuple of int + The locations of the proposed tensors to contract. + input_sets : list of sets + The indices found on each tensors. + output_set : set + The output indices of the expression. + idx_dict : dict + Mapping of each index to its size. + memory_limit : int + The total allowed size for an intermediary tensor. + path_cost : int + The contraction cost so far. + naive_cost : int + The cost of the unoptimized expression. + + Returns + ------- + cost : (int, int) + A tuple containing the size of any indices removed, and the flop cost. + positions : tuple of int + The locations of the proposed tensors to contract. + new_input_sets : list of sets + The resulting new list of indices if this proposed contraction is performed. + + """ + + # Find the contraction + contract = _find_contraction(positions, input_sets, output_set) + idx_result, new_input_sets, idx_removed, idx_contract = contract + + # Sieve the results based on memory_limit + new_size = _compute_size_by_dict(idx_result, idx_dict) + if new_size > memory_limit: + return None + + # Build sort tuple + old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions) + removed_size = sum(old_sizes) - new_size + + # NB: removed_size used to be just the size of any removed indices i.e.: + # helpers.compute_size_by_dict(idx_removed, idx_dict) + cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) + sort = (-removed_size, cost) + + # Sieve based on total cost as well + if (path_cost + cost) > naive_cost: + return None + + # Add contraction to possible choices + return [sort, positions, new_input_sets] + + +def _update_other_results(results, best): + """Update the positions and provisional input_sets of ``results`` based on + performing the contraction result ``best``. Remove any involving the tensors + contracted. + + Parameters + ---------- + results : list + List of contraction results produced by ``_parse_possible_contraction``. + best : list + The best contraction of ``results`` i.e. the one that will be performed. + + Returns + ------- + mod_results : list + The list of modified results, updated with outcome of ``best`` contraction. + """ + + best_con = best[1] + bx, by = best_con + mod_results = [] + + for cost, (x, y), con_sets in results: + + # Ignore results involving tensors just contracted + if x in best_con or y in best_con: + continue + + # Update the input_sets + del con_sets[by - int(by > x) - int(by > y)] + del con_sets[bx - int(bx > x) - int(bx > y)] + con_sets.insert(-1, best[2][-1]) + + # Update the position indices + mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) + mod_results.append((cost, mod_con, con_sets)) + + return mod_results + +def _greedy_path(input_sets, output_set, idx_dict, memory_limit): + """ + Finds the path by contracting the best pair until the input list is + exhausted. The best pair is found by minimizing the tuple + ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing + matrix multiplication or inner product operations, then Hadamard like + operations, and finally outer operations. Outer products are limited by + ``memory_limit``. This algorithm scales cubically with respect to the + number of elements in the list ``input_sets``. + + Parameters + ---------- + input_sets : list + List of sets that represent the lhs side of the einsum subscript + output_set : set + Set that represents the rhs side of the overall einsum subscript + idx_dict : dictionary + Dictionary of index sizes + memory_limit : int + The maximum number of elements in a temporary array + + Returns + ------- + path : list + The greedy contraction order within the memory limit constraint. + + Examples + -------- + >>> isets = [set('abd'), set('ac'), set('bdc')] + >>> oset = set() + >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} + >>> _greedy_path(isets, oset, idx_sizes, 5000) + [(0, 2), (0, 1)] + """ + + # Handle trivial cases that leaked through + if len(input_sets) == 1: + return [(0,)] + elif len(input_sets) == 2: + return [(0, 1)] + + # Build up a naive cost + contract = _find_contraction(range(len(input_sets)), input_sets, output_set) + idx_result, new_input_sets, idx_removed, idx_contract = contract + naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) + + # Initially iterate over all pairs + comb_iter = itertools.combinations(range(len(input_sets)), 2) + known_contractions = [] + + path_cost = 0 + path = [] + + for iteration in range(len(input_sets) - 1): + + # Iterate over all pairs on first step, only previously found pairs on subsequent steps + for positions in comb_iter: + + # Always initially ignore outer products + if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): + continue + + result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, + naive_cost) + if result is not None: + known_contractions.append(result) + + # If we do not have a inner contraction, rescan pairs including outer products + if len(known_contractions) == 0: + + # Then check the outer products + for positions in itertools.combinations(range(len(input_sets)), 2): + result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, + path_cost, naive_cost) + if result is not None: + known_contractions.append(result) + + # If we still did not find any remaining contractions, default back to einsum like behavior + if len(known_contractions) == 0: + path.append(tuple(range(len(input_sets)))) + break + + # Sort based on first index + best = min(known_contractions, key=lambda x: x[0]) + + # Now propagate as many unused contractions as possible to next iteration + known_contractions = _update_other_results(known_contractions, best) + + # Next iteration only compute contractions with the new tensor + # All other contractions have been accounted for + input_sets = best[2] + new_tensor_pos = len(input_sets) - 1 + comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) + + # Update path and total cost + path.append(best[1]) + path_cost += best[0][1] + + return path + + +def _can_dot(inputs, result, idx_removed): + """ + Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. + + Parameters + ---------- + inputs : list of str + Specifies the subscripts for summation. + result : str + Resulting summation. + idx_removed : set + Indices that are removed in the summation + + + Returns + ------- + type : bool + Returns true if BLAS should and can be used, else False + + Notes + ----- + If the operations is BLAS level 1 or 2 and is not already aligned + we default back to einsum as the memory movement to copy is more + costly than the operation itself. + + + Examples + -------- + + # Standard GEMM operation + >>> _can_dot(['ij', 'jk'], 'ik', set('j')) + True + + # Can use the standard BLAS, but requires odd data movement + >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) + False + + # DDOT where the memory is not aligned + >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) + False + + """ + + # All `dot` calls remove indices + if len(idx_removed) == 0: + return False + + # BLAS can only handle two operands + if len(inputs) != 2: + return False + + input_left, input_right = inputs + + for c in set(input_left + input_right): + # can't deal with repeated indices on same input or more than 2 total + nl, nr = input_left.count(c), input_right.count(c) + if (nl > 1) or (nr > 1) or (nl + nr > 2): + return False + + # can't do implicit summation or dimension collapse e.g. + # "ab,bc->c" (implicitly sum over 'a') + # "ab,ca->ca" (take diagonal of 'a') + if nl + nr - 1 == int(c in result): + return False + + # Build a few temporaries + set_left = set(input_left) + set_right = set(input_right) + keep_left = set_left - idx_removed + keep_right = set_right - idx_removed + rs = len(idx_removed) + + # At this point we are a DOT, GEMV, or GEMM operation + + # Handle inner products + + # DDOT with aligned data + if input_left == input_right: + return True + + # DDOT without aligned data (better to use einsum) + if set_left == set_right: + return False + + # Handle the 4 possible (aligned) GEMV or GEMM cases + + # GEMM or GEMV no transpose + if input_left[-rs:] == input_right[:rs]: + return True + + # GEMM or GEMV transpose both + if input_left[:rs] == input_right[-rs:]: + return True + + # GEMM or GEMV transpose right + if input_left[-rs:] == input_right[-rs:]: + return True + + # GEMM or GEMV transpose left + if input_left[:rs] == input_right[:rs]: + return True + + # Einsum is faster than GEMV if we have to copy data + if not keep_left or not keep_right: + return False + + # We are a matrix-matrix product, but we need to copy data + return True + + +def _parse_einsum_input(operands): + """ + A reproduction of einsum c side einsum parsing in python. + + Returns + ------- + input_strings : str + Parsed input strings + output_string : str + Parsed output string + operands : list of array_like + The operands to use in the numpy contraction + + Examples + -------- + The operand list is simplified to reduce printing: + + >>> np.random.seed(123) + >>> a = np.random.rand(4, 4) + >>> b = np.random.rand(4, 4, 4) + >>> _parse_einsum_input(('...a,...a->...', a, b)) + ('za,xza', 'xz', [a, b]) # may vary + + >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) + ('za,xza', 'xz', [a, b]) # may vary + """ + + if len(operands) == 0: + raise ValueError("No input operands") + + if isinstance(operands[0], str): + subscripts = operands[0].replace(" ", "") + operands = [asanyarray(v) for v in operands[1:]] + + # Ensure all characters are valid + for s in subscripts: + if s in '.,->': + continue + if s not in einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + + else: + tmp_operands = list(operands) + operand_list = [] + subscript_list = [] + for p in range(len(operands) // 2): + operand_list.append(tmp_operands.pop(0)) + subscript_list.append(tmp_operands.pop(0)) + + output_list = tmp_operands[-1] if len(tmp_operands) else None + operands = [asanyarray(v) for v in operand_list] + subscripts = "" + last = len(subscript_list) - 1 + for num, sub in enumerate(subscript_list): + for s in sub: + if s is Ellipsis: + subscripts += "..." + else: + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] + if num != last: + subscripts += "," + + if output_list is not None: + subscripts += "->" + for s in output_list: + if s is Ellipsis: + subscripts += "..." + else: + try: + s = operator.index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " + "either int or Ellipsis") from e + subscripts += einsum_symbols[s] + # Check for proper "->" + if ("-" in subscripts) or (">" in subscripts): + invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) + if invalid or (subscripts.count("->") != 1): + raise ValueError("Subscripts can only contain one '->'.") + + # Parse ellipses + if "." in subscripts: + used = subscripts.replace(".", "").replace(",", "").replace("->", "") + unused = list(einsum_symbols_set - set(used)) + ellipse_inds = "".join(unused) + longest = 0 + + if "->" in subscripts: + input_tmp, output_sub = subscripts.split("->") + split_subscripts = input_tmp.split(",") + out_sub = True + else: + split_subscripts = subscripts.split(',') + out_sub = False + + for num, sub in enumerate(split_subscripts): + if "." in sub: + if (sub.count(".") != 3) or (sub.count("...") != 1): + raise ValueError("Invalid Ellipses.") + + # Take into account numerical values + if operands[num].shape == (): + ellipse_count = 0 + else: + ellipse_count = max(operands[num].ndim, 1) + ellipse_count -= (len(sub) - 3) + + if ellipse_count > longest: + longest = ellipse_count + + if ellipse_count < 0: + raise ValueError("Ellipses lengths do not match.") + elif ellipse_count == 0: + split_subscripts[num] = sub.replace('...', '') + else: + rep_inds = ellipse_inds[-ellipse_count:] + split_subscripts[num] = sub.replace('...', rep_inds) + + subscripts = ",".join(split_subscripts) + if longest == 0: + out_ellipse = "" + else: + out_ellipse = ellipse_inds[-longest:] + + if out_sub: + subscripts += "->" + output_sub.replace("...", out_ellipse) + else: + # Special care for outputless ellipses + output_subscript = "" + tmp_subscripts = subscripts.replace(",", "") + for s in sorted(set(tmp_subscripts)): + if s not in (einsum_symbols): + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + normal_inds = ''.join(sorted(set(output_subscript) - + set(out_ellipse))) + + subscripts += "->" + out_ellipse + normal_inds + + # Build output string if does not exist + if "->" in subscripts: + input_subscripts, output_subscript = subscripts.split("->") + else: + input_subscripts = subscripts + # Build output subscripts + tmp_subscripts = subscripts.replace(",", "") + output_subscript = "" + for s in sorted(set(tmp_subscripts)): + if s not in einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + + # Make sure output subscripts are in the input + for char in output_subscript: + if char not in input_subscripts: + raise ValueError("Output character %s did not appear in the input" + % char) + + # Make sure number operands is equivalent to the number of terms + if len(input_subscripts.split(',')) != len(operands): + raise ValueError("Number of einsum subscripts must be equal to the " + "number of operands.") + + return (input_subscripts, output_subscript, operands) + + +def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): + # NOTE: technically, we should only dispatch on array-like arguments, not + # subscripts (given as strings). But separating operands into + # arrays/subscripts is a little tricky/slow (given einsum's two supported + # signatures), so as a practical shortcut we dispatch on everything. + # Strings will be ignored for dispatching since they don't define + # __array_function__. + return operands + + +@array_function_dispatch(_einsum_path_dispatcher, module='numpy') +def einsum_path(*operands, optimize='greedy', einsum_call=False): + """ + einsum_path(subscripts, *operands, optimize='greedy') + + Evaluates the lowest cost contraction order for an einsum expression by + considering the creation of intermediate arrays. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation. + *operands : list of array_like + These are the arrays for the operation. + optimize : {bool, list, tuple, 'greedy', 'optimal'} + Choose the type of path. If a tuple is provided, the second argument is + assumed to be the maximum intermediate size created. If only a single + argument is provided the largest input or output array size is used + as a maximum intermediate size. + + * if a list is given that starts with ``einsum_path``, uses this as the + contraction path + * if False no optimization is taken + * if True defaults to the 'greedy' algorithm + * 'optimal' An algorithm that combinatorially explores all possible + ways of contracting the listed tensors and chooses the least costly + path. Scales exponentially with the number of terms in the + contraction. + * 'greedy' An algorithm that chooses the best pair contraction + at each step. Effectively, this algorithm searches the largest inner, + Hadamard, and then outer products at each step. Scales cubically with + the number of terms in the contraction. Equivalent to the 'optimal' + path for most contractions. + + Default is 'greedy'. + + Returns + ------- + path : list of tuples + A list representation of the einsum path. + string_repr : str + A printable representation of the einsum path. + + Notes + ----- + The resulting path indicates which terms of the input contraction should be + contracted first, the result of this contraction is then appended to the + end of the contraction list. This list can then be iterated over until all + intermediate contractions are complete. + + See Also + -------- + einsum, linalg.multi_dot + + Examples + -------- + + We can begin with a chain dot example. In this case, it is optimal to + contract the ``b`` and ``c`` tensors first as represented by the first + element of the path ``(1, 2)``. The resulting tensor is added to the end + of the contraction and the remaining contraction ``(0, 1)`` is then + completed. + + >>> np.random.seed(123) + >>> a = np.random.rand(2, 2) + >>> b = np.random.rand(2, 5) + >>> c = np.random.rand(5, 2) + >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') + >>> print(path_info[0]) + ['einsum_path', (1, 2), (0, 1)] + >>> print(path_info[1]) + Complete contraction: ij,jk,kl->il # may vary + Naive scaling: 4 + Optimized scaling: 3 + Naive FLOP count: 1.600e+02 + Optimized FLOP count: 5.600e+01 + Theoretical speedup: 2.857 + Largest intermediate: 4.000e+00 elements + ------------------------------------------------------------------------- + scaling current remaining + ------------------------------------------------------------------------- + 3 kl,jk->jl ij,jl->il + 3 jl,ij->il il->il + + + A more complex index transformation example. + + >>> I = np.random.rand(10, 10, 10, 10) + >>> C = np.random.rand(10, 10) + >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, + ... optimize='greedy') + + >>> print(path_info[0]) + ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] + >>> print(path_info[1]) + Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary + Naive scaling: 8 + Optimized scaling: 5 + Naive FLOP count: 8.000e+08 + Optimized FLOP count: 8.000e+05 + Theoretical speedup: 1000.000 + Largest intermediate: 1.000e+04 elements + -------------------------------------------------------------------------- + scaling current remaining + -------------------------------------------------------------------------- + 5 abcd,ea->bcde fb,gc,hd,bcde->efgh + 5 bcde,fb->cdef gc,hd,cdef->efgh + 5 cdef,gc->defg hd,defg->efgh + 5 defg,hd->efgh efgh->efgh + """ + + # Figure out what the path really is + path_type = optimize + if path_type is True: + path_type = 'greedy' + if path_type is None: + path_type = False + + explicit_einsum_path = False + memory_limit = None + + # No optimization or a named path algorithm + if (path_type is False) or isinstance(path_type, str): + pass + + # Given an explicit path + elif len(path_type) and (path_type[0] == 'einsum_path'): + explicit_einsum_path = True + + # Path tuple with memory limit + elif ((len(path_type) == 2) and isinstance(path_type[0], str) and + isinstance(path_type[1], (int, float))): + memory_limit = int(path_type[1]) + path_type = path_type[0] + + else: + raise TypeError("Did not understand the path: %s" % str(path_type)) + + # Hidden option, only einsum should call this + einsum_call_arg = einsum_call + + # Python side parsing + input_subscripts, output_subscript, operands = _parse_einsum_input(operands) + + # Build a few useful list and sets + input_list = input_subscripts.split(',') + input_sets = [set(x) for x in input_list] + output_set = set(output_subscript) + indices = set(input_subscripts.replace(',', '')) + + # Get length of each unique dimension and ensure all dimensions are correct + dimension_dict = {} + broadcast_indices = [[] for x in range(len(input_list))] + for tnum, term in enumerate(input_list): + sh = operands[tnum].shape + if len(sh) != len(term): + raise ValueError("Einstein sum subscript %s does not contain the " + "correct number of indices for operand %d." + % (input_subscripts[tnum], tnum)) + for cnum, char in enumerate(term): + dim = sh[cnum] + + # Build out broadcast indices + if dim == 1: + broadcast_indices[tnum].append(char) + + if char in dimension_dict.keys(): + # For broadcasting cases we always want the largest dim size + if dimension_dict[char] == 1: + dimension_dict[char] = dim + elif dim not in (1, dimension_dict[char]): + raise ValueError("Size of label '%s' for operand %d (%d) " + "does not match previous terms (%d)." + % (char, tnum, dimension_dict[char], dim)) + else: + dimension_dict[char] = dim + + # Convert broadcast inds to sets + broadcast_indices = [set(x) for x in broadcast_indices] + + # Compute size of each input array plus the output array + size_list = [_compute_size_by_dict(term, dimension_dict) + for term in input_list + [output_subscript]] + max_size = max(size_list) + + if memory_limit is None: + memory_arg = max_size + else: + memory_arg = memory_limit + + # Compute naive cost + # This isn't quite right, need to look into exactly how einsum does this + inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 + naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) + + # Compute the path + if explicit_einsum_path: + path = path_type[1:] + elif ( + (path_type is False) + or (len(input_list) in [1, 2]) + or (indices == output_set) + ): + # Nothing to be optimized, leave it to einsum + path = [tuple(range(len(input_list)))] + elif path_type == "greedy": + path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) + elif path_type == "optimal": + path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) + else: + raise KeyError("Path name %s not found", path_type) + + cost_list, scale_list, size_list, contraction_list = [], [], [], [] + + # Build contraction tuple (positions, gemm, einsum_str, remaining) + for cnum, contract_inds in enumerate(path): + # Make sure we remove inds from right to left + contract_inds = tuple(sorted(list(contract_inds), reverse=True)) + + contract = _find_contraction(contract_inds, input_sets, output_set) + out_inds, input_sets, idx_removed, idx_contract = contract + + cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict) + cost_list.append(cost) + scale_list.append(len(idx_contract)) + size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) + + bcast = set() + tmp_inputs = [] + for x in contract_inds: + tmp_inputs.append(input_list.pop(x)) + bcast |= broadcast_indices.pop(x) + + new_bcast_inds = bcast - idx_removed + + # If we're broadcasting, nix blas + if not len(idx_removed & bcast): + do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) + else: + do_blas = False + + # Last contraction + if (cnum - len(path)) == -1: + idx_result = output_subscript + else: + sort_result = [(dimension_dict[ind], ind) for ind in out_inds] + idx_result = "".join([x[1] for x in sorted(sort_result)]) + + input_list.append(idx_result) + broadcast_indices.append(new_bcast_inds) + einsum_str = ",".join(tmp_inputs) + "->" + idx_result + + contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas) + contraction_list.append(contraction) + + opt_cost = sum(cost_list) + 1 + + if len(input_list) != 1: + # Explicit "einsum_path" is usually trusted, but we detect this kind of + # mistake in order to prevent from returning an intermediate value. + raise RuntimeError( + "Invalid einsum_path is specified: {} more operands has to be " + "contracted.".format(len(input_list) - 1)) + + if einsum_call_arg: + return (operands, contraction_list) + + # Return the path along with a nice string representation + overall_contraction = input_subscripts + "->" + output_subscript + header = ("scaling", "current", "remaining") + + speedup = naive_cost / opt_cost + max_i = max(size_list) + + path_print = " Complete contraction: %s\n" % overall_contraction + path_print += " Naive scaling: %d\n" % len(indices) + path_print += " Optimized scaling: %d\n" % max(scale_list) + path_print += " Naive FLOP count: %.3e\n" % naive_cost + path_print += " Optimized FLOP count: %.3e\n" % opt_cost + path_print += " Theoretical speedup: %3.3f\n" % speedup + path_print += " Largest intermediate: %.3e elements\n" % max_i + path_print += "-" * 74 + "\n" + path_print += "%6s %24s %40s\n" % header + path_print += "-" * 74 + + for n, contraction in enumerate(contraction_list): + inds, idx_rm, einsum_str, remaining, blas = contraction + remaining_str = ",".join(remaining) + "->" + output_subscript + path_run = (scale_list[n], einsum_str, remaining_str) + path_print += "\n%4d %24s %40s" % path_run + + path = ['einsum_path'] + path + return (path, path_print) + + +def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): + # Arguably we dispatch on more arguments than we really should; see note in + # _einsum_path_dispatcher for why. + yield from operands + yield out + + +# Rewrite einsum to handle different cases +@array_function_dispatch(_einsum_dispatcher, module='numpy') +def einsum(*operands, out=None, optimize=False, **kwargs): + """ + einsum(subscripts, *operands, out=None, dtype=None, order='K', + casting='safe', optimize=False) + + Evaluates the Einstein summation convention on the operands. + + Using the Einstein summation convention, many common multi-dimensional, + linear algebraic array operations can be represented in a simple fashion. + In *implicit* mode `einsum` computes these values. + + In *explicit* mode, `einsum` provides further flexibility to compute + other array operations that might not be considered classical Einstein + summation operations, by disabling, or forcing summation over specified + subscript labels. + + See the notes and examples for clarification. + + Parameters + ---------- + subscripts : str + Specifies the subscripts for summation as comma separated list of + subscript labels. An implicit (classical Einstein summation) + calculation is performed unless the explicit indicator '->' is + included as well as subscript labels of the precise output form. + operands : list of array_like + These are the arrays for the operation. + out : ndarray, optional + If provided, the calculation is done into this array. + dtype : {data-type, None}, optional + If provided, forces the calculation to use the data type specified. + Note that you may have to also give a more liberal `casting` + parameter to allow the conversions. Default is None. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout of the output. 'C' means it should + be C contiguous. 'F' means it should be Fortran contiguous, + 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. + 'K' means it should be as close to the layout as the inputs as + is possible, including arbitrarily permuted axes. + Default is 'K'. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Setting this to + 'unsafe' is not recommended, as it can adversely affect accumulations. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + Default is 'safe'. + optimize : {False, True, 'greedy', 'optimal'}, optional + Controls if intermediate optimization should occur. No optimization + will occur if False and True will default to the 'greedy' algorithm. + Also accepts an explicit contraction list from the ``np.einsum_path`` + function. See ``np.einsum_path`` for more details. Defaults to False. + + Returns + ------- + output : ndarray + The calculation based on the Einstein summation convention. + + See Also + -------- + einsum_path, dot, inner, outer, tensordot, linalg.multi_dot + einops : + similar verbose interface is provided by + `einops <https://github.com/arogozhnikov/einops>`_ package to cover + additional operations: transpose, reshape/flatten, repeat/tile, + squeeze/unsqueeze and reductions. + opt_einsum : + `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ + optimizes contraction order for einsum-like expressions + in backend-agnostic manner. + + Notes + ----- + .. versionadded:: 1.6.0 + + The Einstein summation convention can be used to compute + many multi-dimensional, linear algebraic array operations. `einsum` + provides a succinct way of representing these. + + A non-exhaustive list of these operations, + which can be computed by `einsum`, is shown below along with examples: + + * Trace of an array, :py:func:`numpy.trace`. + * Return a diagonal, :py:func:`numpy.diag`. + * Array axis summations, :py:func:`numpy.sum`. + * Transpositions and permutations, :py:func:`numpy.transpose`. + * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`. + * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`. + * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`. + * Tensor contractions, :py:func:`numpy.tensordot`. + * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`. + + The subscripts string is a comma-separated list of subscript labels, + where each label refers to a dimension of the corresponding operand. + Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` + is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label + appears only once, it is not summed, so ``np.einsum('i', a)`` produces a + view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` + describes traditional matrix multiplication and is equivalent to + :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one + operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent + to :py:func:`np.trace(a) <numpy.trace>`. + + In *implicit mode*, the chosen subscripts are important + since the axes of the output are reordered alphabetically. This + means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while + ``np.einsum('ji', a)`` takes its transpose. Additionally, + ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, + ``np.einsum('ij,jh', a, b)`` returns the transpose of the + multiplication since subscript 'h' precedes subscript 'i'. + + In *explicit mode* the output can be directly controlled by + specifying output subscript labels. This requires the + identifier '->' as well as the list of output subscript labels. + This feature increases the flexibility of the function since + summing can be disabled or forced when required. The call + ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`, + and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`. + The difference is that `einsum` does not allow broadcasting by default. + Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the + order of the output subscript labels and therefore returns matrix + multiplication, unlike the example above in implicit mode. + + To enable and control broadcasting, use an ellipsis. Default + NumPy-style broadcasting is done by adding an ellipsis + to the left of each term, like ``np.einsum('...ii->...i', a)``. + To take the trace along the first and last axes, + you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix + product with the left-most indices instead of rightmost, one can do + ``np.einsum('ij...,jk...->ik...', a, b)``. + + When there is only one operand, no axes are summed, and no output + parameter is provided, a view into the operand is returned instead + of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` + produces a view (changed in version 1.10.0). + + `einsum` also provides an alternative way to provide the subscripts + and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. + If the output shape is not provided in this format `einsum` will be + calculated in implicit mode, otherwise it will be performed explicitly. + The examples below have corresponding `einsum` calls with the two + parameter methods. + + .. versionadded:: 1.10.0 + + Views returned from einsum are now writeable whenever the input array + is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now + have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` + and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal + of a 2D array. + + .. versionadded:: 1.12.0 + + Added the ``optimize`` argument which will optimize the contraction order + of an einsum expression. For a contraction with three or more operands this + can greatly increase the computational efficiency at the cost of a larger + memory footprint during computation. + + Typically a 'greedy' algorithm is applied which empirical tests have shown + returns the optimal path in the majority of cases. In some cases 'optimal' + will return the superlative path through a more expensive, exhaustive search. + For iterative calculations it may be advisable to calculate the optimal path + once and reuse that path by supplying it as an argument. An example is given + below. + + See :py:func:`numpy.einsum_path` for more details. + + Examples + -------- + >>> a = np.arange(25).reshape(5,5) + >>> b = np.arange(5) + >>> c = np.arange(6).reshape(2,3) + + Trace of a matrix: + + >>> np.einsum('ii', a) + 60 + >>> np.einsum(a, [0,0]) + 60 + >>> np.trace(a) + 60 + + Extract the diagonal (requires explicit form): + + >>> np.einsum('ii->i', a) + array([ 0, 6, 12, 18, 24]) + >>> np.einsum(a, [0,0], [0]) + array([ 0, 6, 12, 18, 24]) + >>> np.diag(a) + array([ 0, 6, 12, 18, 24]) + + Sum over an axis (requires explicit form): + + >>> np.einsum('ij->i', a) + array([ 10, 35, 60, 85, 110]) + >>> np.einsum(a, [0,1], [0]) + array([ 10, 35, 60, 85, 110]) + >>> np.sum(a, axis=1) + array([ 10, 35, 60, 85, 110]) + + For higher dimensional arrays summing a single axis can be done with ellipsis: + + >>> np.einsum('...j->...', a) + array([ 10, 35, 60, 85, 110]) + >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) + array([ 10, 35, 60, 85, 110]) + + Compute a matrix transpose, or reorder any number of axes: + + >>> np.einsum('ji', c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> np.einsum('ij->ji', c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> np.einsum(c, [1,0]) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> np.transpose(c) + array([[0, 3], + [1, 4], + [2, 5]]) + + Vector inner products: + + >>> np.einsum('i,i', b, b) + 30 + >>> np.einsum(b, [0], b, [0]) + 30 + >>> np.inner(b,b) + 30 + + Matrix vector multiplication: + + >>> np.einsum('ij,j', a, b) + array([ 30, 80, 130, 180, 230]) + >>> np.einsum(a, [0,1], b, [1]) + array([ 30, 80, 130, 180, 230]) + >>> np.dot(a, b) + array([ 30, 80, 130, 180, 230]) + >>> np.einsum('...j,j', a, b) + array([ 30, 80, 130, 180, 230]) + + Broadcasting and scalar multiplication: + + >>> np.einsum('..., ...', 3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.einsum(',ij', 3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + >>> np.multiply(3, c) + array([[ 0, 3, 6], + [ 9, 12, 15]]) + + Vector outer product: + + >>> np.einsum('i,j', np.arange(2)+1, b) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + >>> np.einsum(np.arange(2)+1, [0], b, [1]) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + >>> np.outer(np.arange(2)+1, b) + array([[0, 1, 2, 3, 4], + [0, 2, 4, 6, 8]]) + + Tensor contraction: + + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> np.einsum('ijk,jil->kl', a, b) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + >>> np.tensordot(a,b, axes=([1,0],[0,1])) + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + Writeable returned arrays (since version 1.10.0): + + >>> a = np.zeros((3, 3)) + >>> np.einsum('ii->i', a)[:] = 1 + >>> a + array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + Example of ellipsis use: + + >>> a = np.arange(6).reshape((3,2)) + >>> b = np.arange(12).reshape((4,3)) + >>> np.einsum('ki,jk->ij', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('ki,...k->i...', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('k...,jk', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + + Chained array operations. For more complicated contractions, speed ups + might be achieved by repeatedly computing a 'greedy' path or pre-computing the + 'optimal' path and repeatedly applying it, using an + `einsum_path` insertion (since version 1.12.0). Performance improvements can be + particularly significant with larger arrays: + + >>> a = np.ones(64).reshape(2,4,8) + + Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) + + >>> for iteration in range(500): + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) + + Sub-optimal `einsum` (due to repeated path calculation time): ~330ms + + >>> for iteration in range(500): + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') + + Greedy `einsum` (faster optimal path approximation): ~160ms + + >>> for iteration in range(500): + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') + + Optimal `einsum` (best usage pattern in some use cases): ~110ms + + >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] + >>> for iteration in range(500): + ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) + + """ + # Special handling if out is specified + specified_out = out is not None + + # If no optimization, run pure einsum + if optimize is False: + if specified_out: + kwargs['out'] = out + return c_einsum(*operands, **kwargs) + + # Check the kwargs to avoid a more cryptic error later, without having to + # repeat default values here + valid_einsum_kwargs = ['dtype', 'order', 'casting'] + unknown_kwargs = [k for (k, v) in kwargs.items() if + k not in valid_einsum_kwargs] + if len(unknown_kwargs): + raise TypeError("Did not understand the following kwargs: %s" + % unknown_kwargs) + + # Build the contraction list and operand + operands, contraction_list = einsum_path(*operands, optimize=optimize, + einsum_call=True) + + # Handle order kwarg for output array, c_einsum allows mixed case + output_order = kwargs.pop('order', 'K') + if output_order.upper() == 'A': + if all(arr.flags.f_contiguous for arr in operands): + output_order = 'F' + else: + output_order = 'C' + + # Start contraction loop + for num, contraction in enumerate(contraction_list): + inds, idx_rm, einsum_str, remaining, blas = contraction + tmp_operands = [operands.pop(x) for x in inds] + + # Do we need to deal with the output? + handle_out = specified_out and ((num + 1) == len(contraction_list)) + + # Call tensordot if still possible + if blas: + # Checks have already been handled + input_str, results_index = einsum_str.split('->') + input_left, input_right = input_str.split(',') + + tensor_result = input_left + input_right + for s in idx_rm: + tensor_result = tensor_result.replace(s, "") + + # Find indices to contract over + left_pos, right_pos = [], [] + for s in sorted(idx_rm): + left_pos.append(input_left.find(s)) + right_pos.append(input_right.find(s)) + + # Contract! + new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos))) + + # Build a new view if needed + if (tensor_result != results_index) or handle_out: + if handle_out: + kwargs["out"] = out + new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs) + + # Call einsum + else: + # If out was specified + if handle_out: + kwargs["out"] = out + + # Do the contraction + new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) + + # Append new items and dereference what we can + operands.append(new_view) + del tmp_operands, new_view + + if specified_out: + return out + else: + return asanyarray(operands[0], order=output_order) |