From 5f19a1f251336b0d7b312d1531a547b9c83c7c1b Mon Sep 17 00:00:00 2001 From: yanghuan Date: Thu, 22 Feb 2024 15:31:54 +0800 Subject: [PATCH] fix #443 --- .../CoreSystem.Lua/CoreSystem/Array.lua | 2 +- .../CoreSystem/Collections/Dictionary.lua | 692 ++++++++++-------- .../CoreSystem/Collections/HashSet.lua | 232 +++--- .../CoreSystem/Collections/Linq.lua | 35 +- 4 files changed, 471 insertions(+), 490 deletions(-) diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua index 8bd32d45..2064ca80 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Array.lua @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --]] -local System = System +local System = _G.System local define = System.define local throw = System.throw local div = System.div diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Dictionary.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Dictionary.lua index 2b75b398..946e8c2e 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Dictionary.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Dictionary.lua @@ -14,13 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. --]] -local System = System +local System = _G.System local define = System.define local throw = System.throw local null = System.null local falseFn = System.falseFn local each = System.each -local lengthFn = System.lengthFn local versions = System.versions local Array = System.Array local toString = System.toString @@ -39,7 +38,6 @@ local select = select local getmetatable = getmetatable local setmetatable = setmetatable local tconcat = table.concat -local tremove = table.remove local type = type local counts = setmetatable({}, { __mode = "k" }) @@ -201,215 +199,231 @@ local DictionaryCollection = define("System.Collections.Generic.DictionaryCollec end }, 1) -local function add(this, key, value) - if key == nil then throw(ArgumentNullException("key")) end - if this[key] ~= nil then throw(ArgumentException("key already exists")) end - this[key] = value == nil and null or value - local t = counts[this] - if t then - t[1] = t[1] + 1 - t[2] = t[2] + 1 - else - counts[this] = { 1, 1 } +local ArrayDictionaryFn +local Dictionary = (function () + local function add(this, key, value) + if key == nil then throw(ArgumentNullException("key")) end + if this[key] ~= nil then throw(ArgumentException("key already exists")) end + this[key] = value == nil and null or value + local t = counts[this] + if t then + t[1] = t[1] + 1 + t[2] = t[2] + 1 + else + counts[this] = { 1, 1 } + end end -end -local function remove(this, key) - if key == nil then throw(ArgumentNullException("key")) end - if this[key] ~= nil then - this[key] = nil - local t = counts[this] - t[1] = t[1] - 1 - t[2] = t[2] + 1 - return true + local function remove(this, key) + if key == nil then throw(ArgumentNullException("key")) end + if this[key] ~= nil then + this[key] = nil + local t = counts[this] + t[1] = t[1] - 1 + t[2] = t[2] + 1 + return true + end + return false end - return false -end -local function buildFromDictionary(this, dictionary) - if dictionary == nil then throw(ArgumentNullException("dictionary")) end - local count = 0 - for k, v in pairs(dictionary) do - this[k] = v - count = count + 1 + local function buildFromDictionary(this, dictionary) + if dictionary == nil then throw(ArgumentNullException("dictionary")) end + local count = 0 + for k, v in pairs(dictionary) do + this[k] = v + count = count + 1 + end + counts[this] = { count, 0 } end - counts[this] = { count, 0 } -end -local ArrayDictionaryFn -local function buildHasComparer(this, ...) - local Dictionary = ArrayDictionaryFn(this.__genericTKey__, this.__genericTValue__) - Dictionary.__ctor__(this, ...) - return setmetatable(this, Dictionary) -end + local function buildHasComparer(this, ...) + local Dictionary = ArrayDictionaryFn(this.__genericTKey__, this.__genericTValue__) + setmetatable(this, Dictionary) + Dictionary.__ctor__(this, ...) + end -local Dictionary = { - getIsFixedSize = falseFn, - getIsReadOnly = falseFn, - __ctor__ = function (this, ...) - local n = select("#", ...) - if n == 0 then - elseif n == 1 then - local comparer = ... - if comparer == nil or type(comparer) == "number" then - else - local equals = comparer.EqualsOf - if equals == nil then - buildFromDictionary(this, comparer) - else + return { + getIsFixedSize = falseFn, + getIsReadOnly = falseFn, + __ctor__ = function (this, ...) + local n = select("#", ...) + if n == 1 then + local comparer = ... + if type(comparer) == "table" then + local equals = comparer.EqualsOf + if equals == nil then + buildFromDictionary(this, comparer) + else + buildHasComparer(this, ...) + end + end + elseif n == 2 then + local dictionary, comparer = ... + if comparer ~= nil then buildHasComparer(this, ...) end + if type(dictionary) ~= "number" then + buildFromDictionary(this, dictionary) + end end - else - local dictionary, comparer = ... - if comparer ~= nil then - buildHasComparer(this, ...) - end - if type(dictionary) ~= "number" then - buildFromDictionary(this, dictionary) + end, + AddKeyValue = add, + Add = function (this, ...) + local k, v + if select("#", ...) == 1 then + local pair = ... + k, v = pair[1], pair[2] + else + k, v = ... end - end - end, - AddKeyValue = add, - Add = function (this, ...) - local k, v - if select("#", ...) == 1 then - local pair = ... - k, v = pair[1], pair[2] - else - k, v = ... - end - add(this, k ,v) - end, - Clear = function (this) - for k, _ in pairs(this) do - this[k] = nil - end - counts[this] = nil - end, - ContainsKey = function (this, key) - if key == nil then throw(ArgumentNullException("key")) end - return this[key] ~= nil - end, - ContainsValue = function (this, value) - if value == nil then - for _, v in pairs(this) do - if v == null then - return true - end + add(this, k ,v) + end, + Clear = function (this) + for k, _ in pairs(this) do + this[k] = nil end - else - local comparer = EqualityComparer(this.__genericTValue__).getDefault() - local equals = comparer.EqualsOf + counts[this] = nil + end, + ContainsKey = function (this, key) + if key == nil then throw(ArgumentNullException("key")) end + return this[key] ~= nil + end, + ContainsValue = function (this, value) + if value == nil then for _, v in pairs(this) do - if v ~= null then - if equals(comparer, value, v ) then - return true - end + if v == null then + return true end + end + else + local comparer = EqualityComparer(this.__genericTValue__).getDefault() + local equals = comparer.EqualsOf + for _, v in pairs(this) do + if v ~= null then + if equals(comparer, value, v ) then + return true + end + end + end end - end - return false - end, - Contains = function (this, pair) - local key = pair[1] - if key == nil then throw(ArgumentNullException("key")) end - local value = this[key] - if value ~= nil then - if value == null then value = nil end - local comparer = EqualityComparer(this.__genericTValue__).getDefault() - if comparer:EqualsOf(value, pair[2]) then - return true - end - end - return false - end, - CopyTo = function (this, array, index) - local count = getCount(this) - checkIndexAndCount(array, index, count) - if count > 0 then - local T = this.__genericT__ - index = index + 1 - for k, v in pairs(this) do - if v == null then v = nil end - array[index] = setmetatable({ k, v }, T) - index = index + 1 - end - end - end, - RemoveKey = remove, - Remove = function (this, key) - if isKeyValuePair(key) then - local k, v = key[1], key[2] - if k == nil then throw(ArgumentNullException("key")) end - local value = this[k] + return false + end, + Contains = function (this, pair) + local key = pair[1] + if key == nil then throw(ArgumentNullException("key")) end + local value = this[key] if value ~= nil then if value == null then value = nil end local comparer = EqualityComparer(this.__genericTValue__).getDefault() - if comparer:EqualsOf(value, v) then - remove(this, k) + if comparer:EqualsOf(value, pair[2]) then return true end end return false - end - return remove(this, key) - end, - TryAdd = function (this, key, value) - if key == nil then throw(ArgumentNullException("key")) end - local exists = this:TryGetValue(key) - if exists then - return false - end - this:set(key, value) - return true - end, - TryGetValue = function (this, key) - if key == nil then throw(ArgumentNullException("key")) end - local value = this[key] - if value == nil then - return false, this.__genericTValue__:default() - end - if value == null then return true end - return true, value - end, - getComparer = function (this) - return EqualityComparer(this.__genericTKey__).getDefault() - end, - getCount = getCount, - get = function (this, key) - if key == nil then throw(ArgumentNullException("key")) end - local value = this[key] - if value == nil then throw(KeyNotFoundException()) end - if value ~= null then - return value - end - return nil - end, - set = function (this, key, value) - if key == nil then throw(ArgumentNullException("key")) end - local t = counts[this] - if t then - if this[key] == nil then - t[1] = t[1] + 1 + end, + CopyTo = function (this, array, index) + local count = getCount(this) + checkIndexAndCount(array, index, count) + if count > 0 then + local T = this.__genericT__ + index = index + 1 + for k, v in pairs(this) do + if v == null then v = nil end + array[index] = setmetatable({ k, v }, T) + index = index + 1 + end end - t[2] = t[2] + 1 - else - counts[this] = { 1, 1 } + end, + RemoveKey = remove, + Remove = function (this, key) + if isKeyValuePair(key) then + local k, v = key[1], key[2] + if k == nil then throw(ArgumentNullException("key")) end + local value = this[k] + if value ~= nil then + if value == null then value = nil end + local comparer = EqualityComparer(this.__genericTValue__).getDefault() + if comparer:EqualsOf(value, v) then + remove(this, k) + return true + end + end + return false + end + return remove(this, key) + end, + removeWhere = function (this, match) + local count = 0 + for k, v in pairs(this) do + if match(k, v) then + this[k] = nil + count = count + 1 + end + end + if count > 0 then + local t = counts[this] + t[1] = t[1] - count + t[2] = t[2] + 1 + end + return count + end, + TryAdd = function (this, key, value) + if key == nil then throw(ArgumentNullException("key")) end + local exists = this:TryGetValue(key) + if exists then + return false + end + this:set(key, value) + return true + end, + TryGetValue = function (this, key) + if key == nil then throw(ArgumentNullException("key")) end + local value = this[key] + if value == nil then + return false, this.__genericTValue__:default() + end + if value == null then return true end + return true, value + end, + getComparer = function (this) + return EqualityComparer(this.__genericTKey__).getDefault() + end, + getCount = getCount, + get = function (this, key) + if key == nil then throw(ArgumentNullException("key")) end + local value = this[key] + if value == nil then throw(KeyNotFoundException()) end + if value ~= null then + return value + end + return nil + end, + set = function (this, key, value) + if key == nil then throw(ArgumentNullException("key")) end + local t = counts[this] + if t then + if this[key] == nil then + t[1] = t[1] + 1 + end + t[2] = t[2] + 1 + else + counts[this] = { 1, 1 } + end + this[key] = value == nil and null or value + end, + GetEnumerator = dictionaryEnumerator, + getKeys = function (this) + return DictionaryCollection(this.__genericTKey__)(this, 1) + end, + getValues = function (this) + return DictionaryCollection(this.__genericTValue__)(this, 2) end - this[key] = value == nil and null or value - end, - GetEnumerator = dictionaryEnumerator, - getKeys = function (this) - return DictionaryCollection(this.__genericTKey__)(this, 1) - end, - getValues = function (this) - return DictionaryCollection(this.__genericTValue__)(this, 2) - end -} + } +end)() local ArrayDictionaryEnumerator = define("System.Collections.Generic.ArrayDictionaryEnumerator", function (T) return { + __genericT__ = T, base = { System.IEnumerator_1(T) } } end, { @@ -421,23 +435,34 @@ end, { throwFailedVersion() end local index = this.index - local pair = t[index] - if pair ~= nil then - if this.kind then - this.current = pair[2] + while true do + local pair = t[index] + if pair == nil then + break else - this.current = pair[1] + local k = pair[1] + if k ~= nil then + if this.kind then + this.current = pair[2] + else + this.current = k + end + this.index = index + 1 + return true + else + index = index + 1 + end end - this.index = index + 1 - return true end - this.current = nil + this.current = this.__genericT__:default() return false end }, 1) local arrayDictionaryEnumerator = function (t, kind, T) - return setmetatable({ list = t, kind = kind, index = 1, version = versions[t], currnet = T:default() }, ArrayDictionaryEnumerator(T)) + return setmetatable({ + list = t, kind = kind, index = 1, version = versions[t], currnet = T:default() + }, ArrayDictionaryEnumerator(T)) end local ArrayDictionaryCollection = define("System.Collections.Generic.ArrayDictionaryCollection", function (T) @@ -451,7 +476,7 @@ local ArrayDictionaryCollection = define("System.Collections.Generic.ArrayDictio this.kind = kind end, getCount = function (this) - return #this.dict + return this.dict.count end, get = function (this, index) local p = this.dict[index + 1] @@ -473,86 +498,118 @@ local ArrayDictionaryCollection = define("System.Collections.Generic.ArrayDictio }, 1) local ArrayDictionary = (function () - local function buildFromDictionary(this, dictionary) + local function update(this, add, key, value, set) + if key == nil then throw(ArgumentNullException("key")) end + local comparer, indexs = this.comparer, this.indexs + local code = comparer:GetHashCodeOf(key) + while true do + local index = indexs[code] + local pair = this[index] + if pair == nil then + if add then + local freeList, count = this.freeList, this.count + if freeList then + index = freeList + pair = this[index] + this.freeList = pair[3] + pair[1], pair[2], pair[3] = key, value, nil + else + index = count + 1 + this[index] = setmetatable({ key, value }, this.__genericT__) + end + indexs[code] = index + this.count = count + 1 + versions[this] = (versions[this] or 0) + 1 + else + return false + end + return + else + if comparer:EqualsOf(pair[1], key) then + if add then + if set then + pair[2] = value + return + else + throw(ArgumentException("key already exists")) + end + else + indexs[code] = nil + local freeList, count = this.freeList, this.count + pair[1], pair[2], pair[3] = nil, nil, freeList + this.freeList = index + this.count = count - 1 + versions[this] = (versions[this] or 0) + 1 + return true + end + else + code = code + 1 + end + end + end + end + + local function addRange(this, dictionary) if dictionary == nil then throw(ArgumentNullException("dictionary")) end - local count = 1 - local T = this.__genericT__ for _, pair in each(dictionary) do local k, v = pair[1], pair[2] if type(k) == "table" and k.class == 'S' then k = k:__clone__() end - this[count] = setmetatable({ k, v }, T) - count = count + 1 + update(this, true, k, v) end end - local function add(this, key, value, set) - if key == nil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - if equals(comparer, this[i][1], key) then - if set then - this[i][2] = value - return - else - throw(ArgumentException("key already exists")) - end - end + local function find(this, key) + local comparer, indexs = this.comparer, this.indexs + local code = comparer:GetHashCodeOf(key) + while true do + local index = indexs[code] + local pair = this[index] + if pair == nil then + return nil end - end - this[len + 1] = setmetatable({ key, value }, this.__genericT__) - versions[this] = (versions[this] or 0) + 1 - end - - local function remove(this, key) - if key == nil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - if equals(comparer, this[i][1], key) then - tremove(this, i) - versions[this] = (versions[this] or 0) + 1 - return true - end + if comparer:EqualsOf(pair[1], key) then + return pair + else + code = code + 1 end end - return false end return { + count = 0, getIsFixedSize = falseFn, getIsReadOnly = falseFn, __ctor__ = function (this, ...) - local Comparer + local Comparer, dict local n = select("#", ...) - if n == 0 then - elseif n == 1 then + if n == 1 then local comparer = ... - if comparer == nil or type(comparer) == "number" then - else + if type(comparer) == "table" then local equals = comparer.EqualsOf if equals == nil then - buildFromDictionary(this, comparer) + dict = comparer else Comparer = comparer end end - else + elseif n == 2 then local dictionary, comparer = ... if type(dictionary) ~= "number" then - buildFromDictionary(this, dictionary) + dict = dictionary end Comparer = comparer end this.comparer = Comparer or EqualityComparer(this.__genericTKey__).getDefault() + this.indexs = {} + if dict then + addRange(this, dict) + end + end, + AddKeyValue = function (this, k, v) + update(this, true, k, v) end, - AddKeyValue = add, Add = function (this, ...) local k, v if select("#", ...) == 1 then @@ -561,32 +618,27 @@ local ArrayDictionary = (function () else k, v = ... end - add(this, k ,v) + update(this, true, k ,v) + end, + Clear = function (this) + local count = this.count + if count > 0 then + this.indexs, this.count, this.freeList = {}, 0, nil + Array.clear(this) + end end, - Clear = Array.clear, ContainsKey = function (this, key) if key == nil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - if equals(comparer, this[i][1], key) then - return true - end - end - end - return false + local pair = find(this, key) + return pair ~= nil end, ContainsValue = function (this, value) - local len = #this - if len > 0 then - local comparer = EqualityComparer(this.__genericTValue__).getDefault() - local equals = comparer.EqualsOf - for i = 1, len do - if equals(comparer, value, this[i][2]) then - return true - end + local comparer = EqualityComparer(this.__genericTValue__).getDefault() + local equals = comparer.EqualsOf + for i = 1, #this do + local pair = this[i] + if pair[1] ~= nil and equals(comparer, value, pair[2]) then + return true end end return false @@ -594,98 +646,93 @@ local ArrayDictionary = (function () Contains = function (this, pair) local key = pair[1] if key == nil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - local t = this[i] - if equals(comparer, t[1], key) then - local comparer = EqualityComparer(this.__genericTValue__).getDefault() - if comparer:EqualsOf(t[2], pair[2]) then - return true - end - end + local p = find(this, key) + if p then + local comparer = EqualityComparer(this.__genericTValue__).getDefault() + if comparer:EqualsOf(p[2], pair[2]) then + return true end end return false end, CopyTo = function (this, array, index) - local count = #this + local count = this.count checkIndexAndCount(array, index, count) if count > 0 then - local KeyValuePair = this.__genericT__ + local T = this.__genericT__ index = index + 1 - for i = 1, count do - local t = this[i] - array[index] = setmetatable({ t[1]:__clone__(), t[2] }, KeyValuePair) - index = index + 1 - end - end - end, - RemoveKey = remove, - Remove = function (this, key) - if isKeyValuePair(key) then - local len = #this - local k, v = key[1], key[2] for i = 1, #this do - local pair = this[i] - if pair[1]:EqualsObj(k) then - local comparer = EqualityComparer(this.__genericTValue__).getDefault() - if comparer:EqualsOf(pair[2], v) then - tremove(this, i) - return true + local p = this[i] + local k, v = p[1], p[2] + if k ~= nil then + if type(k) == "table" and k.class == 'S' then + k = k:__clone__() end + array[index] = setmetatable({ k, v }, T) + index = index + 1 end end end + end, + RemoveKey = function (this, key) + return update(this, false, key) + end, + Remove = function (this, pair) + if isKeyValuePair(pair) then + if this:Contains(pair) then + update(this, false, pair[1]) + return true + end + end return false end, + removeWhere = function (this, match) + local count = 0 + for i = 1, #this do + local p = this[i] + local k, v = p[1], p[2] + if k ~= nil then + if match(k, v) then + update(this, false, k) + count = count + 1 + end + end + end + return count + end, TryAdd = function (this, key, value) if key == nil then throw(ArgumentNullException("key")) end - local exists, currentValue = this:TryGetValue(key) + local exists = this:TryGetValue(key) if exists then return false end this:set(key, value) return true end, - TryGetValue = function (this, key, hasNil) - if key == nil and not hasNil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - local pair = this[i] - if equals(comparer, pair[1], key) then - return true, pair[2] - end - end + TryGetValue = function (this, key) + if key == nil then throw(ArgumentNullException("key")) end + local pair = find(this, key) + if pair then + return true, pair[2] end return false, this.__genericTValue__:default() end, getComparer = function (this) return this.comparer end, - getCount = lengthFn, + getCount = function (this) + return this.count + end, get = function (this, key) if key == nil then throw(ArgumentNullException("key")) end - local len = #this - if len > 0 then - local comparer = this.comparer - local equals = comparer.EqualsOf - for i = 1, len do - local pair = this[i] - if equals(comparer, pair[1], key) then - return pair[2] - end - end + local pair = find(this, key) + if pair then + return pair[2] end throw(KeyNotFoundException()) end, set = function (this, key, value) - add(this, key, value, true) + update(this, true, key, value, true) end, GetEnumerator = Array.GetEnumerator, getKeys = function (this) @@ -715,18 +762,15 @@ function System.isDictLike(t) end local DictionaryFn = define("System.Collections.Generic.Dictionary", function(TKey, TValue) - local array, len + local array if hasHash(TKey) then array = ArrayDictionary - else - len = getCount end return { base = { System.IDictionary_2(TKey, TValue), System.IDictionary, System.IReadOnlyDictionary_2(TKey, TValue) }, __genericT__ = KeyValuePairFn(TKey, TValue), __genericTKey__ = TKey, __genericTValue__ = TValue, - __len = len }, array end, Dictionary, 2) diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/HashSet.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/HashSet.lua index 76fda0b5..72564ed4 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/HashSet.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/HashSet.lua @@ -17,30 +17,15 @@ limitations under the License. local System = System local throw = System.throw local each = System.each -local Dictionary = System.Dictionary local wrap = System.wrap local unWrap = System.unWrap -local getEnumerator = Dictionary.GetEnumerator local ArgumentNullException = System.ArgumentNullException -local assert = assert -local pairs = pairs local select = select - -local counts = System.counts - -local function build(this, collection, comparer) - if comparer ~= nil then - assert(false) - end - if collection == nil then - throw(ArgumentNullException("collection")) - end - this:UnionWith(collection) -end +local setmetatable = setmetatable local function checkUniqueAndUnfoundElements(this, other, returnIfUnfound) - if #this == 0 then + if this:getCount() == 0 then local numElementsInOther = 0 for _, item in each(other) do numElementsInOther = numElementsInOther + 1 @@ -48,12 +33,11 @@ local function checkUniqueAndUnfoundElements(this, other, returnIfUnfound) end return 0, numElementsInOther end - local set, uniqueCount, unfoundCount = {}, 0, 0 + local set, uniqueCount, unfoundCount = this:newSet(), 0, 0 for _, item in each(other) do - item = wrap(item) - if this[item] ~= nil then - if set[item] == nil then - set[item] = true + if this:Contains(item) then + if not set:Contains(item) then + set:Add(item) uniqueCount = uniqueCount + 1 end else @@ -66,103 +50,92 @@ local function checkUniqueAndUnfoundElements(this, other, returnIfUnfound) return uniqueCount, unfoundCount end +local HashSetEnumerator = System.define("System.Collections.Generic.HashSetEnumerator", function (T) + return { + __genericT__ = T, + base = { System.IEnumerator_1(T) } + } +end, { + getCurrent = System.getCurrent, + Dispose = System.emptyFn, + MoveNext = function (this) + if this.en:MoveNext() then + local pair = this.en.current + this.current = unWrap(pair[1]) + return true + end + this.current = this.__genericT__:default() + return false + end +}, 1) + local HashSet = { __ctor__ = function (this, ...) - local len = select("#", ...) - if len == 0 then - elseif len == 1 then - local collection = ... - if collection == nil then return end - if collection.GetEnumerator ~= nil then - build(this, collection, nil) - else - assert(true) + local n = select("#", ...) + local collection, comparer + if n == 1 then + local c = ... + if type(c) == "table" then + if c.GetEnumerator then + collection = c + end + comparer = c end - else - build(this, ...) + elseif n == 2 then + collection, comparer = ... + end + this.dict = System.Dictionary(this.__genericT__, System.Boolean)(comparer) + if collection then + this:UnionWith(collection) end end, - Clear = Dictionary.Clear, - getCount = Dictionary.getCount, + newSet = function (this) + return System.HashSet(this.__genericT__)(this.dict.comparer) + end, + Clear = function (this) + this.dict:Clear() + end, + getCount = function (this) + return this.dict:getCount() + end, getIsReadOnly = System.falseFn, - Contains = function (this, item) - item = wrap(item) - return this[item] ~= nil + Contains = function (this, v) + v = wrap(v) + return this.dict:ContainsKey(v) end, - Remove = function (this, item) - item = wrap(item) - if this[item] then - this[item] = nil - local t = counts[this] - t[1] = t[1] - 1 - t[2] = t[2] + 1 - return true - end - return false + Remove = function (this, v) + v = wrap(v) + return this.dict:RemoveKey(v) end, GetEnumerator = function (this) - return getEnumerator(this, 1) + return setmetatable({ en = this.dict:GetEnumerator() }, HashSetEnumerator(this.__genericT__)) end, Add = function (this, v) v = wrap(v) - if this[v] == nil then - this[v] = true - local t = counts[this] - if t then - t[1] = t[1] + 1 - t[2] = t[2] + 1 - else - counts[this] = { 1, 1 } - end - return true - end - return false + return this.dict:TryAdd(v, true) end, UnionWith = function (this, other) if other == nil then throw(ArgumentNullException("other")) end - local count = 0 for _, v in each(other) do - v = wrap(v) - if this[v] == nil then - this[v] = true - count = count + 1 - end - end - if count > 0 then - local t = counts[this] - if t then - t[1] = t[1] + count - t[2] = t[2] + 1 - else - counts[this] = { count, 1 } - end + this:Add(v) end end, IntersectWith = function (this, other) if other == nil then throw(ArgumentNullException("other")) end - local set = {} - for _, v in each(other) do - v = wrap(v) - if this[v] ~= nil then - set[v] = true - end + if this == other or this:getCount() == 0 then + return end - local count = 0 - for v, _ in pairs(this) do - if set[v] == nil then - this[v] = nil - count = count + 1 + local set = this:newSet() + for _, v in each(other) do + if this:Contains(v) then + set:Add(v) end end - if count > 0 then - local t = counts[this] - t[1] = t[1] - count - t[2] = t[2] + 1 - end + this.dict = set.dict end, ExceptWith = function (this, other) if other == nil then @@ -172,49 +145,25 @@ local HashSet = { this:Clear() return end - local count = 0 for _, v in each(other) do - v = wrap(v) - if this[v] ~= nil then - this[v] = nil - count = count + 1 - end - end - if count > 0 then - local t = counts[this] - t[1] = t[1] - count - t[2] = t[2] + 1 + this:Remove(v) end end, SymmetricExceptWith = function (this, other) if other == nil then throw(ArgumentNullException("other")) end + if this:getCount() == 0 then + this:UnionWith(other) + return + end if other == this then this:Clear() return end - local set = {} - local count = 0 - local changed = false for _, v in each(other) do - v = wrap(v) - if this[v] == nil then - this[v] = true - count = count + 1 - changed = true - set[v] = true - elseif set[v] == nil then - this[v] = nil - count = count - 1 - changed = true - end - end - if changed then - local t = counts[this] - if t then - t[1] = t[1] + count - t[2] = t[2] + 1 + if this:Contains(v) then + this:Remove(v) else - counts[this] = { count, 1 } + this:Add(v) end end end, @@ -222,7 +171,7 @@ local HashSet = { if other == nil then throw(ArgumentNullException("other")) end - local count = #this + local count = this:getCount() if count == 0 then return true end @@ -241,8 +190,7 @@ local HashSet = { throw(ArgumentNullException("other")) end for _, element in each(other) do - element = wrap(element) - if this[element] == nil then + if not this:Contains(element) then return false end end @@ -252,7 +200,7 @@ local HashSet = { if other == nil then throw(ArgumentNullException("other")) end - local count = #this + local count = this:getCount() if count == 0 then return false end @@ -263,12 +211,11 @@ local HashSet = { if other == nil then throw(ArgumentNullException("other")) end - if #this == 0 then + if this:getCount() == 0 then return false end for _, element in each(other) do - element = wrap(element) - if this[element] ~= nil then + if this:Contains(element) then return true end end @@ -285,19 +232,9 @@ local HashSet = { if match == nil then throw(ArgumentNullException("match")) end - local numRemoved = 0 - for v, _ in pairs(this) do - if match(unWrap(v)) then - this[v] = nil - numRemoved = numRemoved + 1 - end - end - if numRemoved > 0 then - local t = counts[this] - t[1] = t[1] - numRemoved - t[2] = t[2] + 1 - end - return numRemoved + return this.dict:removeWhere(function (k, v) + return match(unWrap(k)) + end) end, TrimExcess = System.emptyFn } @@ -306,11 +243,10 @@ function System.hashSetFromTable(t, T) return setmetatable(t, HashSet(T)) end -System.HashSet = System.define("System.Collections.Generic.HashSet", function(T) - return { - base = { System.ICollection_1(T), System.ISet_1(T) }, +System.HashSet = System.define("System.Collections.Generic.HashSet", function(T) + return { + base = { System.ICollection_1(T), System.ISet_1(T) }, __genericT__ = T, __genericTKey__ = T, - __len = HashSet.getCount } end, HashSet, 1) diff --git a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua index 95184597..3c214ff9 100644 --- a/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua +++ b/CSharp.lua/CoreSystem.Lua/CoreSystem/Collections/Linq.lua @@ -280,22 +280,24 @@ local Grouping = define("System.Linq.Grouping", function (TKey, TElement) end, nil, 2) local function getGrouping(this, key, create) - local t = this.groups - local found, group = t:TryGetValue(key, true) - if found then return group end + local comparer = this.comparer + for i = 1, #this do + local group = this[i] + if comparer:EqualsOf(group.key, key) then + return group + end + end if create then - group = setmetatable({ key = key }, Grouping(this.__genericTKey__, this.__genericTElement__)) - t[#t + 1] = setmetatable({ key, group }, t.__genericT__) + local group = setmetatable({ key = key }, this.__genericT__) + this[#this + 1] = group + return group end - return group + return nil end local Lookup = { __ctor__ = function (this, comparer) - local TKey = this.__genericTKey__ - comparer = comparer or EqualityComparer(TKey).getDefault() - local G = Grouping(TKey, this.__genericTElement__) - this.groups = System.Dictionary(TKey, G)(comparer) + this.comparer = comparer or EqualityComparer(this.__genericTKey__).getDefault() end, get = function (this, key) local grouping = getGrouping(this, key) @@ -303,20 +305,19 @@ local Lookup = { return Empty(this.__genericTElement__) end, GetCount = function (this) - return #this.groups + return #this end, Contains = function (this, key) return getGrouping(this, key) ~= nil end, - GetEnumerator = function (this) - return this.groups:getValues():GetEnumerator() - end + GetEnumerator = Array.GetEnumerator, } local LookupFn = define("System.Linq.Lookup", function(TKey, TElement) local cls = { __genericTKey__ = TKey, __genericTElement__ = TElement, + __genericT__ = Grouping(TKey, TElement), } return cls end, Lookup, 2) @@ -613,7 +614,7 @@ function Enumerable.Zip(first, second, resultSelector, TResult) if e1:MoveNext() and e2:MoveNext() then return true, resultSelector(e1:getCurrent(), e2:getCurrent()) end - end, + end, function() e2 = second:GetEnumerator() end) @@ -653,7 +654,7 @@ function Enumerable.Distinct(source, comparer) while en:MoveNext() do local current = en:getCurrent() if addToSet(set, current, getHashCode, comparer) then - return true, current + return true, current end end return false @@ -1185,7 +1186,7 @@ function Enumerable.Aggregate(source, ...) result = func(result, element) end return result - else + else local seed, func, resultSelector = ... if func == nil then throw(ArgumentNullException("func")) end if resultSelector == nil then throw(ArgumentNullException("resultSelector")) end