From 6bb3c2f30122283b41d444cf1e11a4df08b274d8 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 19 Aug 2023 19:59:45 -0700 Subject: [PATCH] Added support for TensorrtExecutionProvider [skip ci] --- lib/onnxruntime/ffi.rb | 10 +++++----- lib/onnxruntime/inference_session.rb | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index e3b80f8..ea0aa75 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -185,11 +185,11 @@ class Api < ::FFI::Struct :ReleasePrepackedWeightsContainer, callback(%i[], :pointer), :CreateSessionWithPrepackedWeightsContainer, callback(%i[], :pointer), :CreateSessionFromArrayWithPrepackedWeightsContainer, callback(%i[], :pointer), - :SessionOptionsAppendExecutionProvider_TensorRT_V2, callback(%i[], :pointer), - :CreateTensorRTProviderOptions, callback(%i[], :pointer), - :UpdateTensorRTProviderOptions, callback(%i[], :pointer), - :GetTensorRTProviderOptionsAsString, callback(%i[], :pointer), - :ReleaseTensorRTProviderOptions, callback(%i[], :pointer), + :SessionOptionsAppendExecutionProvider_TensorRT_V2, callback(%i[pointer pointer], :pointer), + :CreateTensorRTProviderOptions, callback(%i[pointer], :pointer), + :UpdateTensorRTProviderOptions, callback(%i[pointer pointer pointer size_t], :pointer), + :GetTensorRTProviderOptionsAsString, callback(%i[pointer pointer pointer], :pointer), + :ReleaseTensorRTProviderOptions, callback(%i[pointer], :pointer), :EnableOrtCustomOps, callback(%i[], :pointer), :RegisterAllocator, callback(%i[], :pointer), :UnregisterAllocator, callback(%i[], :pointer), diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index e3d2162..1097b78 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -68,6 +68,11 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr release :CUDAProviderOptions, cuda_options when "CPUExecutionProvider" break + when "TensorrtExecutionProvider" + tensor_rt_options = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CreateTensorRTProviderOptions].call(tensor_rt_options) + check_status api[:SessionOptionsAppendExecutionProvider_TensorRT_V2].call(session_options.read_pointer, tensor_rt_options.read_pointer) + release :TensorRTProviderOptions, tensor_rt_options else raise ArgumentError, "Provider not supported: #{provider}" end