Skip to content

Commit

Permalink
adding requestID back :(
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined committed Oct 24, 2024
1 parent 6674269 commit 9e6084c
Show file tree
Hide file tree
Showing 17 changed files with 314 additions and 391 deletions.
1 change: 1 addition & 0 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2693,6 +2693,7 @@ func (a *apiServer) createTrialTx(

trialModel := model.NewTrial(
model.PausedState,
model.RequestID{},
exp.ID,
req.Hparams.AsMap(),
nil,
Expand Down
18 changes: 9 additions & 9 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,13 @@ func (a *apiServer) KillTrial(
experiment.AuthZProvider.Get().CanEditExperiment); err != nil {
return nil, err
}
eID, err := a.m.db.ExperimentIDByTrialID(int(req.Id))
eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.Id))
if err != nil {
return nil, err
}

s := experiment.PatchTrialState{
TrialID: req.Id,
RequestID: rID,
State: model.StateWithReason{
State: model.StoppingKilledState,
InformationalReason: "user requested kill",
Expand Down Expand Up @@ -1383,7 +1383,7 @@ func (a *apiServer) ReportTrialSearcherEarlyExit(
experiment.AuthZProvider.Get().CanEditExperiment); err != nil {
return nil, err
}
eID, err := a.m.db.ExperimentIDByTrialID(int(req.TrialId))
eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.TrialId))
if err != nil {
return nil, err
}
Expand All @@ -1394,8 +1394,8 @@ func (a *apiServer) ReportTrialSearcherEarlyExit(
}

msg := experiment.UserInitiatedEarlyTrialExit{
TrialID: req.TrialId,
Reason: model.ExitedReasonFromProto(req.EarlyExit.Reason),
RequestID: rID,
Reason: model.ExitedReasonFromProto(req.EarlyExit.Reason),
}
if err := e.UserInitiatedEarlyTrialExit(msg); err != nil {
return nil, err
Expand All @@ -1415,7 +1415,7 @@ func (a *apiServer) ReportTrialProgress(
return nil, err
}

eID, err := a.m.db.ExperimentIDByTrialID(int(req.TrialId))
eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.TrialId))
if err != nil {
return nil, err
}
Expand All @@ -1433,7 +1433,7 @@ func (a *apiServer) ReportTrialProgress(
Progress: searcher.PartialUnits(req.Progress),
IsRaw: req.IsRaw,
}
if err := e.TrialReportProgress(req.TrialId, msg); err != nil {
if err := e.TrialReportProgress(rID, msg); err != nil {
return nil, err
}
return &apiv1.ReportTrialProgressResponse{}, nil
Expand All @@ -1456,15 +1456,15 @@ func (a *apiServer) ReportTrialMetrics(
}
if metricGroup == model.ValidationMetricGroup {
// Notify searcher of validation metrics.
eID, err := a.m.db.ExperimentIDByTrialID(int(req.Metrics.TrialId))
eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.Metrics.TrialId))
if err != nil {
return nil, errors.Errorf("Failed to get experiment ID from trial ID (%d)", req.Metrics.TrialId)
}
e, ok := experiment.ExperimentRegistry.Load(eID)
if !ok {
return nil, errors.Errorf("Failed to get experiment (%d) from experiment registry", eID)
}
err = e.TrialReportValidation(req.Metrics.TrialId, req.Metrics.Metrics.AvgMetrics.AsMap())
err = e.TrialReportValidation(rID, req.Metrics.Metrics.AvgMetrics.AsMap())
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type DB interface {
Migrate(migrationURL, codeURL string, actions []string) error
Close() error
GetOrCreateClusterID(telemetryID string) (string, error)
TrialExperimentAndRequestID(id int) (int, model.RequestID, error)
AddExperiment(experiment *model.Experiment, modelDef []byte, activeConfig expconf.ExperimentConfig) error
ExperimentIDByTrialID(trialID int) (int, error)
NonTerminalExperiments() ([]*model.Experiment, error)
Expand Down
19 changes: 19 additions & 0 deletions master/internal/db/postgres_experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,25 @@ SELECT experiment_id FROM trials where id = $1
return experimentID, nil
}

// TrialExperimentAndRequestID returns the trial's experiment and request ID.
func (db *PgDB) TrialExperimentAndRequestID(id int) (int, model.RequestID, error) {
var eID int
var rID model.RequestID
err := db.sql.QueryRow(`
SELECT e.id, t.request_id
FROM trials t, experiments e
WHERE t.experiment_id = e.id
AND t.id = $1`, id).Scan(&eID, &rID)
switch {
case err == sql.ErrNoRows:
return eID, rID, errors.WithStack(ErrNotFound)
case err != nil:
return eID, rID, errors.Wrap(err, "failed to get trial exp and req id")
default:
return eID, rID, nil
}
}

// NonTerminalExperiments finds all experiments in the database whose states are not terminal.
func (db *PgDB) NonTerminalExperiments() ([]*model.Experiment, error) {
rows, err := db.sql.Queryx(`
Expand Down
13 changes: 13 additions & 0 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -1094,3 +1094,16 @@ RETURNING true`, bun.In(uniqueExpIDs)).Scan(ctx, &res)

return nil
}

// TrialByExperimentAndRequestID looks up a trial, returning an error if none exists.
func TrialByExperimentAndRequestID(
ctx context.Context, experimentID int, requestID model.RequestID,
) (*model.Trial, error) {
t := &model.Trial{}
if err := Bun().NewSelect().Model(t).
Where("experiment_id = ?", experimentID).
Where("request_id = ?", requestID).Scan(ctx); err != nil {
return nil, fmt.Errorf("error querying for trial %s: %w", requestID, err)
}
return t, nil
}
Loading

0 comments on commit 9e6084c

Please sign in to comment.