injecttools.py 3.93 KB
Newer Older
Kunshan Wang's avatar
Kunshan Wang committed
1
import re
Kunshan Wang's avatar
Kunshan Wang committed
2
from typing import List, Union, Tuple, Any, Callable, TypeVar, Mapping
Kunshan Wang's avatar
Kunshan Wang committed
3 4
from typing.re import Pattern

Kunshan Wang's avatar
Kunshan Wang committed
5 6
import tempfile, os.path

Kunshan Wang's avatar
Kunshan Wang committed
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
Predicate = Union[str,
        Tuple[Pattern, ...],
        Callable[[Any], bool]]

def _string_contains(line, string):
    return string in line

def _pattern_value_match(line, tup):
    pat = tup[0]
    vals = tup[1:]
    m = pat.search(line)
    return m is not None and all(
            v is None or g == v
            for g,v in zip(m.groups(), vals))

def _apply_func(line, func):
    return func(line)

def find_line(lines: List[str], substr: Predicate, start: int = 0) -> int:
    """Find the line that contains or matches substr since line ``start``. """
    if isinstance(substr, str):
        pred = _string_contains
    elif isinstance(substr, tuple):
        pred = _pattern_value_match
    else:
        pred = _apply_func

    for i in range(start, len(lines)):
        if pred(lines[i], substr):
            return i

    raise KeyError("Not found: " + str(substr) + "\n text:" + str(lines) )

def extract_lines(parent: str, begin: Predicate, end: Predicate) -> str:
    """
    Extract the lines between the line containing ``begin`` and the line
    containing ``end`` (excluding both lines) in ``parent``.
    """
    lines = parent.splitlines()

    begin_line = find_line(lines, begin)
    end_line = find_line(lines, end, begin_line+1)

    new_lines = lines[begin_line+1:end_line]

    return "\n".join(new_lines)

def inject_lines(parent: str, begin: Predicate, end: Predicate, generated: str) -> str:
    """
    Replace the lines between the line containing ``begin`` and the line
    containing ``end`` (excluding both lines) in ``parent`` with ``generated``.
    """
    lines = parent.splitlines()

    begin_line = find_line(lines, begin)
    end_line = find_line(lines, end, begin_line+1)

    new_lines = lines[:begin_line+1] + generated.splitlines() + lines[end_line:]

    return "\n".join(new_lines)
Kunshan Wang's avatar
Kunshan Wang committed
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

STANDARD_PREFIX_BEGIN = "GEN:BEGIN:"
STANDARD_PREFIX_END   = "GEN:END:"

class StandardInjectableFile(object):
    def __init__(self, path: str, injection_points: List[str] = None):
        self.path = path
        if injection_points is None:
            injection_points = []
        self.injection_points = injection_points

    def inject_many(self, m: Mapping[str, str], force=False):
        with open(self.path) as f:
            orig_txt = f.read()

        txt = orig_txt

        for inj_point, inj_content in m.items():
            if inj_point not in self.injection_points and not force:
                raise Exception("Unknown injection point '{}'".format(inj_point))
            inj_begin = STANDARD_PREFIX_BEGIN + inj_point
            inj_end   = STANDARD_PREFIX_END   + inj_point

            new_txt = inject_lines(txt, inj_begin, inj_end, inj_content)
            txt = new_txt

        with tempfile.NamedTemporaryFile("w", delete=False) as f:
            print("Backup to temporary file: {} -> {}".format(self.path, f.name))
            f.write(orig_txt)

        with open(self.path, "w") as f:
            print("Writing to file: {}".format(self.path))
            f.write(txt)

def make_injectable_file_set(
        root_path: str,
        items    : List[Tuple[str, str, List[str]]],
        ) -> Mapping[str, StandardInjectableFile]:
    m = {}
    for name, path, inj_points in items:
        full_path = os.path.join(root_path, path)
        sif = StandardInjectableFile(full_path, inj_points)
        m[name] = sif
    return m
        
class InjectableFileSet(object):
    def __init__(self, m: Mapping[str, List[str]]):
        self.injectable_files = {}
        for path, inj_points in m.items():
            inj_file = StandardInjectableFile(path, inj_points)
            self.injectable_files[path] = inj_file
            
    def __getitem__(self, key):
        try:
            return self.injectable_files[key]
        except KeyError as e:
            raise Exception("Unknown injectable file {}".format(key))