This is an automated email from the ASF dual-hosted git repository. yiconghuang pushed a commit to branch feat/add-model in repository https://gitbox.apache.org/repos/asf/texera.git
commit 816b7524889d242608133f0095ebfdbc63585d8d Author: Yicong Huang <[email protected]> AuthorDate: Wed Aug 27 02:13:10 2025 -0700 feat: working version for load model --- core/gui/src/app/app.module.ts | 6 + core/gui/src/app/common/formly/formly-config.ts | 4 + .../dashboard/component/dashboard.component.html | 2 +- .../input-auto-complete-model.component.html | 42 +++++++ .../input-auto-complete-model.component.scss | 50 ++++++++ .../input-auto-complete-model.component.spec.ts | 49 ++++++++ .../input-auto-complete-model.component.ts | 81 +++++++++++++ .../model-selection/model-selection.component.html | 77 +++++++++++++ .../model-selection/model-selection.component.scss | 89 ++++++++++++++ .../model-selection/model-selection.component.ts | 128 +++++++++++++++++++++ .../operator-property-edit-frame.component.ts | 6 + core/gui/src/assets/operator_images/LoadModel.png | Bin 0 -> 2625 bytes .../edu/uci/ics/amber/operator/LogicalOp.scala | 51 ++------ .../amber/operator/loadModel/LoadModelOpDesc.scala | 75 ++++++++++++ 14 files changed, 615 insertions(+), 45 deletions(-) diff --git a/core/gui/src/app/app.module.ts b/core/gui/src/app/app.module.ts index 9b6f79e61b..acf9803b9c 100644 --- a/core/gui/src/app/app.module.ts +++ b/core/gui/src/app/app.module.ts @@ -182,6 +182,10 @@ import { import { UserModelStagedObjectsListComponent } from "./dashboard/component/user/user-model/user-dataset-explorer/user-model-staged-objects-list/user-model-staged-objects-list.component"; +import { + InputAutoCompleteModelComponent +} from "./workspace/component/input-autocomplete-model/input-auto-complete-model.component"; +import { ModelSelectionComponent } from "./workspace/component/model-selection/model-selection.component"; registerLocaleData(en); @@ -261,6 +265,8 @@ registerLocaleData(en); SortButtonComponent, FiltersComponent, FiltersInstructionsComponent, + InputAutoCompleteModelComponent, + ModelSelectionComponent, SearchComponent, SearchResultsComponent, PortPropertyEditFrameComponent, diff --git a/core/gui/src/app/common/formly/formly-config.ts b/core/gui/src/app/common/formly/formly-config.ts index d950bd3690..245b79fd89 100644 --- a/core/gui/src/app/common/formly/formly-config.ts +++ b/core/gui/src/app/common/formly/formly-config.ts @@ -27,6 +27,9 @@ import { PresetWrapperComponent } from "./preset-wrapper/preset-wrapper.componen import { InputAutoCompleteComponent } from "../../workspace/component/input-autocomplete/input-autocomplete.component"; import { CollabWrapperComponent } from "./collab-wrapper/collab-wrapper/collab-wrapper.component"; import { FormlyRepeatDndComponent } from "./repeat-dnd/repeat-dnd.component"; +import { + InputAutoCompleteModelComponent +} from "../../workspace/component/input-autocomplete-model/input-auto-complete-model.component"; /** * Configuration for using Json Schema with Formly. @@ -77,6 +80,7 @@ export const TEXERA_FORMLY_CONFIG = { { name: "multischema", component: MultiSchemaTypeComponent }, { name: "codearea", component: CodeareaCustomTemplateComponent }, { name: "inputautocomplete", component: InputAutoCompleteComponent, wrappers: ["form-field"] }, + { name: "inputautocompleteModel", component: InputAutoCompleteModelComponent, wrappers: ["form-field"] }, { name: "repeat-section-dnd", component: FormlyRepeatDndComponent }, ], wrappers: [ diff --git a/core/gui/src/app/dashboard/component/dashboard.component.html b/core/gui/src/app/dashboard/component/dashboard.component.html index 92cae29991..604e5a1599 100644 --- a/core/gui/src/app/dashboard/component/dashboard.component.html +++ b/core/gui/src/app/dashboard/component/dashboard.component.html @@ -106,7 +106,7 @@ [routerLink]="DASHBOARD_USER_MODEL"> <span nz-icon - nzType="database"></span> + nzType="code-sandbox"></span> <span>Models</span> </li> <li diff --git a/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.html b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.html new file mode 100644 index 0000000000..e1354968db --- /dev/null +++ b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.html @@ -0,0 +1,42 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<div [ngClass]="{'input-autocomplete-container': isFileSelectionEnabled}"> + <input + *ngIf="selectedFilePath || !isFileSelectionEnabled" + nz-input + [readOnly]="isFileSelectionEnabled" + required + [formControl]="formControl" + [formlyAttributes]="field" /> + <button + *ngIf="isFileSelectionEnabled" + nz-button + class="file-select-button" + nzSize="small" + (click)="isFileSelectionEnabled && onClickOpenFileSelectionModal()"> + {{ selectedFilePath ? 'Reselect Model' : 'Select Model' }} + </button> +</div> +<div + class="alert alert-danger" + role="alert" + *ngIf="props.showError && formControl.errors"> + <formly-validation-message [field]="field"></formly-validation-message> +</div> diff --git a/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.scss b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.scss new file mode 100644 index 0000000000..cc76590399 --- /dev/null +++ b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.scss @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +mat-form-field { + width: 100%; +} +.input-autocomplete-container { + display: flex; + align-items: center; + width: 100%; + + input { + flex: 1; + margin-right: 10px; + } + + button { + white-space: nowrap; + } + + .file-select-button { + border: 2px solid #1890ff; + color: #1890ff; + + &:hover { + background-color: #e6f7ff; + border-color: #1890ff; + } + + &:focus { + box-shadow: 0 0 0 2px rgba(24, 144, 255, 0.2); + } + } +} diff --git a/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.spec.ts b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.spec.ts new file mode 100644 index 0000000000..eba11fafd1 --- /dev/null +++ b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.spec.ts @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ComponentFixture, TestBed, waitForAsync } from "@angular/core/testing"; +import { FormControl, ReactiveFormsModule } from "@angular/forms"; +import { InputAutoCompleteModelComponent } from "./input-auto-complete-model.component"; +import { HttpClientTestingModule } from "@angular/common/http/testing"; +import { NzModalService } from "ng-zorro-antd/modal"; +import { commonTestProviders } from "../../../common/testing/test-utils"; + +describe("InputAutoCompleteComponent", () => { + let component: InputAutoCompleteModelComponent; + let fixture: ComponentFixture<InputAutoCompleteModelComponent>; + + beforeEach(waitForAsync(() => { + TestBed.configureTestingModule({ + declarations: [InputAutoCompleteModelComponent], + imports: [ReactiveFormsModule, HttpClientTestingModule], + providers: [NzModalService, ...commonTestProviders], + }).compileComponents(); + })); + + beforeEach(() => { + fixture = TestBed.createComponent(InputAutoCompleteModelComponent); + component = fixture.componentInstance; + component.field = { props: {}, formControl: new FormControl() }; + fixture.detectChanges(); + }); + + it("should create", () => { + expect(component).toBeTruthy(); + }); +}); diff --git a/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.ts b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.ts new file mode 100644 index 0000000000..a13fb1d4fc --- /dev/null +++ b/core/gui/src/app/workspace/component/input-autocomplete-model/input-auto-complete-model.component.ts @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component } from "@angular/core"; +import { FieldType, FieldTypeConfig } from "@ngx-formly/core"; +import { UntilDestroy, untilDestroyed } from "@ngneat/until-destroy"; +import { WorkflowActionService } from "../../service/workflow-graph/model/workflow-action.service"; +import { NzModalService } from "ng-zorro-antd/modal"; +import { DatasetFileNode, getFullPathFromDatasetFileNode } from "../../../common/type/datasetVersionFileTree"; +import { GuiConfigService } from "../../../common/service/gui-config.service"; +import { ModelSelectionComponent } from "../model-selection/model-selection.component"; + +@UntilDestroy() +@Component({ + selector: "texera-input-autocomplete-model-template", + templateUrl: "./input-auto-complete-model.component.html", + styleUrls: ["input-auto-complete-model.component.scss"], +}) +export class InputAutoCompleteModelComponent extends FieldType<FieldTypeConfig> { + constructor( + private modalService: NzModalService, + public workflowActionService: WorkflowActionService, + private config: GuiConfigService, + ) { + super(); + } + + onClickOpenFileSelectionModal(): void { + const modal = this.modalService.create({ + nzTitle: "Please select one model", + nzContent: ModelSelectionComponent, + nzFooter: null, + nzData: { + selectedFilePath: this.formControl.getRawValue(), + }, + nzBodyStyle: { + // Enables the file selection window to be resizable + resize: "both", + overflow: "auto", + minHeight: "200px", + minWidth: "550px", + maxWidth: "90vw", + maxHeight: "80vh", + }, + nzWidth: "fit-content", + }); + // Handle the selection from the modal + modal.afterClose.pipe(untilDestroyed(this)).subscribe(fileNode => { + const node: DatasetFileNode = fileNode as DatasetFileNode; + this.formControl.setValue(getFullPathFromDatasetFileNode(node)); + }); + } + + get enableDatasetSource(): boolean { + return this.config.env.userSystemEnabled && this.config.env.selectingFilesFromDatasetsEnabled; + } + + get isFileSelectionEnabled(): boolean { + return this.enableDatasetSource; + } + + get selectedFilePath(): string | null { + return this.formControl.value; + } +} diff --git a/core/gui/src/app/workspace/component/model-selection/model-selection.component.html b/core/gui/src/app/workspace/component/model-selection/model-selection.component.html new file mode 100644 index 0000000000..dd06885693 --- /dev/null +++ b/core/gui/src/app/workspace/component/model-selection/model-selection.component.html @@ -0,0 +1,77 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<nz-spin [nzSpinning]="isAccessibleModelsLoading"> + <div class="container"> + <div class="selection-row"> + <nz-select + nzShowSearch + nzAllowClear + nzPlaceHolder="Select model" + [(ngModel)]="selectedModel" + (ngModelChange)="onDatasetChange()" + [ngStyle]="{'width': isModelSelected ? '60%' : '100%'}" + class="select-dataset"> + <nz-option + *ngFor="let model of models" + [nzValue]="model" + [nzLabel]="model.model.name" + [nzCustomContent]="true"> + <div class="dataset-option"> + <div class="dataset-id-container">{{ model.model.mid }}</div> + <span class="dataset-name">{{ model.model.name }}</span> + <span + class="dataset-owner" + *ngIf="model.isOwner"> + OWNER + </span> + <span + class="dataset-access-privilege" + *ngIf="!model.isOwner"> + {{ model.accessPrivilege }} + </span> + </div> + </nz-option> + </nz-select> + + <nz-select + *ngIf="selectedModel" + nzShowSearch + nzAllowClear + nzPlaceHolder="Select version" + [(ngModel)]="selectedVersion" + (ngModelChange)="onVersionChange()" + class="select-version"> + <nz-option + *ngFor="let version of modelVersions" + [nzValue]="version" + [nzLabel]="version.name"> + </nz-option> + </nz-select> + </div> + + <texera-user-dataset-version-filetree + *ngIf="suggestedFileTreeNodes.length > 0" + [isExpandAllAfterViewInit]="true" + [fileTreeNodes]="suggestedFileTreeNodes" + (selectedTreeNode)="onFileTreeNodeSelected($event)" + class="texera-user-dataset-version-filetree"> + </texera-user-dataset-version-filetree> + </div> +</nz-spin> diff --git a/core/gui/src/app/workspace/component/model-selection/model-selection.component.scss b/core/gui/src/app/workspace/component/model-selection/model-selection.component.scss new file mode 100644 index 0000000000..0e7ceadbf9 --- /dev/null +++ b/core/gui/src/app/workspace/component/model-selection/model-selection.component.scss @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +:host { + display: block; + padding: 16px; +} + +.container { + display: flex; + flex-direction: column; + gap: 16px; +} + +.selection-row { + display: flex; + gap: 16px; +} + +.nz-select { + width: 100%; +} + +.select-dataset { + transition: width 0.3s; /* Add animation effect */ +} + +.select-version { + flex: 1; + min-width: 0; +} + +.texera-user-dataset-version-filetree { + margin-top: 16px; +} + +.dataset-option { + display: flex; + align-items: center; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.dataset-name { + display: inline-block; + max-width: calc(100% - 100px); /* Adjust to fit other elements */ + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.dataset-id-container { + background-color: grey; + color: white; + width: 23px; + height: 23px; + border-radius: 50%; + display: flex; + justify-content: center; + align-items: center; + font-size: 12px; + margin-right: 5px; + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); + overflow: hidden; +} + +.dataset-owner, +.dataset-access-privilege { + margin-left: 10px; + font-size: 12px; + color: grey; +} diff --git a/core/gui/src/app/workspace/component/model-selection/model-selection.component.ts b/core/gui/src/app/workspace/component/model-selection/model-selection.component.ts new file mode 100644 index 0000000000..d7981a29bf --- /dev/null +++ b/core/gui/src/app/workspace/component/model-selection/model-selection.component.ts @@ -0,0 +1,128 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, inject, OnInit } from "@angular/core"; +import { NZ_MODAL_DATA, NzModalRef } from "ng-zorro-antd/modal"; +import { UntilDestroy, untilDestroyed } from "@ngneat/until-destroy"; +import { DatasetFileNode } from "../../../common/type/datasetVersionFileTree"; + +import { DashboardModel } from "../../../dashboard/type/dashboard-model.interface"; +import { ModelService } from "../../../dashboard/service/user/model/model.service"; +import { parseFilePathToDatasetFile } from "../../../common/type/dataset-file"; +import { ModelVersion } from "../../../common/type/model"; + +@UntilDestroy() +@Component({ + selector: "texera-model-selection-model", + templateUrl: "model-selection.component.html", + styleUrls: ["model-selection.component.scss"], +}) +export class ModelSelectionComponent implements OnInit { + readonly selectedFilePath: string = inject(NZ_MODAL_DATA).selectedFilePath; + private _models: ReadonlyArray<DashboardModel> = []; + + // indicate whether the accessible datasets have been loaded from the backend + isAccessibleModelsLoading = true; + + selectedModel?: DashboardModel; + selectedVersion?: ModelVersion; + modelVersions?: ModelVersion[]; + suggestedFileTreeNodes: DatasetFileNode[] = []; + isModelSelected: boolean = false; + + constructor( + private modalRef: NzModalRef, + private modelService: ModelService + ) {} + + ngOnInit() { + this.isAccessibleModelsLoading = true; + + // retrieve all the accessible models from the backend + this.modelService + .retrieveAccessibleModels() + .pipe(untilDestroyed(this)) + .subscribe(models => { + this._models = models; + this.isAccessibleModelsLoading = false; + if (!this.selectedFilePath || this.selectedFilePath == "") { + return; + } + // if users already select some file, then ONLY show that selected dataset & related version + const selectedDatasetFile = parseFilePathToDatasetFile(this.selectedFilePath); + this.selectedModel = this.models.find( + d => d.ownerEmail === selectedDatasetFile.ownerEmail && d.model.name === selectedDatasetFile.datasetName + ); + this.isModelSelected = !!this.selectedModel; + if (this.selectedModel && this.selectedModel.model.mid !== undefined) { + this.modelService + .retrieveModelVersionList(this.selectedModel.model.mid) + .pipe(untilDestroyed(this)) + .subscribe(versions => { + this.modelVersions = versions; + this.selectedVersion = this.modelVersions.find(v => v.name === selectedDatasetFile.versionName); + this.onVersionChange(); + }); + } + }); + } + + onDatasetChange() { + this.selectedVersion = undefined; + this.suggestedFileTreeNodes = []; + this.isModelSelected = !!this.selectedModel; + if (this.selectedModel && this.selectedModel.model.mid !== undefined) { + this.modelService + .retrieveModelVersionList(this.selectedModel.model.mid) + .pipe(untilDestroyed(this)) + .subscribe(versions => { + this.modelVersions = versions; + if (this.modelVersions && this.modelVersions.length > 0) { + this.selectedVersion = this.modelVersions[0]; + this.onVersionChange(); + } + }); + } + } + + onVersionChange() { + this.suggestedFileTreeNodes = []; + if ( + this.selectedModel && + this.selectedModel.model.mid !== undefined && + this.selectedVersion && + this.selectedVersion.mvid !== undefined + ) { + this.modelService + .retrieveModelVersionFileTree(this.selectedModel.model.mid, this.selectedVersion.mvid) + .pipe(untilDestroyed(this)) + .subscribe(data => { + this.suggestedFileTreeNodes = data.fileNodes; + }); + } + } + + onFileTreeNodeSelected(node: DatasetFileNode) { + this.modalRef.close(node); + } + + get models(): ReadonlyArray<DashboardModel> { + return this._models; + } +} diff --git a/core/gui/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts b/core/gui/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts index 8cc463ff8f..9ef8a3fbb1 100644 --- a/core/gui/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts +++ b/core/gui/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts @@ -453,6 +453,12 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On mappedField.type = "inputautocomplete"; } + + // if the title is fileName, then change it to custom autocomplete input template + if (mappedField.key == "modelPath" || mappedField.key == "model path") { + mappedField.type = "inputautocompleteModel"; + } + // if the title is python script (for Python UDF), then make this field a custom template 'codearea' if (mapSource?.description?.toLowerCase() === "input your code here") { if (mappedField.type) { diff --git a/core/gui/src/assets/operator_images/LoadModel.png b/core/gui/src/assets/operator_images/LoadModel.png new file mode 100644 index 0000000000..56911e250b Binary files /dev/null and b/core/gui/src/assets/operator_images/LoadModel.png differ diff --git a/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/LogicalOp.scala b/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/LogicalOp.scala index 95bb2f87d6..0da6868e3e 100644 --- a/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/LogicalOp.scala +++ b/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/LogicalOp.scala @@ -24,11 +24,7 @@ import com.fasterxml.jackson.annotation._ import com.kjetland.jackson.jsonSchema.annotations.JsonSchemaTitle import edu.uci.ics.amber.core.executor.OperatorExecutor import edu.uci.ics.amber.core.tuple.Schema -import edu.uci.ics.amber.core.virtualidentity.{ - ExecutionIdentity, - OperatorIdentity, - WorkflowIdentity -} +import edu.uci.ics.amber.core.virtualidentity.{ExecutionIdentity, OperatorIdentity, WorkflowIdentity} import edu.uci.ics.amber.core.workflow.WorkflowContext.{DEFAULT_EXECUTION_ID, DEFAULT_WORKFLOW_ID} import edu.uci.ics.amber.core.workflow.{PhysicalOp, PhysicalPlan, PortIdentity} import edu.uci.ics.amber.operator.aggregate.AggregateOpDesc @@ -39,22 +35,15 @@ import edu.uci.ics.amber.operator.distinct.DistinctOpDesc import edu.uci.ics.amber.operator.dummy.DummyOpDesc import edu.uci.ics.amber.operator.filter.SpecializedFilterOpDesc import edu.uci.ics.amber.operator.hashJoin.HashJoinOpDesc -import edu.uci.ics.amber.operator.huggingFace.{ - HuggingFaceIrisLogisticRegressionOpDesc, - HuggingFaceSentimentAnalysisOpDesc, - HuggingFaceSpamSMSDetectionOpDesc, - HuggingFaceTextSummarizationOpDesc -} +import edu.uci.ics.amber.operator.huggingFace.{HuggingFaceIrisLogisticRegressionOpDesc, HuggingFaceSentimentAnalysisOpDesc, HuggingFaceSpamSMSDetectionOpDesc, HuggingFaceTextSummarizationOpDesc} import edu.uci.ics.amber.operator.ifStatement.IfOpDesc import edu.uci.ics.amber.operator.intersect.IntersectOpDesc import edu.uci.ics.amber.operator.intervalJoin.IntervalJoinOpDesc import edu.uci.ics.amber.operator.keywordSearch.KeywordSearchOpDesc import edu.uci.ics.amber.operator.limit.LimitOpDesc +import edu.uci.ics.amber.operator.loadModel.LoadModelOpDesc import edu.uci.ics.amber.operator.machineLearning.Scorer.MachineLearningScorerOpDesc -import edu.uci.ics.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer.{ - SklearnAdvancedKNNClassifierTrainerOpDesc, - SklearnAdvancedKNNRegressorTrainerOpDesc -} +import edu.uci.ics.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer.{SklearnAdvancedKNNClassifierTrainerOpDesc, SklearnAdvancedKNNRegressorTrainerOpDesc} import edu.uci.ics.amber.operator.machineLearning.sklearnAdvanced.SVCTrainer.SklearnAdvancedSVCTrainerOpDesc import edu.uci.ics.amber.operator.machineLearning.sklearnAdvanced.SVRTrainer.SklearnAdvancedSVRTrainerOpDesc import edu.uci.ics.amber.operator.metadata.{OPVersion, OperatorInfo, PropertyNameConstants} @@ -64,38 +53,11 @@ import edu.uci.ics.amber.operator.regex.RegexOpDesc import edu.uci.ics.amber.operator.reservoirsampling.ReservoirSamplingOpDesc import edu.uci.ics.amber.operator.sklearn._ import edu.uci.ics.amber.operator.sleep.SleepOpDesc -import edu.uci.ics.amber.operator.sklearn.training.{ - SklearnTrainingAdaptiveBoostingOpDesc, - SklearnTrainingBaggingOpDesc, - SklearnTrainingBernoulliNaiveBayesOpDesc, - SklearnTrainingComplementNaiveBayesOpDesc, - SklearnTrainingDecisionTreeOpDesc, - SklearnTrainingDummyClassifierOpDesc, - SklearnTrainingExtraTreeOpDesc, - SklearnTrainingExtraTreesOpDesc, - SklearnTrainingGaussianNaiveBayesOpDesc, - SklearnTrainingGradientBoostingOpDesc, - SklearnTrainingKNNOpDesc, - SklearnTrainingLinearSVMOpDesc, - SklearnTrainingMultiLayerPerceptronOpDesc, - SklearnTrainingMultinomialNaiveBayesOpDesc, - SklearnTrainingNearestCentroidOpDesc, - SklearnTrainingPassiveAggressiveOpDesc, - SklearnTrainingPerceptronOpDesc, - SklearnTrainingProbabilityCalibrationOpDesc, - SklearnTrainingRandomForestOpDesc, - SklearnTrainingRidgeCVOpDesc, - SklearnTrainingRidgeOpDesc, - SklearnTrainingSDGOpDesc, - SklearnTrainingSVMOpDesc -} +import edu.uci.ics.amber.operator.sklearn.training.{SklearnTrainingAdaptiveBoostingOpDesc, SklearnTrainingBaggingOpDesc, SklearnTrainingBernoulliNaiveBayesOpDesc, SklearnTrainingComplementNaiveBayesOpDesc, SklearnTrainingDecisionTreeOpDesc, SklearnTrainingDummyClassifierOpDesc, SklearnTrainingExtraTreeOpDesc, SklearnTrainingExtraTreesOpDesc, SklearnTrainingGaussianNaiveBayesOpDesc, SklearnTrainingGradientBoostingOpDesc, SklearnTrainingKNNOpDesc, SklearnTrainingLinearSVMOpDesc, SklearnTra [...] import edu.uci.ics.amber.operator.sort.SortOpDesc import edu.uci.ics.amber.operator.sortPartitions.SortPartitionsOpDesc import edu.uci.ics.amber.operator.source.apis.reddit.RedditSearchSourceOpDesc -import edu.uci.ics.amber.operator.source.apis.twitter.v2.{ - TwitterFullArchiveSearchSourceOpDesc, - TwitterSearchSourceOpDesc -} +import edu.uci.ics.amber.operator.source.apis.twitter.v2.{TwitterFullArchiveSearchSourceOpDesc, TwitterSearchSourceOpDesc} import edu.uci.ics.amber.operator.source.fetcher.URLFetcherOpDesc import edu.uci.ics.amber.operator.source.scan.FileScanSourceOpDesc import edu.uci.ics.amber.operator.source.scan.arrow.ArrowSourceOpDesc @@ -266,6 +228,7 @@ trait StateTransferFunc new Type(value = classOf[ArrowSourceOpDesc], name = "ArrowSource"), new Type(value = classOf[MachineLearningScorerOpDesc], name = "Scorer"), new Type(value = classOf[SortOpDesc], name = "Sort"), + new Type(value = classOf[LoadModelOpDesc], name = "LoadModel"), new Type(value = classOf[SklearnLogisticRegressionOpDesc], name = "SklearnLogisticRegression"), new Type( value = classOf[SklearnLogisticRegressionCVOpDesc], diff --git a/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/loadModel/LoadModelOpDesc.scala b/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/loadModel/LoadModelOpDesc.scala new file mode 100644 index 0000000000..9edb29b048 --- /dev/null +++ b/core/workflow-operator/src/main/scala/edu/uci/ics/amber/operator/loadModel/LoadModelOpDesc.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package edu.uci.ics.amber.operator.loadModel + +import com.fasterxml.jackson.annotation.{JsonProperty, JsonPropertyDescription} +import com.kjetland.jackson.jsonSchema.annotations.JsonSchemaTitle +import edu.uci.ics.amber.core.tuple.{AttributeType, Schema} +import edu.uci.ics.amber.core.workflow.{OutputPort, PortIdentity} +import edu.uci.ics.amber.operator.PythonOperatorDescriptor +import edu.uci.ics.amber.operator.metadata.{OperatorGroupConstants, OperatorInfo} + +class LoadModelOpDesc extends PythonOperatorDescriptor { + + @JsonProperty(required = true) + @JsonSchemaTitle("modelPath") + @JsonPropertyDescription("The model to load") + var modelPath: String = "" + + override def getOutputSchemas( + inputSchemas: Map[PortIdentity, Schema] + ): Map[PortIdentity, Schema] = { + val outputSchema = Schema() + .add("model", AttributeType.BINARY) + Map(operatorInfo.outputPorts.head.id -> outputSchema) + } + + override def operatorInfo: OperatorInfo = + OperatorInfo( + "Load Model", + "Loads a machine learning model from the specified path", + OperatorGroupConstants.MACHINE_LEARNING_GENERAL_GROUP, + inputPorts = List(), + outputPorts = List(OutputPort()) + ) + + override def generatePythonCode(): String = { + s"""from pytexera import * + |import tensorflow as tf + |import tempfile + |import os + |class GenerateOperator(UDFSourceOperator): + | + | @overrides + | + | def produce(self) -> Iterator[Union[TupleLike, TableLike, None]]: + | file = DatasetFileDocument("$modelPath") + | bytes = file.read_file().getvalue() # return an io.Bytes object + | + | with tempfile.NamedTemporaryFile(suffix='.h5', delete=False) as tmp_file: + | tmp_file.write(bytes) + | tmp_file.flush() + | model = tf.keras.models.load_model(tmp_file.name, compile=False) + | os.unlink(tmp_file.name) # Clean up temporary file + | + | yield {"model": model} + |""".stripMargin + } +}
