import argparse import time from pathlib import Path from typing import List import numpy as np import numpy.typing as npt try: from .immich_session import RknnSession except ImportError: from rknn_multi_executor.immich_session import RknnSession # type: ignore def parse_shape(shape_str: str) -> List[int]: parts = [p.strip() for p in shape_str.split(",") if p.strip()] if not parts: raise ValueError("Invalid shape string") return [int(p) for p in parts] def main() -> None: parser = argparse.ArgumentParser(description="Minimal RKNN Native Executor usage") parser.add_argument("--model", required=True, type=Path, help="Path to .rknn model") parser.add_argument("--num-workers", type=int, default=3, help="Number of worker contexts") parser.add_argument( "--shape", type=str, default="1,3,640,640", help="Input shape as comma-separated list, e.g. 1,3,640,640", ) parser.add_argument( "--dtype", type=str, default="float32", choices=["float32", "float16", "int32", "int8", "uint8"], help="Data type for randomly generated input tensor", ) args = parser.parse_args() shape = parse_shape(args.shape) if len(shape) < 2: raise ValueError("Shape must have at least 2 dims (e.g., NCHW)") gen_t0 = time.perf_counter() # Generate a random input tensor with the requested dtype x: npt.NDArray[np.generic] if args.dtype == "float32": x = np.random.rand(*shape).astype(np.float32) elif args.dtype == "float16": x = np.random.rand(*shape).astype(np.float16) elif args.dtype == "int32": # Use a modest integer range; adjust as needed for your model (e.g., vocab size) x = np.random.randint(0, 1000, size=tuple(shape), dtype=np.int32) elif args.dtype == "int8": x = np.random.randint(-128, 128, size=tuple(shape), dtype=np.int8) elif args.dtype == "uint8": x = np.random.randint(0, 256, size=tuple(shape), dtype=np.uint8) else: raise ValueError(f"Unsupported dtype: {args.dtype}") gen_t1 = time.perf_counter() print(f"[timing] generated random {args.dtype} tensor with shape {shape} " f"in {(gen_t1-gen_t0)*1000:.2f} ms") time.sleep(1) session_t0 = time.perf_counter() session = RknnSession(args.model.as_posix(), num_workers=args.num_workers) session_t1 = time.perf_counter() print(session.get_inputs()) print(f"[timing] session init took {(session_t1-session_t0)*1000:.2f} ms") try: print("IO description:", session.io_info) inputs = session.get_inputs() if not inputs: raise RuntimeError("Model exposes no inputs") input_name = inputs[0].name or "input" # Serial demo for i in range(3): t0 = time.perf_counter() outs = session.run(None, {input_name: x}) t1 = time.perf_counter() print( f"[serial {i+1}] start={t0:.6f}s end={t1:.6f}s " f"dur_ms={(t1-t0)*1000:.2f} shapes={[getattr(o, 'shape', None) for o in outs]}" ) # Batch demo (single RKNN session call with batched input to exercise batch mode) batch_repeats = 3 if shape[0] == 1: batch_shape = [batch_repeats, *shape[1:]] else: batch_shape = shape x_batch: npt.NDArray[np.generic] if args.dtype == "float32": x_batch = np.random.rand(*batch_shape).astype(np.float32) elif args.dtype == "float16": x_batch = np.random.rand(*batch_shape).astype(np.float16) elif args.dtype == "int32": x_batch = np.random.randint(0, 1000, size=tuple(batch_shape), dtype=np.int32) elif args.dtype == "int8": x_batch = np.random.randint(-128, 128, size=tuple(batch_shape), dtype=np.int8) elif args.dtype == "uint8": x_batch = np.random.randint(0, 256, size=tuple(batch_shape), dtype=np.uint8) else: raise ValueError(f"Unsupported dtype: {args.dtype}") for i in range(3): t0 = time.perf_counter() outs = session.run(None, {input_name: x_batch}) t1 = time.perf_counter() print( f"[batch {i+1}] start={t0:.6f}s end={t1:.6f}s " f"dur_ms={(t1-t0)*1000:.2f} shapes={[getattr(o, 'shape', None) for o in outs]}" ) time.sleep(1) # Parallel demo using Immich-style pool total_requests = 5 * args.num_workers print(f"[pool] submitting {total_requests} requests with {args.num_workers} worker contexts") batch_t0 = time.perf_counter() futures = [] for _ in range(total_requests): futures.append(session.rknnpool.put([x])) for idx, fut in enumerate(futures): res = fut.result() outs = res.outputs lat_ms = res.duration_s * 1000.0 print(f"[parallel {idx+1}] dur_ms={lat_ms:.2f} shapes={[getattr(o, 'shape', None) for o in outs]}") batch_t1 = time.perf_counter() print(f"[parallel batch] total_ms={(batch_t1-batch_t0)*1000:.2f}") time.sleep(1) finally: session.close() if __name__ == "__main__": main()