from threading import Thread from markupsafe import Markup import jinjax class ThreadWithReturnValue(Thread): def __init__(self, group=None, target=None, name=None, args=None, kwargs=None): args = args or () kwargs = kwargs or {} Thread.__init__( self, group=group, target=target, name=name, args=args, kwargs=kwargs, ) self._target = target self._args = args self._kwargs = kwargs self._return = None def run(self): if self._target is not None: self._return = self._target(*self._args, **self._kwargs) def join(self, *args, **kwargs): Thread.join(self, *args, **kwargs) return self._return def test_thread_safety_of_render_assets(catalog, folder): NUM_THREADS = 5 child_tmpl = """ {#css "c{i}.css" #} {#js "c{i}.js" #}

Child {i}

""".strip() parent_tmpl = """ {{ catalog.render_assets() }} {{ content }}""".strip() comp_tmpl = """ {#css "a{i}.css", "b{i}.css" #} {#js "a{i}.js", "b{i}.js" #} """.strip() expected_tmpl = """

Child {i}

""".strip() def render(i): return catalog.render(f"Page{i}") for i in range(NUM_THREADS): si = str(i) child_name = f"Child{i}.jinja" child_src = child_tmpl.replace("{i}", si) parent_name = f"Parent{i}.jinja" parent_src = parent_tmpl.replace("{i}", si) comp_name = f"Page{i}.jinja" comp_src = comp_tmpl.replace("{i}", si) (folder / child_name).write_text(child_src) (folder / comp_name).write_text(comp_src) (folder / parent_name).write_text(parent_src) threads = [] for i in range(NUM_THREADS): thread = ThreadWithReturnValue(target=render, args=(i,)) threads.append(thread) thread.start() results = [thread.join() for thread in threads] for i, result in enumerate(results): expected = expected_tmpl.replace("{i}", str(i)) print(f"---- EXPECTED {i}----") print(expected) print(f"---- RESULT {i}----") print(result) assert result == Markup(expected) def test_same_thread_assets_independence(catalog, folder): catalog2 = jinjax.Catalog() catalog2.add_folder(folder) print("Catalog1 key:", catalog._key) print("Catalog2 key:", catalog2._key) # Check if the context variables exist before the test print("Before any rendering:") print("Catalog1 in collected_css:", catalog._key in jinjax.catalog.collected_css) print("Catalog2 in collected_css:", catalog2._key in jinjax.catalog.collected_css) print("collected_css keys:", list(jinjax.catalog.collected_css.keys())) print("collected_js keys:", list(jinjax.catalog.collected_js.keys())) (folder / "Parent.jinja").write_text( """ {{ catalog.render_assets() }} {{ content }}""".strip() ) (folder / "Comp1.jinja").write_text( """ {#css "a.css" #} {#js "a.js" #} """.strip() ) (folder / "Comp2.jinja").write_text( """ {#css "b.css" #} {#js "b.js" #} """.strip() ) expected_1 = """ """.strip() expected_2 = """ """.strip() # Render first component with first catalog html1 = catalog.render("Comp1") # Check context variables after first render print("\nAfter first render:") print("Catalog1 collected_css:", catalog.collected_css) print("Catalog2 collected_css:", catalog2.collected_css) print("Catalog1 in collected_css:", catalog._key in jinjax.catalog.collected_css) print("Catalog2 in collected_css:", catalog2._key in jinjax.catalog.collected_css) print("collected_css keys:", list(jinjax.catalog.collected_css.keys())) # `irender` instead of `render` so the assets are not cleared html2 = catalog2.irender("Comp2") # Check context variables after second render print("\nAfter second render:") print("Catalog1 collected_css:", catalog.collected_css) print("Catalog2 collected_css:", catalog2.collected_css) print("Catalog1 in collected_css:", catalog._key in jinjax.catalog.collected_css) print("Catalog2 in collected_css:", catalog2._key in jinjax.catalog.collected_css) print("collected_css keys:", list(jinjax.catalog.collected_css.keys())) print("\nHTML outputs:") print("HTML1:", html1) print("HTML2:", html2) assert html1 == Markup(expected_1) assert html2 == Markup(expected_2) def test_thread_safety_of_template_globals(catalog, folder): NUM_THREADS = 5 (folder / "Page.jinja").write_text( "{{ globalvar if globalvar is defined else 'not set' }}" ) def render(i): return catalog.render("Page", _globals={"globalvar": i}) threads = [] for i in range(NUM_THREADS): thread = ThreadWithReturnValue(target=render, args=(i,)) threads.append(thread) thread.start() results = [thread.join() for thread in threads] for i, result in enumerate(results): assert result == Markup(str(i))