diff --git a/lib/rls.rb b/lib/rls.rb index 3c1b666..6e9835d 100644 --- a/lib/rls.rb +++ b/lib/rls.rb @@ -22,10 +22,6 @@ def configuration Rails.application.config.rls end - def connection - ActiveRecord::Base.connection - end - def role configuration.role end @@ -38,12 +34,12 @@ def admin RLS::Current.admin end - def disable_rls_role! + def disable! self.admin = true ActiveRecord::Base.connection_pool.disconnect! end - def enable_rls_role! + def enable! self.admin = false ActiveRecord::Base.connection_pool.disconnect! end @@ -63,12 +59,11 @@ def process(tenant_id, &block) end def set!(tenant_id) - connection.execute format(SET_CUSTOMER_ID_SQL, connection.quote(tenant_id)) + connection.rls_set(tenant_id:) end def reset! - connection.execute RESET_CUSTOMER_ID_SQL - connection.clear_query_cache + connection.rls_reset end end diff --git a/lib/rls/extensions/postgresql_adapter.rb b/lib/rls/extensions/postgresql_adapter.rb index 3d4e2a6..b1275a5 100644 --- a/lib/rls/extensions/postgresql_adapter.rb +++ b/lib/rls/extensions/postgresql_adapter.rb @@ -1,9 +1,23 @@ module RLS module Extensions module PostgreSQLAdapter + SET_ROLE_SQL = 'SET ROLE %s'.freeze + + SET_TENANT_ID_SQL = 'SET rls.tenant_id = %s'.freeze + RESET_TENANT_ID_SQL = 'RESET rls.tenant_id'.freeze + def initialize(...) super - execute("SET ROLE #{RLS.role}") unless RLS.admin + execute(format(SET_ROLE_SQL, quote(RLS.role))) unless RLS.admin + end + + def rls_set(tenant_id:) + execute(format(SET_TENANT_ID_SQL, quote(tenant_id))) + end + + def rls_reset + execute(RESET_TENANT_ID_SQL) + clear_query_cache end end end diff --git a/lib/tasks/rls.rake b/lib/tasks/rls.rake index 30b5228..b65aa9c 100644 --- a/lib/tasks/rls.rake +++ b/lib/tasks/rls.rake @@ -1,22 +1,32 @@ # frozen_string_literal: true -Rake::Task['db:load_config'].enhance(['rls:disable_rls_role']) +# disable before +Rake::Task['db:load_config'].enhance(['rls:disable']) + +# enable after +Rake::Task.tasks.each do |task| + if task.prerequisites.any? { |pre| pre == 'db:load_config' || (pre == 'load_config' && task.name.start_with?('db:')) } + task.enhance do + Rake::Task['rls:enable'].invoke + end + end +end namespace :rls do def connection @connection ||= RLS.connection end - task disable_rls_role: :environment do - RLS.disable_rls_role! + task disable: :environment do + RLS.disable! end - task enable_rls_role: :environment do - RLS.enable_rls_role! + task enable: :environment do + RLS.enable! end task create_role: :environment do - RLS.disable_rls_role! + RLS.disable! RLS.connection.execute <<~SQL CREATE ROLE "#{RLS.role}" WITH NOLOGIN; @@ -28,11 +38,11 @@ namespace :rls do puts "Role #{RLS.role} created" - RLS.enable_rls_role! + RLS.enable! end task drop_role: :environment do - RLS.disable_rls_role! + RLS.disable! RLS.connection.execute <<~SQL ALTER DEFAULT PRIVILEGES IN SCHEMA public REVOKE ALL ON TABLES FROM "#{RLS.role}"; @@ -45,7 +55,7 @@ namespace :rls do puts "Role #{RLS.role} dropped" - RLS.enable_rls_role! + RLS.enable! end end