forked from elastic/kibana
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enterprise Search] Add model management API logic (elastic#172120)
## Summary Adding parts of ML model management API logic: - Fetch models - Cached and pollable wrapper for model fetching - Create model - Start model These API logic pieces map to existing API endpoints and are currently unused. Their purpose is to enable one-click deployment of models within pipeline configuration. --------- Co-authored-by: kibanamachine <[email protected]>
- Loading branch information
1 parent
508e9da
commit 4ab4239
Showing
8 changed files
with
525 additions
and
0 deletions.
There are no files selected for viewing
229 changes: 229 additions & 0 deletions
229
...pplications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { LogicMounter } from '../../../__mocks__/kea_logic'; | ||
|
||
import { HttpError, Status } from '../../../../../common/types/api'; | ||
import { MlModelDeploymentState } from '../../../../../common/types/ml'; | ||
|
||
import { MlModel } from '../../../../../common/types/ml'; | ||
|
||
import { | ||
CachedFetchModelsApiLogic, | ||
CachedFetchModelsApiLogicValues, | ||
} from './cached_fetch_models_api_logic'; | ||
import { FetchModelsApiLogic } from './fetch_models_api_logic'; | ||
|
||
const DEFAULT_VALUES: CachedFetchModelsApiLogicValues = { | ||
data: [], | ||
isInitialLoading: false, | ||
isLoading: false, | ||
modelsData: null, | ||
pollTimeoutId: null, | ||
status: Status.IDLE, | ||
}; | ||
|
||
const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [ | ||
{ | ||
modelId: 'model_1', | ||
title: 'Model 1', | ||
type: 'ner', | ||
deploymentState: MlModelDeploymentState.NotDeployed, | ||
startTime: 0, | ||
targetAllocationCount: 0, | ||
nodeAllocationCount: 0, | ||
threadsPerAllocation: 0, | ||
isPlaceholder: false, | ||
hasStats: false, | ||
}, | ||
]; | ||
const FETCH_MODELS_API_ERROR_RESPONSE = { | ||
body: { | ||
error: 'Error while fetching models', | ||
message: 'Error while fetching models', | ||
statusCode: 500, | ||
}, | ||
} as HttpError; | ||
|
||
jest.useFakeTimers(); | ||
|
||
describe('TextExpansionCalloutLogic', () => { | ||
const { mount } = new LogicMounter(CachedFetchModelsApiLogic); | ||
const { mount: mountFetchModelsApiLogic } = new LogicMounter(FetchModelsApiLogic); | ||
|
||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
mountFetchModelsApiLogic(); | ||
mount(); | ||
}); | ||
|
||
describe('listeners', () => { | ||
describe('apiError', () => { | ||
it('sets new polling timeout if a timeout ID is already set', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.apiError(FETCH_MODELS_API_ERROR_RESPONSE); | ||
|
||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled(); | ||
}); | ||
}); | ||
|
||
describe('apiSuccess', () => { | ||
it('sets new polling timeout if a timeout ID is already set', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE); | ||
|
||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled(); | ||
}); | ||
}); | ||
|
||
describe('createPollTimeout', () => { | ||
const duration = 5000; | ||
it('clears polling timeout if it is set', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
jest.spyOn(global, 'clearTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.createPollTimeout(duration); | ||
|
||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id'); | ||
}); | ||
it('sets polling timeout', () => { | ||
jest.spyOn(global, 'setTimeout'); | ||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'setTimeoutId'); | ||
|
||
CachedFetchModelsApiLogic.actions.createPollTimeout(duration); | ||
|
||
expect(setTimeout).toHaveBeenCalledWith(expect.any(Function), duration); | ||
expect(CachedFetchModelsApiLogic.actions.setTimeoutId).toHaveBeenCalled(); | ||
}); | ||
}); | ||
|
||
describe('startPolling', () => { | ||
it('clears polling timeout if it is set', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
jest.spyOn(global, 'clearTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.startPolling(); | ||
|
||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id'); | ||
}); | ||
it('makes API request and sets polling timeout', () => { | ||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'makeRequest'); | ||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.startPolling(); | ||
|
||
expect(CachedFetchModelsApiLogic.actions.makeRequest).toHaveBeenCalled(); | ||
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled(); | ||
}); | ||
}); | ||
|
||
describe('stopPolling', () => { | ||
it('clears polling timeout if it is set', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
jest.spyOn(global, 'clearTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.stopPolling(); | ||
|
||
expect(clearTimeout).toHaveBeenCalledWith('timeout-id'); | ||
}); | ||
it('clears polling timeout value', () => { | ||
jest.spyOn(CachedFetchModelsApiLogic.actions, 'clearPollTimeout'); | ||
|
||
CachedFetchModelsApiLogic.actions.stopPolling(); | ||
|
||
expect(CachedFetchModelsApiLogic.actions.clearPollTimeout).toHaveBeenCalled(); | ||
}); | ||
}); | ||
}); | ||
|
||
describe('reducers', () => { | ||
describe('modelsData', () => { | ||
it('gets cleared on API reset', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
modelsData: [], | ||
}); | ||
|
||
CachedFetchModelsApiLogic.actions.apiReset(); | ||
|
||
expect(CachedFetchModelsApiLogic.values.modelsData).toBe(null); | ||
}); | ||
it('gets set on API success', () => { | ||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE); | ||
|
||
expect(CachedFetchModelsApiLogic.values.modelsData).toEqual(FETCH_MODELS_API_DATA_RESPONSE); | ||
}); | ||
}); | ||
|
||
describe('pollTimeoutId', () => { | ||
it('gets cleared on clear timeout action', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
pollTimeoutId: 'timeout-id', | ||
}); | ||
|
||
CachedFetchModelsApiLogic.actions.clearPollTimeout(); | ||
|
||
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toBe(null); | ||
}); | ||
it('gets set on set timeout action', () => { | ||
const timeout = setTimeout(() => {}, 500); | ||
|
||
CachedFetchModelsApiLogic.actions.setTimeoutId(timeout); | ||
|
||
expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toEqual(timeout); | ||
}); | ||
}); | ||
}); | ||
|
||
describe('selectors', () => { | ||
describe('isInitialLoading', () => { | ||
it('true if API is idle', () => { | ||
mount(DEFAULT_VALUES); | ||
|
||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true); | ||
}); | ||
it('true if API is loading for the first time', () => { | ||
mount({ | ||
...DEFAULT_VALUES, | ||
status: Status.LOADING, | ||
}); | ||
|
||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true); | ||
}); | ||
it('false if the API is neither idle nor loading', () => { | ||
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE); | ||
|
||
expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(false); | ||
}); | ||
}); | ||
}); | ||
}); |
125 changes: 125 additions & 0 deletions
125
...lic/applications/enterprise_search_content/api/ml_models/cached_fetch_models_api_logic.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { kea, MakeLogicType } from 'kea'; | ||
|
||
import { isEqual } from 'lodash'; | ||
|
||
import { Status } from '../../../../../common/types/api'; | ||
import { MlModel } from '../../../../../common/types/ml'; | ||
import { Actions } from '../../../shared/api_logic/create_api_logic'; | ||
|
||
import { FetchModelsApiLogic, FetchModelsApiResponse } from './fetch_models_api_logic'; | ||
|
||
const FETCH_MODELS_POLLING_DURATION = 5000; // 5 seconds | ||
const FETCH_MODELS_POLLING_DURATION_ON_FAILURE = 30000; // 30 seconds | ||
|
||
export interface CachedFetchModlesApiLogicActions { | ||
apiError: Actions<{}, FetchModelsApiResponse>['apiError']; | ||
apiReset: Actions<{}, FetchModelsApiResponse>['apiReset']; | ||
apiSuccess: Actions<{}, FetchModelsApiResponse>['apiSuccess']; | ||
clearPollTimeout(): void; | ||
createPollTimeout(duration: number): { duration: number }; | ||
makeRequest: Actions<{}, FetchModelsApiResponse>['makeRequest']; | ||
setTimeoutId(id: NodeJS.Timeout): { id: NodeJS.Timeout }; | ||
startPolling(): void; | ||
stopPolling(): void; | ||
} | ||
|
||
export interface CachedFetchModelsApiLogicValues { | ||
data: FetchModelsApiResponse; | ||
isInitialLoading: boolean; | ||
isLoading: boolean; | ||
modelsData: MlModel[] | null; | ||
pollTimeoutId: NodeJS.Timeout | null; | ||
status: Status; | ||
} | ||
|
||
export const CachedFetchModelsApiLogic = kea< | ||
MakeLogicType<CachedFetchModelsApiLogicValues, CachedFetchModlesApiLogicActions> | ||
>({ | ||
actions: { | ||
clearPollTimeout: true, | ||
createPollTimeout: (duration) => ({ duration }), | ||
setTimeoutId: (id) => ({ id }), | ||
startPolling: true, | ||
stopPolling: true, | ||
}, | ||
connect: { | ||
actions: [FetchModelsApiLogic, ['apiSuccess', 'apiError', 'apiReset', 'makeRequest']], | ||
values: [FetchModelsApiLogic, ['data', 'status']], | ||
}, | ||
events: ({ values }) => ({ | ||
beforeUnmount: () => { | ||
if (values.pollTimeoutId) { | ||
clearTimeout(values.pollTimeoutId); | ||
} | ||
}, | ||
}), | ||
listeners: ({ actions, values }) => ({ | ||
apiError: () => { | ||
if (values.pollTimeoutId) { | ||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION_ON_FAILURE); | ||
} | ||
}, | ||
apiSuccess: () => { | ||
if (values.pollTimeoutId) { | ||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION); | ||
} | ||
}, | ||
createPollTimeout: ({ duration }) => { | ||
if (values.pollTimeoutId) { | ||
clearTimeout(values.pollTimeoutId); | ||
} | ||
|
||
const timeoutId = setTimeout(() => { | ||
actions.makeRequest({}); | ||
}, duration); | ||
actions.setTimeoutId(timeoutId); | ||
}, | ||
startPolling: () => { | ||
if (values.pollTimeoutId) { | ||
clearTimeout(values.pollTimeoutId); | ||
} | ||
actions.makeRequest({}); | ||
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION); | ||
}, | ||
stopPolling: () => { | ||
if (values.pollTimeoutId) { | ||
clearTimeout(values.pollTimeoutId); | ||
} | ||
actions.clearPollTimeout(); | ||
}, | ||
}), | ||
path: ['enterprise_search', 'content', 'api', 'fetch_models_api_wrapper'], | ||
reducers: { | ||
modelsData: [ | ||
null, | ||
{ | ||
apiReset: () => null, | ||
apiSuccess: (currentState, newState) => | ||
isEqual(currentState, newState) ? currentState : newState, | ||
}, | ||
], | ||
pollTimeoutId: [ | ||
null, | ||
{ | ||
clearPollTimeout: () => null, | ||
setTimeoutId: (_, { id }) => id, | ||
}, | ||
], | ||
}, | ||
selectors: ({ selectors }) => ({ | ||
isInitialLoading: [ | ||
() => [selectors.status, selectors.modelsData], | ||
( | ||
status: CachedFetchModelsApiLogicValues['status'], | ||
modelsData: CachedFetchModelsApiLogicValues['modelsData'] | ||
) => status === Status.IDLE || (modelsData === null && status === Status.LOADING), | ||
], | ||
}), | ||
}); |
30 changes: 30 additions & 0 deletions
30
...ublic/applications/enterprise_search_content/api/ml_models/create_model_api_logic.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { mockHttpValues } from '../../../__mocks__/kea_logic'; | ||
|
||
import { nextTick } from '@kbn/test-jest-helpers'; | ||
|
||
import { createModel } from './create_model_api_logic'; | ||
|
||
describe('CreateModelApiLogic', () => { | ||
const { http } = mockHttpValues; | ||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
}); | ||
describe('createModel', () => { | ||
it('calls correct api', async () => { | ||
const mockResponseBody = { modelId: 'model_1', deploymentState: '' }; | ||
http.post.mockReturnValue(Promise.resolve(mockResponseBody)); | ||
|
||
const result = createModel({ modelId: 'model_1' }); | ||
await nextTick(); | ||
expect(http.post).toHaveBeenCalledWith('/internal/enterprise_search/ml/models/model_1'); | ||
await expect(result).resolves.toEqual(mockResponseBody); | ||
}); | ||
}); | ||
}); |
Oops, something went wrong.