pytex/formatter/tex_formatter.py

139 lines
5.8 KiB
Python

import datetime
import re
from pathlib import Path
from typing import Dict, Optional, List
from datetime import *
from PyTeX.base import Attributes, Args
from PyTeX.errors import *
from PyTeX.utils import ensure_file_integrity
from .formatter import Formatter
class TexFormatter(Formatter):
def __init__(self, name: str, author: str, header: Optional[List[str]], file_extension: str,
tex_version: str, version: str, latex_name: str):
self.version = version[1:] if version.startswith('v') else version
self.header = header
self.name_raw = name
self.author = author
self.latex_name = latex_name
author_parts = self.author.lower().replace('ß', 'ss').split(' ')
self.author_acronym = author_parts[0][0] + author_parts[-1]
self.date = datetime.now().strftime('%Y/%m/%d')
self.year = int(datetime.now().strftime('%Y'))
self.replace_dict: Dict = {}
self.arg_replace_dict: Dict = {}
self.source_file_name = "not specified"
if self.latex_name == 'prepend-author':
self.name_lowercase = r'{prefix}-{name}'.format(prefix=self.author_acronym,
name=self.name_raw.lower().strip().replace(' ', '-'))
else:
self.name_lowercase = self.name_raw.lower().strip().replace(' ', '-')
self.file_name = self.name_lowercase + file_extension
if tex_version == 'LaTeX2e':
self.prefix = self.name_lowercase.replace('-', '@') + '@'
elif tex_version == 'LaTeX3':
self.prefix = '__' + self.name_lowercase.replace('-', '_') + '_'
else:
raise UnknownTexVersionError(tex_version)
super().__init__()
@staticmethod
def __command_name2keyword(keyword: str):
return '__' + keyword.upper().strip().replace(' ', '_') + '__'
def expected_file_name(self):
return self.file_name
@property
def filename(self):
return self.file_name
def __parse_replacement_args(self, match_groups, *user_args, **user_kwargs):
new_args = []
for arg in user_args:
if type(arg) == Attributes:
new_args.append(getattr(self, arg.value))
elif type(arg) == Args:
new_args.append(match_groups[arg.value].strip())
elif type(arg) == str:
new_args.append(arg.strip())
else:
new_args += 'ERROR'
new_args = tuple(new_args)
new_kwargs = {}
for kw in user_kwargs:
if type(user_kwargs[kw]) == Attributes:
new_kwargs[kw] = getattr(self, user_kwargs[kw].value)
elif type(user_kwargs[kw]) == Args:
new_kwargs[kw] = match_groups[user_kwargs[kw].value].strip()
elif type(user_kwargs[kw]) == str:
new_kwargs[kw] = user_kwargs[kw]
else:
new_kwargs[kw] = 'ERROR'
return new_args, new_kwargs
def __format_string(self, contents: str) -> str:
for key in self.replace_dict.keys():
contents = contents.replace(key, self.replace_dict[key])
return contents
def __format_string_with_arg(self, contents: str) -> str:
for command in self.arg_replace_dict.keys():
search_regex = re.compile(r'{keyword}\({arguments}(?<!@)\)'.format(
keyword=command,
arguments=','.join(['(.*?)'] * self.arg_replace_dict[command]['num_args'])
))
match = re.search(search_regex, contents)
while match is not None:
format_args, format_kwargs = self.__parse_replacement_args(
list(map(lambda group: group.replace('@)', ')'), match.groups())),
*self.arg_replace_dict[command]['format_args'],
**self.arg_replace_dict[command]['format_kwargs']
)
contents = contents.replace(
match.group(),
self.arg_replace_dict[command]['replacement'].format(*format_args, **format_kwargs)
)
match = re.search(search_regex, contents)
return contents
def add_replacement(self, keyword: str, replacement: str, *args, **kwargs):
args, kwargs = self.__parse_replacement_args([], *args, **kwargs)
self.replace_dict[self.__command_name2keyword(keyword)] = replacement.format(*args, **kwargs)
def add_arg_replacement(self, num_args: int, keyword: str, replacement: str, *args, **kwargs):
self.arg_replace_dict[self.__command_name2keyword(keyword)] = {
'num_args': num_args,
'replacement': replacement,
'format_args': args,
'format_kwargs': kwargs
}
def format_file(
self,
input_path: Path,
output_dir: Path = None,
relative_name: Optional[str] = None,
last_build_info: Optional[List[Dict]] = None) -> List[str]:
self.source_file_name = str(input_path.name)
input_file = input_path.open()
lines = input_file.readlines()
if self.header:
newlines = '%' * 80 + '\n' \
+ '\n'.join(map(lambda line: '% ' + line, self.header)) \
+ '\n' + '%' * 80 + '\n\n'
else:
newlines = []
for line in lines:
newlines += self.__format_string_with_arg(self.__format_string(line))
if output_dir is None:
output_dir = input_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
ensure_file_integrity(output_dir / self.file_name, str(Path(relative_name).parent / self.file_name), last_build_info)
(output_dir / self.file_name).write_text(''.join(newlines))
return [str(self.file_name)]