From 6bb3c2f30122283b41d444cf1e11a4df08b274d8 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 19 Aug 2023 19:59:45 -0700 Subject: [PATCH 1/2] Added support for TensorrtExecutionProvider [skip ci] --- lib/onnxruntime/ffi.rb | 10 +++++----- lib/onnxruntime/inference_session.rb | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index e3b80f8..ea0aa75 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -185,11 +185,11 @@ class Api < ::FFI::Struct :ReleasePrepackedWeightsContainer, callback(%i[], :pointer), :CreateSessionWithPrepackedWeightsContainer, callback(%i[], :pointer), :CreateSessionFromArrayWithPrepackedWeightsContainer, callback(%i[], :pointer), - :SessionOptionsAppendExecutionProvider_TensorRT_V2, callback(%i[], :pointer), - :CreateTensorRTProviderOptions, callback(%i[], :pointer), - :UpdateTensorRTProviderOptions, callback(%i[], :pointer), - :GetTensorRTProviderOptionsAsString, callback(%i[], :pointer), - :ReleaseTensorRTProviderOptions, callback(%i[], :pointer), + :SessionOptionsAppendExecutionProvider_TensorRT_V2, callback(%i[pointer pointer], :pointer), + :CreateTensorRTProviderOptions, callback(%i[pointer], :pointer), + :UpdateTensorRTProviderOptions, callback(%i[pointer pointer pointer size_t], :pointer), + :GetTensorRTProviderOptionsAsString, callback(%i[pointer pointer pointer], :pointer), + :ReleaseTensorRTProviderOptions, callback(%i[pointer], :pointer), :EnableOrtCustomOps, callback(%i[], :pointer), :RegisterAllocator, callback(%i[], :pointer), :UnregisterAllocator, callback(%i[], :pointer), diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index e3d2162..1097b78 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -68,6 +68,11 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr release :CUDAProviderOptions, cuda_options when "CPUExecutionProvider" break + when "TensorrtExecutionProvider" + tensor_rt_options = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CreateTensorRTProviderOptions].call(tensor_rt_options) + check_status api[:SessionOptionsAppendExecutionProvider_TensorRT_V2].call(session_options.read_pointer, tensor_rt_options.read_pointer) + release :TensorRTProviderOptions, tensor_rt_options else raise ArgumentError, "Provider not supported: #{provider}" end From 8421aed98eaff2db7db3410d2811ab4a8957ac23 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 15 Oct 2023 11:45:05 -0700 Subject: [PATCH 2/2] Set logger [skip ci] --- lib/onnxruntime/ffi.rb | 2 +- lib/onnxruntime/inference_session.rb | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index ea0aa75..fc9ea3f 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -19,7 +19,7 @@ class Api < ::FFI::Struct :GetErrorCode, callback(%i[pointer], :pointer), :GetErrorMessage, callback(%i[pointer], :pointer), :CreateEnv, callback(%i[int string pointer], :pointer), - :CreateEnvWithCustomLogger, callback(%i[], :pointer), + :CreateEnvWithCustomLogger, callback(%i[pointer pointer int string pointer], :pointer), :EnableTelemetryEvents, callback(%i[pointer], :pointer), :DisableTelemetryEvents, callback(%i[pointer], :pointer), :CreateSession, callback(%i[pointer pointer pointer pointer], :pointer), diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index 1097b78..706964e 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -614,8 +614,11 @@ def env Utils.mutex.synchronize do @@env ||= begin env = ::FFI::MemoryPointer.new(:pointer) - check_status api[:CreateEnv].call(3, "Default", env) - at_exit { release :Env, env } + logging_function = ::FFI::Function.new(:void, [:pointer, :int, :pointer, :pointer, :pointer, :pointer]) do |param, severity, category, logid, code_location, message| + puts message.read_string + end + env.instance_variable_set(:@logging_function, logging_function) + check_status api[:CreateEnvWithCustomLogger].call(logging_function, nil, 0, "Default", env) # disable telemetry # https://github.com/microsoft/onnxruntime/blob/master/docs/Privacy.md check_status api[:DisableTelemetryEvents].call(env)