diff --git a/PyTeX/format/auto_format.py b/PyTeX/format/auto_format.py index 9879064..365ecaa 100644 --- a/PyTeX/format/auto_format.py +++ b/PyTeX/format/auto_format.py @@ -8,6 +8,7 @@ from .simple_tex_formatter import SimpleTeXFormatter from .dtx_formatter import DTXFormatter from .pytex_formatter import PyTeXFormatter from .git_version_info import GitVersionInfo +from .default_macros import get_default_macros def formatter_from_file_extension( @@ -15,7 +16,8 @@ def formatter_from_file_extension( config: Optional[FormattingConfig] = None, git_version_info: Optional[GitVersionInfo] = None, locate_file_config: bool = True, - allow_infile_config: bool = True + allow_infile_config: bool = True, + default_macros: bool = True, ) -> PyTeXFormatter: switcher: Dict[str, Type[Union[DTXFormatter, SimpleTeXFormatter, DictFormatter]]] = { 'dtx.pytex': DTXFormatter, @@ -29,10 +31,15 @@ def formatter_from_file_extension( except: raise NotImplementedError - return switcher[extension]( + formatter = switcher[extension]( input_file=input_file, config=config, git_version_info=git_version_info, locate_file_config=locate_file_config, allow_infile_config=allow_infile_config ) + if default_macros: + formatter.macros = get_default_macros() + return formatter + + diff --git a/PyTeX/format/macros.py b/PyTeX/format/macros.py index 169e977..10d6052 100644 --- a/PyTeX/format/macros.py +++ b/PyTeX/format/macros.py @@ -1,24 +1,27 @@ import re -from typing import List, Union +from typing import List, Union, Tuple, Dict from .constants import * from .enums import FormatterProperty, Argument +from abc import ABC, abstractmethod class MacroReplacement: def __init__( self, replacement: str, - format_type: str = '%', *args, - **kwargs + **kwargs, ): + if 'format_type' in kwargs.keys(): + self.format_type = kwargs['format_type'] + else: + self.format_type = '%' self.replacement: str = replacement - self.format_type = '%', self.args = args self.kwargs = kwargs - def make_format_args(self, formatter, *call_args): + def make_format_args(self, formatter, *call_args) -> Tuple[Tuple, Dict]: new_args = [] for arg in self.args: if type(arg) == FormatterProperty: @@ -38,44 +41,67 @@ class MacroReplacement: new_kwargs = {} for kw in self.kwargs.keys(): if type(self.kwargs[kw]) == FormatterProperty: - new_kwargs[kw] = getattr(formatter, self.kwargs[kw].value) + new_kwargs[kw] = formatter.attribute_dict[self.kwargs[kw].value] elif type(self.kwargs[kw]) == Argument: new_kwargs[kw] = call_args[self.kwargs[kw].value - 1] elif type(self.kwargs[kw]) == str: new_kwargs[kw] = self.kwargs[kw] else: raise NotImplementedError - return new_args, new_kwargs + return tuple(new_args), new_kwargs def format(self, formatter, *call_args) -> str: args, kwargs = self.make_format_args(formatter, *call_args) if self.format_type == '%': - if self.args: + if self.kwargs: raise NotImplementedError # Currently, not supported - return self.replacement % kwargs + return self.replacement % args elif self.format_type == '{': return self.replacement.format( - *args, **kwargs + *args, **kwargs, **formatter.attribute_dict ) else: raise NotImplementedError -class Macro: +class Macro(ABC): + @abstractmethod def __init__(self): raise NotImplementedError + @abstractmethod def matches(self, line: str) -> bool: raise NotImplementedError + @abstractmethod def apply(self, line: str, formatter) -> Union[str, List[str]]: raise NotImplementedError -class BasicMacro(Macro): +class SimpleMacro(Macro): def __init__( self, - macroname, + macroname: str, + macro_replacement: MacroReplacement + ): + self.macroname = macroname + self.macro_replacement = macro_replacement + + def matches(self, line: str) -> bool: + return line.find(FORMATTER_PREFIX + self.macroname) != -1 + + def apply(self, line: str, formatter) -> Union[str, List[str]]: + return line.replace( + FORMATTER_PREFIX + self.macroname, + self.macro_replacement.format( + formatter + )) + + +class ArgumentMacro(Macro): + def __init__( + self, + macroname: str, num_args: int, macro_replacement: MacroReplacement ): @@ -88,6 +114,8 @@ class BasicMacro(Macro): )) def matches(self, line: str) -> bool: + if line.find('!!') != -1: + pass match = re.search(self._search_regex, line) if match is None: return False diff --git a/PyTeX/format/tex_formatter.py b/PyTeX/format/tex_formatter.py index 3777389..01299f7 100644 --- a/PyTeX/format/tex_formatter.py +++ b/PyTeX/format/tex_formatter.py @@ -36,7 +36,7 @@ class LineStream: self.reserve_lines(pos + 1) return self._cached_lines[pos] - def set_future_line(self, line, pos): + def set_future_line(self, pos: int, line: str): self.reserve_lines(pos + 1) self._cached_lines[pos] = line @@ -94,6 +94,10 @@ class TexFormatter(PyTeXFormatter, ABC): def macros(self) -> List[Macro]: return self._macros + @macros.setter + def macros(self, macros: List[Macro]): + self._macros = macros + def _handle_macro(self, macro: Macro): res = macro.apply( self.line_stream.current_line(),