Browse Source

feat: support regenerating ai results #2012 (#2528)

* feat: added emoji and network image support

* fix: code cleanup and improvements

* feat: rewrite AI response #2012

* fix: code refactor, rewrite in smart edit

* chore: delete the redundant file

---------

Co-authored-by: ahmeduzair890 <[email protected]>
Co-authored-by: rizwan3395 <rizwan.ios33gmail.com>
Co-authored-by: Lucas.Xu <[email protected]>
Muhammad Rizwan 2 years ago
parent
commit
1f1a97b21b

+ 1 - 0
frontend/appflowy_flutter/assets/translations/en.json

@@ -359,6 +359,7 @@
       "autoGeneratorGenerate": "Generate",
       "autoGeneratorGenerate": "Generate",
       "autoGeneratorHintText": "Ask OpenAI ...",
       "autoGeneratorHintText": "Ask OpenAI ...",
       "autoGeneratorCantGetOpenAIKey": "Can't get OpenAI key",
       "autoGeneratorCantGetOpenAIKey": "Can't get OpenAI key",
+      "autoGeneratorRewrite": "Rewrite",
       "smartEdit": "AI Assistants",
       "smartEdit": "AI Assistants",
       "openAI": "OpenAI",
       "openAI": "OpenAI",
       "smartEditFixSpelling": "Fix spelling",
       "smartEditFixSpelling": "Fix spelling",

+ 102 - 2
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/auto_completion_node_widget.dart

@@ -1,5 +1,4 @@
 import 'dart:convert';
 import 'dart:convert';
-
 import 'package:appflowy/plugins/document/presentation/plugins/openai/service/openai_client.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/service/openai_client.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/util/learn_more_action.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/util/learn_more_action.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/discard_dialog.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/discard_dialog.dart';
@@ -21,6 +20,8 @@ import '../util/editor_extension.dart';
 
 
 const String kAutoCompletionInputType = 'auto_completion_input';
 const String kAutoCompletionInputType = 'auto_completion_input';
 const String kAutoCompletionInputString = 'auto_completion_input_string';
 const String kAutoCompletionInputString = 'auto_completion_input_string';
+const String kAutoCompletionGenerationCount =
+    'auto_completion_generation_count';
 const String kAutoCompletionInputStartSelection =
 const String kAutoCompletionInputStartSelection =
     'auto_completion_input_start_selection';
     'auto_completion_input_start_selection';
 
 
@@ -124,7 +125,8 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
   }
   }
 
 
   Widget _buildAutoGeneratorPanel(BuildContext context) {
   Widget _buildAutoGeneratorPanel(BuildContext context) {
-    if (text.isEmpty) {
+    if (text.isEmpty &&
+        widget.node.attributes[kAutoCompletionGenerationCount] < 1) {
       return Column(
       return Column(
         mainAxisSize: MainAxisSize.min,
         mainAxisSize: MainAxisSize.min,
         children: [
         children: [
@@ -204,6 +206,15 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
     );
     );
   }
   }
 
 
+  Future<void> _updateGenerationCount() async {
+    final transaction = widget.editorState.transaction;
+    transaction.updateNode(widget.node, {
+      kAutoCompletionGenerationCount:
+          widget.node.attributes[kAutoCompletionGenerationCount] + 1
+    });
+    await widget.editorState.apply(transaction);
+  }
+
   Widget _buildFooterWidget(BuildContext context) {
   Widget _buildFooterWidget(BuildContext context) {
     return Row(
     return Row(
       children: [
       children: [
@@ -212,6 +223,11 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
           onPressed: () => _onExit(),
           onPressed: () => _onExit(),
         ),
         ),
         const Space(10, 0),
         const Space(10, 0),
+        SecondaryTextButton(
+          LocaleKeys.document_plugins_autoGeneratorRewrite.tr(),
+          onPressed: () => _onRewrite(),
+        ),
+        const Space(10, 0),
         SecondaryTextButton(
         SecondaryTextButton(
           LocaleKeys.button_discard.tr(),
           LocaleKeys.button_discard.tr(),
           onPressed: () => _onDiscard(),
           onPressed: () => _onDiscard(),
@@ -272,6 +288,64 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
           await _showError(error.message);
           await _showError(error.message);
         },
         },
       );
       );
+      await _updateGenerationCount();
+    }, (error) async {
+      loading.stop();
+      await _showError(
+        LocaleKeys.document_plugins_autoGeneratorCantGetOpenAIKey.tr(),
+      );
+    });
+  }
+
+  Future<void> _onRewrite() async {
+    String previousOutput = _getPreviousOutput()!;
+    final loading = Loading(context);
+    loading.start();
+    // clear previous response
+    final selection =
+        widget.node.attributes[kAutoCompletionInputStartSelection];
+    if (selection != null) {
+      final start = Selection.fromJson(json.decode(selection)).start.path;
+      final end = widget.node.previous?.path;
+      if (end != null) {
+        final transaction = widget.editorState.transaction;
+        transaction.deleteNodesAtPath(
+          start,
+          end.last - start.last + 1,
+        );
+        await widget.editorState.apply(transaction);
+      }
+    }
+    // generate new response
+    final result = await UserBackendService.getCurrentUserProfile();
+    result.fold((userProfile) async {
+      final openAIRepository = HttpOpenAIRepository(
+        client: http.Client(),
+        apiKey: userProfile.openaiKey,
+      );
+      await openAIRepository.getStreamedCompletions(
+        prompt: _rewritePrompt(previousOutput),
+        onStart: () async {
+          loading.stop();
+          await _makeSurePreviousNodeIsEmptyTextNode();
+        },
+        onProcess: (response) async {
+          if (response.choices.isNotEmpty) {
+            final text = response.choices.first.text;
+            await widget.editorState.autoInsertText(
+              text,
+              inputType: TextRobotInputType.word,
+              delay: Duration.zero,
+            );
+          }
+        },
+        onEnd: () async {},
+        onError: (error) async {
+          loading.stop();
+          await _showError(error.message);
+        },
+      );
+      await _updateGenerationCount();
     }, (error) async {
     }, (error) async {
       loading.stop();
       loading.stop();
       await _showError(
       await _showError(
@@ -280,6 +354,31 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
     });
     });
   }
   }
 
 
+  String? _getPreviousOutput() {
+    final selection =
+        widget.node.attributes[kAutoCompletionInputStartSelection];
+    if (selection != null) {
+      final start = Selection.fromJson(json.decode(selection)).start.path;
+      final end = widget.node.previous?.path;
+      if (end != null) {
+        String lastOutput = "";
+        for (var i = start.last; i < end.last - start.last + 2; i++) {
+          TextNode? textNode =
+              widget.editorState.document.nodeAtPath([i]) as TextNode?;
+          lastOutput = "$lastOutput ${textNode!.toPlainText()}";
+        }
+        return lastOutput.trim();
+      }
+    }
+    return null;
+  }
+
+  String _rewritePrompt(String previousOutput) {
+    String prompt =
+        'I am not satisfied with your previous response($previousOutput) to the query ($text) please write another one';
+    return prompt;
+  }
+
   Future<void> _onDiscard() async {
   Future<void> _onDiscard() async {
     final selection =
     final selection =
         widget.node.attributes[kAutoCompletionInputStartSelection];
         widget.node.attributes[kAutoCompletionInputStartSelection];
@@ -293,6 +392,7 @@ class _AutoCompletionInputState extends State<_AutoCompletionInput> {
           end.last - start.last + 1,
           end.last - start.last + 1,
         );
         );
         await widget.editorState.apply(transaction);
         await widget.editorState.apply(transaction);
+        await _makeSurePreviousNodeIsEmptyTextNode();
       }
       }
     }
     }
     _onExit();
     _onExit();

+ 1 - 0
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/auto_completion_plugins.dart

@@ -14,6 +14,7 @@ SelectionMenuItem autoGeneratorMenuItem = SelectionMenuItem.node(
       type: kAutoCompletionInputType,
       type: kAutoCompletionInputType,
       attributes: {
       attributes: {
         kAutoCompletionInputString: '',
         kAutoCompletionInputString: '',
+        kAutoCompletionGenerationCount: 0,
       },
       },
     );
     );
     return node;
     return node;

+ 23 - 5
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/smart_edit_node_widget.dart

@@ -211,15 +211,15 @@ class _SmartEditInputState extends State<_SmartEditInput> {
   }
   }
 
 
   Widget _buildResultWidget(BuildContext context) {
   Widget _buildResultWidget(BuildContext context) {
-    final loading = Padding(
+    final loadingWidget = Padding(
       padding: const EdgeInsets.symmetric(horizontal: 4.0),
       padding: const EdgeInsets.symmetric(horizontal: 4.0),
       child: SizedBox.fromSize(
       child: SizedBox.fromSize(
         size: const Size.square(14),
         size: const Size.square(14),
         child: const CircularProgressIndicator(),
         child: const CircularProgressIndicator(),
       ),
       ),
     );
     );
-    if (result.isEmpty) {
-      return loading;
+    if (result.isEmpty || loading) {
+      return loadingWidget;
     }
     }
     return Flexible(
     return Flexible(
       child: Text(
       child: Text(
@@ -231,6 +231,18 @@ class _SmartEditInputState extends State<_SmartEditInput> {
   Widget _buildInputFooterWidget(BuildContext context) {
   Widget _buildInputFooterWidget(BuildContext context) {
     return Row(
     return Row(
       children: [
       children: [
+        FlowyRichTextButton(
+          TextSpan(
+            children: [
+              TextSpan(
+                text: LocaleKeys.document_plugins_autoGeneratorRewrite.tr(),
+                style: Theme.of(context).textTheme.bodyMedium,
+              ),
+            ],
+          ),
+          onPressed: () => _requestCompletions(rewrite: true),
+        ),
+        const Space(10, 0),
         FlowyRichTextButton(
         FlowyRichTextButton(
           TextSpan(
           TextSpan(
             children: [
             children: [
@@ -272,7 +284,7 @@ class _SmartEditInputState extends State<_SmartEditInput> {
           ),
           ),
           onPressed: () async => await _onExit(),
           onPressed: () async => await _onExit(),
         ),
         ),
-        const Spacer(flex: 2),
+        const Spacer(flex: 1),
         Expanded(
         Expanded(
           child: FlowyText.regular(
           child: FlowyText.regular(
             overflow: TextOverflow.ellipsis,
             overflow: TextOverflow.ellipsis,
@@ -359,7 +371,13 @@ class _SmartEditInputState extends State<_SmartEditInput> {
     );
     );
   }
   }
 
 
-  Future<void> _requestCompletions() async {
+  Future<void> _requestCompletions({bool rewrite = false}) async {
+    if (rewrite) {
+      setState(() {
+        loading = true;
+        result = "";
+      });
+    }
     final openAIRepository = await getIt.getAsync<OpenAIRepository>();
     final openAIRepository = await getIt.getAsync<OpenAIRepository>();
 
 
     var lines = input.split('\n\n');
     var lines = input.split('\n\n');